google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
32 stars 14 forks source link

Pick a slot from 0 to batch_size-1 during run_interactive.py #67

Closed bhavya01 closed 4 months ago

bhavya01 commented 4 months ago

To fix this index error warnings.warn("Some donated buffers were not usable:" ---- Streaming decode started on #slot32. new_pos Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> cache_seq_len 2048 Traceback (most recent call last): File "/home/bbahl/jetstream-pytorch/run_interactive.py", line 168, in app.run(main) File "/home/bbahl/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/bbahl/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/bbahl/jetstream-pytorch/run_interactive.py", line 146, in main token_id = slot_tokens[slot, 0].item() File "/home/bbahl/.local/lib/python3.10/site-packages/jax/_src/array.py", line 348, in getitem return lax_numpy._rewriting_take(self, idx) File "/home/bbahl/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4604, in _rewriting_take return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, File "/home/bbahl/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4613, in _gather indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update File "/home/bbahl/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4831, in _index_to_gather raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") IndexError: index is out of bounds for axis 0 with size 0

FanhaiLu1 commented 4 months ago

Thanks for fixing it! I submitted the fit in multiple host but forgot to change this file.