learning-at-home / hivemind

Decentralized deep learning in PyTorch. Built to train models on thousands of volunteers across the world.
MIT License
2k stars 157 forks source link

Read {run_id}_progress from DHT manually throws exceptions #533

Open cirquit opened 1 year ago

cirquit commented 1 year ago

Hi,

I can't seem to be able to read the training information (like here) out of the DHT that was created by hivemind.

I can connect to the DHT and run the following:

> dht.store("key", "value", expiration=get_dht_time() + 600)
> dht.get("key")
ValueWithExpiration(value='value', expiration_time=1670845892.2483625)

However, when training with hivemind, I can't seem to be able to get the data with two different behaviors after calling the get function after each other.

Only the second call shows some actual training progress data, but not complete (1 out of 4 peers) and not in a way that allows me to access it compared to the documentation.

It seems that there is some issue with the get call being run asynchronously and not being able to decode the returning LocalTrainingProgress.

How does the tutorial data get/store differ from what hivemind does with the LocalTrainingProgress?

First call to get

>>> dht.get("hivemind-123_progress")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/dht.py", line 173, in get
    return future if return_future else future.result()
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/utils/mpfuture.py", line 257, in result
    return super().result(timeout)
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
msgpack.exceptions.ExtraData: unpack(b) received extra data.

Second call to get

>>> dht.get("hivemind-123_progress")
Dec 12 12:43:20.841 [ERROR] [asyncio._run:129] Task exception was never retrieved
future: <Task finished name='Task-13381' coro=<DHT._get() done, defined at /home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/dht.py:175> exception=ExtraData({'peer_id': b"\x12 W\xb23\xa4\x85\xd0\xfa\xad\n[t\xec\xc7\xfe'\xed\x1d\x94\x03\n\xf6\x11e\xf4\xe3j,\xf7\xae\xd5h\xca", 'epoch': 24, 'samples_accumulated': 0, 'samples_per_second': 10.078083213276257, 'time': 1670842945.1815588, 'client_mode': False}, b'[signature:P3NGbBDc4ujJwy2afKJSEXD/lsM1s7icix+h5LoxGk1K6ZFvq5vaf7vs4mokUm0TmYbeGMq85DV1M3nr/+lrVg/WGAtC3moq9iiigaKiNnhszcZPx1ls+UOoIbZXGh35kdIzCIr2qsV9GxheuPaohErMoEzxN+kAytZ+wEtxoxEgOCAXEdOGVmee0Dx6eIQVzs96d7aIEpucNLGRu8ylOvgjcZNOu+MMyqVTom3R6yvl8RRTh3Dj/0cS7a0ajo+osIx7ENIadL8Zh8Vqmw+evLR2dZhAULYhN/wq1C/8dNYZzM1C2spbjG9hMYlD33RUhmD0gE+rWP0OKHA7vUPtSA==]')>
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/dht.py", line 177, in _get
    result = await self._node.get(key, latest=latest, **kwargs)
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/node.py", line 543, in get
    result = await self.get_many([key], **kwargs)
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/node.py", line 565, in get_many
    results_by_id = await self.get_many_by_id(key_ids, sufficient_expiration_time, **kwargs)
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/node.py", line 620, in get_many_by_id
    search_results[key_id].add_candidate(self.protocol.storage.get(key_id), source_node_id=self.node_id)
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/node.py", line 844, in add_candidate
    self.finish_search()
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/node.py", line 873, in finish_search
    self.serializer.loads(value_bytes), item_expiration_time
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/utils/serializer.py", line 72, in loads
    return msgpack.loads(buf, ext_hook=cls._decode_ext_types, raw=False)
  File "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb
msgpack.exceptions.ExtraData: unpack(b) received extra data.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/dht/dht.py", line 173, in get
    return future if return_future else future.result()
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/site-packages/hivemind/utils/mpfuture.py", line 257, in result
    return super().result(timeout)
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/home/ubuntu/miniconda3/envs/conda-hivemind/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
msgpack.exceptions.ExtraData: unpack(b) received extra data.
cirquit commented 1 year ago

So I found the issue - adding additional validators to the DHT is necessary to parse LocalTrainingProgress.

from hivemind.dht.schema import (
    BytesWithPublicKey,
    RSASignatureValidator,
    SchemaValidator)
from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint

class LocalTrainingProgress(BaseModel):
  peer_id: bytes
  epoch: conint(ge=0, strict=True)
  samples_accumulated: conint(ge=0, strict=True)
  samples_per_second: confloat(ge=0.0, strict=True)
  time: StrictFloat
  client_mode: StrictBool

class TrainingProgressSchema(BaseModel):
    progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]

run_id = (...get run_id)
dht = (...init dht)

signature_validator = RSASignatureValidator(None)
local_public_key = signature_validator.local_public_key
dht.add_validators(
    [
        SchemaValidator(TrainingProgressSchema, prefix=f"{run_id}"),
        signature_validator,
    ]
)

metadata, expiration = dht.get(key=f"{run_id}_progress", return_future=False)

I'm planning to create a pull request to update the documentation with a full example to access the GlobalTrainingProgress. You're welcome to either keep the issue open for me to reference it or close it.