bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
8.89k stars 489 forks source link

Improve default arguments for clients and servers #530

Closed borzunov closed 8 months ago

borzunov commented 8 months ago

This PR updates multiple default arguments in clients and servers:

  1. The client defaults to torch_dtype=torch.float32 instead of torch_dtype="auto".

    The old default was to load weights in the dtype they are saved in (usually bfloat16/float16), which caused issues when the client was run on CPU (the default unless you call .cuda()). Specifically, bfloat16 is slow on most CPUs (unless a CPU supports AVX512) and float16 can't be run natively and leads to an exception. This default was a legacy of the earliest Petals versions designed to run BLOOM - its embeddings were so big that they didn't fit into RAM in float32 (e.g., in Colab). The newer models don't have this issue.

    In contrast, the new default leads to good speed on all CPUs and is consistent with PyTorch and HF Transformers. Also, the client now shows "bfloat16 on non-AVX512 CPU" in all cases (previously this warning was shown only if the machine has enough RAM to fit float32 weights, which could hide the crucial reason of inference being slow).

    Note: This change is backward-incompatible, so we have to increase at least the minor package version (2.2.0 -> 2.3.0.dev0).

  2. The server uses 2x smaller --attn_cache_tokens.

    The old default led to loading 39 (out of 80) or 78 (out of 80) blocks for popular models on some GPU types, which visibly slowed down inference due to an excess network hop. It was also leaving too much cache, so that inference slowed down much before the cache is used.

    The new default leads to more efficient block layouts and makes the inference routing algorithm choose alternative paths through other servers when a particular server already has enough active inference sessions (= its cache is full).

  3. The client's max number of retries can be limited by the PETALS_MAX_RETRIES env var.

    This is to limit ClientConfig.max_retries in tests, so we see tracebacks instead of retrying indefinitely in case of errors.