langchain-ai / langchain-nvidia

MIT License
49 stars 15 forks source link

Retrieval does not work with ChatNVIDIA #86

Open crslen opened 1 month ago

crslen commented 1 month ago

When using the code provided in the documentation - langchain docs, the expected response does not return. When I change the NVIDIAChat class to OllamaLLM to see if I get the same response, I get the correct response.

ChatNVIDIA class response: [Document(page_content='harrison worked at kensho')] Harrison is a common name, and without additional context, it's not possible to determine exactly where a specific Harrison has worked. If you could please provide more details, such as the last name or the industry, that would help narrow down the search.

OllamaLLM class response: [Document(page_content='harrison worked at kensho')] Based on the information provided in the context, Harrison worked at Kensho.

It appears to not consider the "Document" when retrieving the data.

from operator import itemgetter

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from nemo_embed import NemoEmbeddings
from os.path import os
from dotenv import load_dotenv
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
# from langchain_ollama.llms import OllamaLLM

load_dotenv()
llm = os.getenv("LLM")
api_key = os.getenv("API_KEY")
NIMhost = os.getenv("NIMHOST")
token = os.getenv("MAXTOKEN")
temp = os.getenv("TEMPERATURE")
top_p = os.getenv("TOP_P")

embeddings=FastEmbedEmbeddings()

vectorstore = FAISS.from_texts(
    ["harrison worked at kensho"],
    embedding=embeddings,
)
retriever = vectorstore.as_retriever()
docs = retriever.invoke("where did harrison work?")
print(docs)
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer solely based on the following context:\n<Documents>\n{context}\n</Documents>",
        ),
        ("user", "{question}"),
    ]
)

# model = OllamaLLM(model="llama2")
model = ChatNVIDIA(model=llm,
            temperature=temp,
            top_p=top_p,
            max_tokens=token,
            base_url=f"http://{NIMhost}/v1")

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

print(chain.invoke("where did harrison work?"))
raspawar commented 1 month ago

Hi @crslen

The provided snippet works for me.

Here is the snippet of how to trace the context with meta/llama2-70b model:

model = ChatNVIDIA(model=llm,
            temperature=temp,
            top_p=top_p,
            max_tokens=token,)

chain = (
    prompt
    | model
    | StrOutputParser()
)

rag_chain = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
).assign(answer=chain)

print(rag_chain.invoke("where did harrison work?"))

This returns:

{'context': [Document(page_content='harrison worked at kensho')], 'question': 'where did harrison work?', 'answer': 'Harrison worked at Kensho.'}

NOTE: NemoEmbedding should not be used, please refer NVIDIAEmbeddings

Please provide more information if the issue still persists