HKUDS / LightRAG

"LightRAG: Simple and Fast Retrieval-Augmented Generation"
https://arxiv.org/abs/2410.05779
MIT License
9.64k stars 1.2k forks source link

向量存储维度问题 #146

Closed wangsj1018 closed 1 month ago

wangsj1018 commented 1 month ago

使用智谱的大模型,一直出现向量存储的维度问题,麻烦大佬帮忙看看

这是我的代码

import os

import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc, compute_args_hash
from lightrag.base import BaseKVStorage
from zhipuai import ZhipuAI

WORKING_DIR = "./dickens"
ZHIPU_APIKEY = "我自己的key"

if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)

async def zhipu_model_complete(
     prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
    model_name = kwargs['hashing_kv'].global_config['llm_model_name']
    return await zhipu_model_if_cache(
        model_name,
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        **kwargs,
    )

async def zhipu_model_if_cache(
    model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
    kwargs.pop("max_tokens", None)
    kwargs.pop("response_format", None)
    zhipu_client = ZhipuAI(api_key=ZHIPU_APIKEY)
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
    messages.extend(history_messages)
    messages.append({"role": "user", "content": prompt})
    if hashing_kv is not None:
        args_hash = compute_args_hash(model, messages)
        if_cache_return = await hashing_kv.get_by_id(args_hash)
        if if_cache_return is not None:
            return if_cache_return["return"]

    response = zhipu_client.chat.completions.create(
        model=model, 
        messages=messages, 
        **kwargs
    )

    result = response.choices[0].message.content

    if hashing_kv is not None:
        await hashing_kv.upsert({args_hash: {"return": result, "model": model}})

    return result

async def zhipu_embedding(texts: list[str], embed_model) -> np.ndarray:
    zhipu_client = ZhipuAI(api_key=ZHIPU_APIKEY)
    response = zhipu_client.embeddings.create(model=embed_model, input=texts)

    return np.array([dp.embedding for dp in response.data])

rag = LightRAG(
    working_dir=WORKING_DIR,
    llm_model_func=zhipu_model_complete,  
    llm_model_name='GLM-4-Flash',
    embedding_func=EmbeddingFunc(
        embedding_dim=512,
        max_token_size=8192,
        func=lambda texts: zhipu_embedding(
            texts, 
            embed_model="embedding-3"
        )
    ),
)

with open("./book.txt", 'r', encoding='utf-8') as f:
    rag.insert(f.read())

# Perform naive search
# print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))

# # Perform local search
# print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))

# # Perform global search
# print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))

# # Perform hybrid search
print(rag.query("please provide the context or content from which I should", param=QueryParam(mode="hybrid")))

报错如下:

Traceback (most recent call last):
  File "D:\code\py\LightRAG\lightrag_zhipu_demo.py", line 87, in <module>
    rag.insert(f.read())
  File "D:\code\py\LightRAG\lightrag\lightrag.py", line 166, in insert
    return loop.run_until_complete(self.ainsert(string_or_strings))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\asyncio\base_events.py", line 687, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "D:\code\py\LightRAG\lightrag\lightrag.py", line 210, in ainsert
    await self.chunks_vdb.upsert(inserting_chunks)
  File "D:\code\py\LightRAG\lightrag\storage.py", line 103, in upsert
    results = self._client.upsert(datas=list_data)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\site-packages\nano_vectordb\dbs.py", line 108, in upsert
    self.__storage["matrix"] = np.vstack([self.__storage["matrix"], new_matrix])
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\site-packages\numpy\core\shape_base.py", line 289, in vstack
    return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 512 and the array at index 1 has size 2048
Dormiveglia-elf commented 1 month ago

Please refer to my previous PR #116

cristianohello commented 1 month ago

怎么修改代码?还是报错

wangsj1018 commented 1 month ago

怎么修改代码?还是报错

不太会python, 我多练练

wangsj1018 commented 1 month ago

Please refer to my previous PR #116

thx