haoliuhl / ringattention

Transformers with Arbitrarily Large Context
Apache License 2.0
630 stars 50 forks source link

scripts/jax2hf. py error #17

Open liuxpro opened 5 months ago

liuxpro commented 5 months ago

HI I am trying to use the current script RingAttention main/scripts/jax2hf. py to convert the jax model to huggingface format, which comes from https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M-Jax/tree/main. But there was an error, how can I solve it? THX!

command: python scripts/jax2hf.py --load_checkpoint params::t2t_chat_1m/params --output_dir hg2jax_1mchat --tokenizer_path t2t_chat_1m/tokenizer.model

error Info: Fetching the tokenizer from t2t_chat_1m/tokenizer.model. Traceback (most recent call last): File "lworldmodel/RingAttention-main/scripts/jax2hf.py", line 280, in run(main) File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "RingAttention-main/scripts/jax2hf.py", line 273, in main load_and_convert_checkpoint(FLAGS.load_checkpoint), File "RingAttention-main/scripts/jax2hf.py", line 83, in load_and_convertcheckpoint , flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path) File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 216, in load_trainstate_checkpoint restored_params = cls.load_checkpoint( File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 130, in load_checkpoint for key, value in unpacker: File "msgpack/_unpacker.pyx", line 540, in msgpack._cmsgpack.Unpacker.next File "msgpack/_unpacker.pyx", line 474, in msgpack._cmsgpack.Unpacker._unpack File "msgpack/_unpacker.pyx", line 446, in msgpack._cmsgpack.Unpacker.read_from_file msgpack.exceptions.BufferFull

liuxpro commented 5 months ago

The previous command parameter was -- load_checkpoint param::, and we also experimented with this parameter -- load_checkpoint flax_params::, but an error occurred as well.

command: python scripts/jax2hf.py --load_checkpoint flax_params::lwm_jax/t2t_chat_1m/params --output_dir hg2jax_1mchat --tokenizer_path t2t_chat_1m/tokenizer.model

error Info: Fetching the tokenizer from lwm_jax/t2t_chat_1m/tokenizer.model. Traceback (most recent call last): File "RingAttention-main/scripts/jax2hf.py", line 280, in run(main) File "/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "RingAttention-main/scripts/jax2hf.py", line 273, in main load_and_convert_checkpoint(FLAGS.load_checkpoint), File "RingAttention-main/scripts/jax2hf.py", line 83, in load_and_convertcheckpoint , flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path) File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 228, in load_trainstate_checkpoint restored_params = cls.load_flax_checkpoint( File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 165, in load_flax_checkpoint state_dict = flax.serialization.msgpack_restore(encoded_bytes) File "/home/liuxiao/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/serialization.py", line 407, in msgpack_restore state_dict = msgpack.unpackb( File "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb msgpack.exceptions.ExtraData: unpack(b) received extra data.