v6d-io / v6d

vineyard (v6d): an in-memory immutable data manager. (Project under CNCF, TAG-Storage)
https://v6d.io
Apache License 2.0
816 stars 117 forks source link

Improve the implementation of migrating an object with lots of blobs #1914

Closed dashanji closed 2 weeks ago

dashanji commented 2 weeks ago

What do these changes do?

Start two vineyards in different machines

./bin/vineyardd --compression=false --reserve_memory=true --size=20Gi --etcd_endpoint=192.168.0.239:2379 ./bin/vineyardd --compression=false --reserve_memory=true --size=20Gi --etcd_endpoint=192.168.0.239:2379

Improve the performance of migrating a single objects

The code is:

import safetensors
import safetensors.torch
import torch
import vineyard
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from vineyard.contrib.ml.torch import torch_context

## Run the following code in vineyardd instance 1
state_file = "./stable-diffusion-v1-5-pruned-emaonly.safetensors"
with open(state_file, 'rb') as f:
    state_dict = safetensors.torch.load(f.read())

client = vineyard.connect("/var/run/vineyard.sock")
with torch_context(client):
    obj = client.put(state_dict, persist=True, name="state_dict")

## Run the following code in vineyardd instance 2
client2 = vineyard.connect("/var/run/vineyard.sock")

print('start to migrate')
start = time.time()
with torch_context(client2):
    data2 = client2.get(name="state_dict", fetch=True)
end = time.time()
print(f"Migrate Time taken: {end - start} seconds.")

start = time.time()
rpc_client = vineyard.connect(host="192.168.0.237", port=9600)
with torch_context(rpc_client):
    data3 = rpc_client.get(name="state_dict")
end = time.time()

assert len(state_dict) == len(data3), "State dict and data3 have different lengths"
assert len(state_dict) == len(data2), "State dict and data2 have different lengths"

for k, v in state_dict.items():
    assert torch.equal(v, data3[k]), f"Tensors for get_data1 key {k} do not match."
    assert torch.equal(v, data2[k]), f"Tensors for get_data1 key {k} do not match."

print(f"RPC Time taken: {end - start} seconds.")

print("All operations are completed.")

Without any compression,the performance of the original version is:

Migrate a sd model cost 27s.
RPC Get a sd model cost 3.7s.

Without any compression, the performance of this PR is:

Migrate a sd model cost 3.5s.
RPC Get a sd model cost 3.7s.

The migrate is a little faster than rpc get as we reserved the memory in the vineyardd previously.

Improve the performance when loading a model with multi-processes(simulate the multiple GPUs)

Here, we use the sd models as an example

import safetensors
import safetensors.torch
import torch
import vineyard
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from vineyard.contrib.ml.torch import torch_context

for i in range(0, 4):
    if i == 0:
        state_file = f'./stable-diffusion-v1-5-pruned-emaonly.safetensors'
    else:
        state_file = f'./stable-diffusion-v1-5-pruned-emaonly-{i}.safetensors'
    with open(state_file, 'rb') as f:
        state_dict = safetensors.torch.load(f.read())
        #print("state_dict nbytes: ", state_dict.nbytes)

    client = vineyard.connect(host="192.168.0.237", port=9600)

    with torch_context(client):
        obj = client.put(state_dict, persist=True, name=f"state_dict_{i}")

for i in range(4, 7):
    state_file = f'./stable-diffusion-v1-5-pruned-emaonly-{i}.safetensors'
    with open(state_file, 'rb') as f:
        state_dict = safetensors.torch.load(f.read())

    client = vineyard.connect(host="192.168.0.239", port=9600)

    with torch_context(client):
        obj = client.put(state_dict, persist=True, name=f"state_dict_{i}")

print("All operations are completed.")

Use rpc client to get the model as follow.

import threading
import time
import vineyard
from vineyard.contrib.ml.torch import torch_context

def read_model():
    # read these model
    start = time.time()
    client = vineyard.connect(host="192.168.0.239", port=9600)
    for i in range(0, 7):
        with torch_context(client):
            sd = client.get(name="state_dict_{i}")
    end = time.time()
    print("Time taken to read the model: ", end - start)

if __name__ == "__main__":

    threads = []

    start = time.time()
    for i in range(4):
        thread = threading.Thread(target=read_model)
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

    end = time.time()
    print("All threads have finished reading the model: ", end - start)

The result is as follows.

[root@iZj6cgaobf64z0jwsa4gkhZ caoye]# python3 test_rpc_client.py
Time taken to read the model:  63.97634696960449
Time taken to read the model:  63.9770827293396
Time taken to read the model:  67.01051378250122
Time taken to read the model:  70.00218796730042
All threads have finished reading the model:  70.1090898513794

Then we use the migrate to get these models.

import threading
import time
import vineyard
from vineyard.contrib.ml.torch import torch_context

def read_model():
    # read these model
    start = time.time()
    client = vineyard.connect("/var/run/vineyard-local.sock")
    for i in range(0, 7):
        with torch_context(client):
            name = "state_dict_" + str(i)
            sd = client.get(name=name, fetch=True)
    end = time.time()
    print("Time taken to read the model: ", end - start)

if __name__ == "__main__":
    threads = []

    start = time.time()
    for i in range(4):
        thread = threading.Thread(target=read_model)
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

    end = time.time()
    print("All threads have finished reading the model: ", end - start)

The result is as follows.

[root@iZj6cgaobf64z0jwsa4gkhZ caoye]# python3 test_migrate.py
Time taken to read the model:  28.407331943511963
Time taken to read the model:  28.407649278640747
Time taken to read the model:  28.41820740699768
Time taken to read the model:  28.463479042053223
All threads have finished reading the model:  28.570210933685303

Related issue number

Fixes #1905