|
@@ -0,0 +1,215 @@
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv('environment.env')
|
|
|
+
|
|
|
+from fastapi import FastAPI, Request, HTTPException, status, Body
|
|
|
+# from fastapi.templating import Jinja2Templates
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from fastapi.responses import FileResponse
|
|
|
+from fastapi import Depends
|
|
|
+from contextlib import asynccontextmanager
|
|
|
+from pydantic import BaseModel
|
|
|
+from typing import List, Optional
|
|
|
+import uvicorn
|
|
|
+
|
|
|
+import sqlparse
|
|
|
+from sqlalchemy import create_engine
|
|
|
+import pandas as pd
|
|
|
+#from retrying import retry
|
|
|
+import datetime
|
|
|
+import json
|
|
|
+from json import loads
|
|
|
+import time
|
|
|
+from langchain.callbacks import get_openai_callback
|
|
|
+
|
|
|
+from langchain_community.vectorstores import Chroma
|
|
|
+from langchain_openai import OpenAIEmbeddings
|
|
|
+from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
|
|
|
+from Indexing_Split import create_retriever as split_retriever
|
|
|
+from Indexing_Split import gen_doc_from_database, gen_doc_from_history
|
|
|
+
|
|
|
+import os
|
|
|
+from langchain_community.vectorstores import SupabaseVectorStore
|
|
|
+from langchain_openai import OpenAIEmbeddings
|
|
|
+from supabase.client import Client, create_client
|
|
|
+from add_vectordb import GetVectorStore
|
|
|
+from langchain_community.cache import RedisSemanticCache # 更新导入路径
|
|
|
+from langchain_core.prompts import PromptTemplate
|
|
|
+import openai
|
|
|
+
|
|
|
+# Get API log
|
|
|
+import logging
|
|
|
+logger = logging.getLogger("uvicorn.error")
|
|
|
+
|
|
|
+openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
+URI = os.getenv("SUPABASE_URI")
|
|
|
+openai.api_key = openai_api_key
|
|
|
+
|
|
|
+
|
|
|
+global_retriever = None
|
|
|
+
|
|
|
+# 定義FastAPI的生命週期管理器,在啟動和關閉時執行特定操作
|
|
|
+@asynccontextmanager
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
+ global global_retriever
|
|
|
+ global vector_store
|
|
|
+
|
|
|
+ start = time.time()
|
|
|
+ # global_retriever = split_retriever(path='./Documents', extension="docx")
|
|
|
+ # global_retriever = raptor_retriever(path='../Documents', extension="txt")
|
|
|
+ # global_retriever = unstructured_retriever(path='../Documents')
|
|
|
+
|
|
|
+ supabase_url = os.getenv("SUPABASE_URL")
|
|
|
+ supabase_key = os.getenv("SUPABASE_KEY")
|
|
|
+ URI = os.getenv("SUPABASE_URI")
|
|
|
+ document_table = "documents"
|
|
|
+ supabase: Client = create_client(supabase_url, supabase_key)
|
|
|
+
|
|
|
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
|
|
+ vector_store = GetVectorStore(embeddings, supabase, document_table)
|
|
|
+ global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
|
|
|
+
|
|
|
+ print(time.time() - start)
|
|
|
+ yield
|
|
|
+
|
|
|
+# 定義依賴注入函數,用於在請求處理過程中獲取全局變量
|
|
|
+def get_retriever():
|
|
|
+ return global_retriever
|
|
|
+
|
|
|
+
|
|
|
+def get_vector_store():
|
|
|
+ return vector_store
|
|
|
+
|
|
|
+# 創建FastAPI應用實例並配置以及中間件
|
|
|
+app = FastAPI(lifespan=lifespan)
|
|
|
+
|
|
|
+# templates = Jinja2Templates(directory="temp")
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"],
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"],
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+# 定義API路由和處理函數
|
|
|
+# 處理傳入的問題並返回答案
|
|
|
+@app.get("/answer2")
|
|
|
+def multi_query_answer(question, retriever=Depends(get_retriever)):
|
|
|
+ try:
|
|
|
+ start = time.time()
|
|
|
+
|
|
|
+ with get_openai_callback() as cb:
|
|
|
+ # qa_doc = gen_doc_from_database()
|
|
|
+ # qa_history_doc = gen_doc_from_history()
|
|
|
+ # qa_doc.extend(qa_history_doc)
|
|
|
+ # vectorstore = Chroma.from_documents(documents=qa_doc, embedding=OpenAIEmbeddings(), collection_name="qa_pairs")
|
|
|
+ # retriever_qa = vectorstore.as_retriever(search_kwargs={"k": 3})
|
|
|
+ # final_answer, reference_docs = naive_rag_for_qapairs(question, retriever_qa)
|
|
|
+ final_answer = 'False'
|
|
|
+ if final_answer == 'False':
|
|
|
+ final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
|
|
|
+
|
|
|
+ # print(CHAT_HISTORY)
|
|
|
+
|
|
|
+ # with get_openai_callback() as cb:
|
|
|
+ # final_answer, reference_docs = multi_query(question, retriever)
|
|
|
+ processing_time = time.time() - start
|
|
|
+ print(processing_time)
|
|
|
+ save_history(question, final_answer, reference_docs, cb, processing_time)
|
|
|
+
|
|
|
+ return {"Answer": final_answer}
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error in /answer2 endpoint: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
+
|
|
|
+class ChatHistoryItem(BaseModel):
|
|
|
+ q: str
|
|
|
+ a: str
|
|
|
+
|
|
|
+# 處理帶有歷史聊天紀錄的問題並返回答案
|
|
|
+@app.post("/answer_with_history")
|
|
|
+def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
|
|
|
+ start = time.time()
|
|
|
+
|
|
|
+ chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
|
|
|
+ print(chat_history)
|
|
|
+
|
|
|
+ # TODO: similarity search
|
|
|
+
|
|
|
+ with get_openai_callback() as cb:
|
|
|
+ final_answer, reference_docs = multi_query(question, retriever, chat_history)
|
|
|
+ processing_time = time.time() - start
|
|
|
+ print(processing_time)
|
|
|
+ save_history(question, final_answer, reference_docs, cb, processing_time)
|
|
|
+
|
|
|
+ return {"Answer": final_answer}
|
|
|
+
|
|
|
+# 處理帶有聊天歷史紀錄和文件名過濾的問題,並返回答案
|
|
|
+@app.post("/answer_with_history2")
|
|
|
+def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
|
|
|
+ start = time.time()
|
|
|
+
|
|
|
+ retriever = vector_store.as_retriever(search_kwargs={"k": 4,
|
|
|
+ 'filter': {'extension':extension}})
|
|
|
+
|
|
|
+ chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
|
|
|
+ print(chat_history)
|
|
|
+
|
|
|
+ # TODO: similarity search
|
|
|
+
|
|
|
+ with get_openai_callback() as cb:
|
|
|
+ final_answer, reference_docs = multi_query(question, retriever, chat_history)
|
|
|
+ processing_time = time.time() - start
|
|
|
+ print(processing_time)
|
|
|
+ save_history(question, final_answer, reference_docs, cb, processing_time)
|
|
|
+
|
|
|
+ return {"Answer": final_answer}
|
|
|
+
|
|
|
+# 保存歷史。將處理結果儲存到數據庫
|
|
|
+def save_history(question, answer, reference, cb, processing_time):
|
|
|
+ # reference = [doc.dict() for doc in reference]
|
|
|
+ record = {
|
|
|
+ 'Question': [question],
|
|
|
+ 'Answer': [answer],
|
|
|
+ 'Total_Tokens': [cb.total_tokens],
|
|
|
+ 'Total_Cost': [cb.total_cost],
|
|
|
+ 'Processing_time': [processing_time],
|
|
|
+ 'Contexts': [str(reference)]
|
|
|
+ }
|
|
|
+ df = pd.DataFrame(record)
|
|
|
+ engine = create_engine(URI)
|
|
|
+ df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
|
|
|
+
|
|
|
+class history_output(BaseModel):
|
|
|
+ Question: str
|
|
|
+ Answer: str
|
|
|
+ Contexts: str
|
|
|
+ Total_Tokens: int
|
|
|
+ Total_Cost: float
|
|
|
+ Processing_time: float
|
|
|
+ Time: datetime.datetime
|
|
|
+
|
|
|
+# 定義獲取歷史紀錄的路由
|
|
|
+@app.get('/history', response_model=List[history_output])
|
|
|
+async def get_history():
|
|
|
+ engine = create_engine(URI, echo=True)
|
|
|
+
|
|
|
+ df = pd.read_sql_table("systex_records", engine.connect())
|
|
|
+ df.fillna('', inplace=True)
|
|
|
+ result = df.to_json(orient='index', force_ascii=False)
|
|
|
+ result = loads(result)
|
|
|
+ return result.values()
|
|
|
+
|
|
|
+@app.get("/")
|
|
|
+def read_root():
|
|
|
+ return {"message": "Welcome to the Carbon Chatbot API"}
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)
|
|
|
+
|
|
|
+# if __name__ == "__main__":
|
|
|
+# uvicorn.run("RAG_app:app", host='cmm.ai', port=8081, reload=True, ssl_keyfile="/etc/letsencrypt/live/cmm.ai/privkey.pem",
|
|
|
+# ssl_certfile="/etc/letsencrypt/live/cmm.ai/fullchain.pem")
|
|
|
+
|