bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
9.11k stars 512 forks source link

Issue with beam search decoding #503

Closed Vincent-Stragier closed 11 months ago

Vincent-Stragier commented 1 year ago

Using the Getting Started Colab configured to use Llama 2, I'm not able to do beam search decoding (only greedy decoding works).

I'm using the following code snippet for greedy:

inputs = tokenizer('A cat in French is "', return_tensors="pt")["input_ids"].cuda()
outputs = model.generate(
    inputs,
    max_new_tokens=20,
    num_beams=5,
    early_stopping=True,
    # no_repeat_ngram_size=2,
    num_return_sequences=5,
    do_sample=True,
    temperature=0.01,
    top_p=0.5,
    # top_k=0
)
# print(tokenizer.decode(outputs[0]))

print("Output:\n" + 100 * '-')
for i, output in enumerate(outputs):
    print("{}: {}".format(i, tokenizer.decode(output, skip_special_tokens=True)))

The output I got is:

Sep 05 08:28:14.822 [INFO] Route found: 0:40 via …aDb6mP => 40:80 via …xxYbSH
Sep 05 08:28:18.621 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2621440, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20110055749622757, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.001488511844204852})) (retry in 0 sec): AssertionError()
Sep 05 08:28:18.626 [WARN] [petals.client.routing.sequence_manager.maybe_log_traceback:537] See detailed traceback below:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/petals/client/inference_session.py", line 312, in step
    inputs = server_session.step(
  File "/usr/local/lib/python3.10/dist-packages/petals/client/inference_session.py", line 159, in step
    tensors=[
  File "/usr/local/lib/python3.10/dist-packages/petals/client/inference_session.py", line 160, in <listcomp>
    serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
  File "/usr/local/lib/python3.10/dist-packages/hivemind/compression/serialization.py", line 38, in serialize_torch_tensor
    assert tensor.device == torch.device("cpu")
AssertionError
Sep 05 08:28:20.045 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:28:20.058 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 1 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:28:21.953 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:28:21.957 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 2 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:28:23.981 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:28:23.985 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 4 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:28:27.988 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:28:27.992 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 8 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:28:36.003 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:28:36.006 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 16 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:28:52.020 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:28:52.024 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 32 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:29:24.047 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:29:24.051 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 60 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:30:24.105 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:30:24.108 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) (retry in 60 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.20019696359692774, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0013732088753644585})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:31:24.172 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:31:24.176 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.199855693277656, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0012234507003656663})) (retry in 60 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.199855693277656, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0012234507003656663})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:32:24.205 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:32:24.209 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.199855693277656, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0012234507003656663})) (retry in 60 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.199855693277656, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.0012234507003656663})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:33:24.243 [INFO] Route found: 0:40 via …aDb6mP
Sep 05 08:33:24.247 [WARN] [petals.client.inference_session.step:327] Caught exception when running inference via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.1993569028221948, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.001128210760225732})) (retry in 60 sec): AssertionError("Broken input cache: span=RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (QmNSatwUFbjgAQ5C7ZFDpzcCuHaFHhvp6Zi4s44zaDb6mP)>, start=0, end=40, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=1536.091125543488, public_name='slush 🤖💪', version='2.1.0', network_rps=4603.365429029381, forward_rps=31489.868073641504, inference_rps=908.2264508328947, adapters=(), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=2580480, next_pings={'12D3KooWD3Nk7EF2CNKYHoBBNrfgcoB3kDTMabXutNppHKxMfXqm': 0.1993569028221948, 'QmctZZcBEUkKUbbMsLfQRu4C4eSyp8438QD8WqiDxxYbSH': 0.001128210760225732})) shape=torch.Size([5, 7, 8192]) position=0 n_input_tokens=1")
Sep 05 08:34:24.289 [INFO] Route found: 0:40 via …aDb6mP
Vincent-Stragier commented 1 year ago

The current workaround is to use Petals in CPU mode, i.e., remove all .cuda() and change .from_pretrained(model_name) to .from_pretrained(model_name, torch_dtype=torch.float32.

borzunov commented 11 months ago

Hi @Vincent-Stragier,

Sorry for the slow fix and thanks for reporting! This issue was finally resolved in #531.