kokoichi206 / til

0 stars 1 forks source link

tmp #14

Open kokoichi206 opened 5 months ago

kokoichi206 commented 5 months ago
import os
import logging
import sys
from http import HTTPStatus
from dataclasses import dataclass

from dotenv import load_dotenv
from flask import Flask, request, jsonify
from sqlalchemy.orm import sessionmaker
from sqlalchemy.types import Integer, String
from sqlalchemy.schema import Column
from sqlalchemy.orm import declarative_base
from sqlalchemy import create_engine

from llama_index import (
    load_index_from_storage,
    StorageContext,
)

from llama_index.prompts import PromptTemplate
from llama_index.storage.docstore import SimpleDocumentStore
from llama_index.storage.index_store import SimpleIndexStore
from llama_index.vector_stores import SimpleVectorStore

# ========================== db 設定初期化 ==========================
CONNECT_STR = f"{os.getenv('DATABASE')}://{os.getenv('USER')}:{os.getenv('PASSWORD')}@{os.getenv('HOST')}:{os.getenv('PORT')}/{os.getenv('DB_NAME')}"
logging.info("CONNECT_STR: %s", CONNECT_STR)

engine = create_engine(CONNECT_STR, pool_pre_ping=True)
Base = declarative_base()

class DocumentSQLAlchemy(Base):
    __tablename__ = "documents"                         # 検索に使用するドキュメント情報
    id = Column(
        Integer, primary_key=True, autoincrement=True)  # ID
    name = Column(String(length=128), unique=True)      # ドキュメント名
    detail = Column(String())                           # 詳細
    admin_name = Column(String(length=24))              # 管理者名(出したくない)
    url = Column(String())                              # ドキュメント URL (出したくない)

def migrate_all():
    """
    Base で継承された全クラスのマイグレーションを行う。

    実行後再度呼んでもエラーにはならない。
    """
    Base.metadata.create_all(engine)

# TODO: 毎回呼んでも大丈夫か。
migrate_all()

@dataclass
class DocumentItem:
    """Document を RDB で管理するためのデータクラス。"""
    id: str
    name: str
    detail: str
    admin_name: str
    url: str

    @classmethod
    def ftom_dict(cls, d: dict) -> 'DocumentItem':
        """Dictonary から Documentitem を作成する。"""
        doc = DocumentItem.__new__(DocumentItem)
        doc.__dict__.update(d)
        return doc

# ========================== llama-index 設定初期化 ==========================
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)

load_dotenv()

persist_dir = "./../index"
storage_context = StorageContext.from_defaults(
    docstore=SimpleDocumentStore.from_persist_dir(persist_dir=persist_dir),
    # from_persist_dir メソッドだと vector_store.json を期待するが、
    # index_generate が生成するのは default__vector_store.json であるため from_persist_path を使っている。
    vector_store=SimpleVectorStore.from_persist_path(
        persist_path=f"{persist_dir}/default__vector_store.json",
    ),
    index_store=SimpleIndexStore.from_persist_dir(persist_dir=persist_dir),
    persist_dir=persist_dir,
)

index = load_index_from_storage(storage_context)

file_name = 'prompt.txt'
with open(file_name, 'r', encoding='utf-8') as file:
    QA_PROMPT_TEMPLATE = file.read()

qa_template = PromptTemplate(QA_PROMPT_TEMPLATE)
query_engine = index.as_query_engine(
    text_qa_template=qa_template,
)

def insert_document(doc: DocumentItem) -> None:
    """
    Database にドキュメントを登録する。
    """
    SessionClass = sessionmaker(engine)  # セッションを作るクラスを作成
    session = SessionClass()

    sql_doc = DocumentSQLAlchemy(
        name=doc.name,
        detail=doc.detail,
        admin_name=doc.admin_name,
        url=doc.url,
    )
    session.add(sql_doc)
    session.commit()

    session.close()

# ========================== Flask 設定初期化 ==========================
API_KEY = os.getenv('API_KEY', None)
if not API_KEY:
    logging.critical('API_KEY is not set.')
    os._exit(1)

def authorize() -> bool:
    logging.info("API_KEY: %s", request.headers.get('API_KEY'))
    if not request.headers.get("API_KEY") == API_KEY:
        return False

    return True

def create_app():
    # create and configure the app
    app = Flask(__name__, instance_relative_config=True)
    app.config.from_mapping(
        SECRET_KEY='dev',
        DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'),
    )

    def _dump_debug_log() -> None:
        app.logger.debug("===========: get request :===========")
        app.logger.debug("request.path: %s", request.path)
        app.logger.debug("request.args: %s", request.args)
        if request.method == 'POST':
            app.logger.debug("request.json: %s", request.json)
        app.logger.debug("request.headers: %s", request.headers)

    @app.route("/health", methods=["GET"])
    def health_check():
        return jsonify({"status": "ok"}), HTTPStatus.OK

    @app.route("/docs/search", methods=["GET"])
    def search():
        """
        ドキュメントまたはデータベース内で特定のクエリに基づいて検索を行い、結果を返すエンドポイント。

        このエンドポイントは、クエリパラメータとして与えられた検索語句('query')を使用して検索を実行します。検索結果はJSON形式で返されます。
        クエリパラメータが指定されていない場合、エンドポイントはHTTPステータス400(BAD REQUEST)とともにエラーメッセージを返します。

        Parameters:
        - query (str): 検索語句。URL のクエリパラメータとして指定されます。例: '/docs/search?query=ハッカソンのドキュメントの URL 教えて'

        Returns:
        - response (dict): 検索結果を含む JSON 形式の辞書。検索語句に基づいて得られた結果が含まれます。
        """
        _dump_debug_log()

        if not authorize():
            return jsonify({"message": "unahotorized"}), HTTPStatus.UNAUTHORIZED

        query_text = request.args.get("query", None)
        if not query_text:
            return jsonify({"message": "query must not be empty"}), HTTPStatus.BAD_REQUEST

        response = query_engine.query(query_text)
        return jsonify({"message": str(response)}), HTTPStatus.OK

    @app.route("/docs", methods=["POST"])
    def register():
        _dump_debug_log()

        if not authorize():
            return jsonify({"message": "unahotorized"}), HTTPStatus.UNAUTHORIZED

        doc = DocumentItem.ftom_dict(request.json)
        # TODO: ドキュメントの validate.
        insert_document(doc=doc)
        return '', HTTPStatus.NO_CONTENT

    return app

if __name__ == '__main__':
    local_app = create_app()
    local_app.run(host="0.0.0.0", port=5601)
import os
import logging
import sys
from pathlib import Path
from llama_index.indices.service_context import ServiceContext
from llama_index.llms import OpenAI
from llama_index import download_loader
from llama_index import VectorStoreIndex

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)

SimpleCSVReader = download_loader("SimpleCSVReader")
loader = SimpleCSVReader(encoding="utf-8")

llm = OpenAI(model='gpt-3.5-turbo', temperature=0)
service_context = ServiceContext.from_defaults(llm=llm)
tmp_index_folder_name = ("../index")
documents = loader.load_data(
    file=Path('./data.csv'))

index = VectorStoreIndex.from_documents(
    documents, service_context=service_context)
index.storage_context.persist(tmp_index_folder_name)
kokoichi206 commented 4 months ago

整理