huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.52k stars 27.12k forks source link

Getting Assertion Error when calling neo4j chain for inference #29100

Closed KaifAhmad1 closed 9 months ago

KaifAhmad1 commented 9 months ago

System Info

langchain version = 0.1.7 bitsandbytes = 0.42.0 pip = 24.0 cuda = 12.1 OS Windows 11 x64

Who can help?

Hey, @SunMarc @younesbelkada please help me out.

Information

Tasks

Reproduction

I've brought up this concern on LangChain, but Duso-Bot is indicating that it's actually related to BitsAndBytes.

Here is the discussion link and issue: https://github.com/langchain-ai/langchain/discussions/17701 also raised on bitsandbytes repo but did not get support. Link: https://github.com/TimDettmers/bitsandbytes/issues/1067

Expected behavior

It wil give the answes without raising the exception. answer

amyeroberts commented 9 months ago

Hi @KaifAhmad1, thanks for opening an issue!

Please make sure to provide a minimal code reproducer and information about the bug encountered, including the full error traceback when reporting an issue.

If the error is coming from bitsandbytes there isn't anything the transformers team can do.

KaifAhmad1 commented 9 months ago

Hey, @amyeroberts I have tagged this issue with bitsandbytes maintainers according to transformers documentation @SunMarc @younesbelkada

bitsandbytes = 0.42.0 pip = 24.0 python = 3.10.10 cuda = 12.1 OS = windows 11 x64

import torch
from torch import cuda, bfloat16
import transformers
model_id = 'microsoft/phi-2'
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

# begin initializing HF items, you need an access token
model_config = transformers.AutoConfig.from_pretrained(
    model_id,
    use_auth_token=hf_auth,
    trust_remote_code=True
)

# BnB Configuration
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    config=model_config,
    device_map='auto',
    use_auth_token=hf_auth,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True
)

# How model looks like:
model.eval()

from langchain.chains import GraphCypherQAChain
from langchain.graphs import Neo4jGraph

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.chains.question_answering.stuff_prompt import CHAT_PROMPT
from langchain.callbacks.manager import CallbackManagerForChainRun
from typing import Any, Dict, List
from pydantic import Field

vector_search = """
WITH 
k, e) yield node, score
RETURN node.text AS result
ORDER BY score DESC
LIMIT 3
"""

print(graph.schema)

class Neo4jVectorChain(Chain):
    graph: Neo4jGraph = Field(exclude=True)
    input_key: str = "query"
    output_key: str = "result"
    embeddings: HuggingFaceBgeEmbeddings = HuggingFaceBgeEmbeddings()
    qa_chain: LLMChain = LLMChain(llm=llm, prompt=CHAT_PROMPT)

    @property
    def input_keys(self) -> List[str]:
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        _output_keys = [self.output_key]
        return _output_keys

    def _call(self, inputs: Dict[str, str], run_manager, k=3) -> Dict[str, Any]:
        question = inputs[self.input_key]
        embedding = self.embeddings.embed_query(question)

        context = self.graph.query(vector_search, {'embedding': embedding, 'k': 3})
        context = [el['result'] for el in context]

        result = self.qa_chain({"question": question, "context": context})
        final_result = result[self.qa_chain.output_key]
        return {self.output_key: final_result}

chain = Neo4jVectorChain(graph=graph, embeddings=embeddings, verbose=True)

graph_result = chain.run("How can we enhance the specificity and efficiency of CRISPR/Cas9 gene-editing technology to minimize off-target effects and increase its potential for therapeutic applications?")
> Entering new Neo4jVectorChain chain...
/usr/local/lib/python3.10/dist-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.
  warn_deprecated(
/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.3` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.
  warnings.warn(
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-42-4ff3ab735a16>](https://localhost:8080/#) in <cell line: 1>()
----> 1 graph_result = chain.run("How can we enhance the specificity and efficiency of CRISPR/Cas9 gene-editing technology to minimize off-target effects and increase its potential for therapeutic applications?")

49 frames
[/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py](https://localhost:8080/#) in matmul_4bit(A, B, quant_state, out, bias)
    564 
    565 def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
--> 566     assert quant_state is not None
    567     if A.numel() == A.shape[-1] and A.requires_grad == False:
    568         if A.shape[-1] % quant_state.blocksize != 0:

AssertionError:
younesbelkada commented 9 months ago

Hi @KaifAhmad1 Thanks very much for the issue ! You are using the trust_remote_code model that we don't maintain, can you try out phi-2 without trust_remote_code ? I think 4bit should work out of the box with the non-trust_remote_code model