123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- from dotenv import load_dotenv
- load_dotenv('environment.env')
- from fastapi import FastAPI, Request, HTTPException, status, Body
- # from fastapi.templating import Jinja2Templates
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import FileResponse
- from fastapi import Depends
- from contextlib import asynccontextmanager
- from pydantic import BaseModel
- from typing import List, Optional
- import uvicorn
- import sqlparse
- from sqlalchemy import create_engine
- import pandas as pd
- #from retrying import retry
- import datetime
- import json
- from json import loads
- import time
- from langchain.callbacks import get_openai_callback
- from langchain_community.vectorstores import Chroma
- from langchain_openai import OpenAIEmbeddings
- from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
- from Indexing_Split import create_retriever as split_retriever
- from Indexing_Split import gen_doc_from_database, gen_doc_from_history
- import os
- from langchain_community.vectorstores import SupabaseVectorStore
- from langchain_openai import OpenAIEmbeddings
- from supabase.client import Client, create_client
- from add_vectordb import GetVectorStore
- from langchain_community.cache import RedisSemanticCache # 更新导入路径
- from langchain_core.prompts import PromptTemplate
- import openai
- # Get API log
- import logging
- logger = logging.getLogger("uvicorn.error")
- openai_api_key = os.getenv("OPENAI_API_KEY")
- URI = os.getenv("SUPABASE_URI")
- openai.api_key = openai_api_key
- global_retriever = None
- # 定義FastAPI的生命週期管理器,在啟動和關閉時執行特定操作
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- global global_retriever
- global vector_store
-
- start = time.time()
- # global_retriever = split_retriever(path='./Documents', extension="docx")
- # global_retriever = raptor_retriever(path='../Documents', extension="txt")
- # global_retriever = unstructured_retriever(path='../Documents')
- supabase_url = os.getenv("SUPABASE_URL")
- supabase_key = os.getenv("SUPABASE_KEY")
- URI = os.getenv("SUPABASE_URI")
- document_table = "documents"
- supabase: Client = create_client(supabase_url, supabase_key)
- embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
- vector_store = GetVectorStore(embeddings, supabase, document_table)
- global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
- print(time.time() - start)
- yield
- # 定義依賴注入函數,用於在請求處理過程中獲取全局變量
- def get_retriever():
- return global_retriever
- def get_vector_store():
- return vector_store
- # 創建FastAPI應用實例並配置以及中間件
- app = FastAPI(lifespan=lifespan)
- # templates = Jinja2Templates(directory="temp")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 定義API路由和處理函數
- # 處理傳入的問題並返回答案
- @app.get("/answer2")
- def multi_query_answer(question, retriever=Depends(get_retriever)):
- try:
- start = time.time()
- with get_openai_callback() as cb:
- # qa_doc = gen_doc_from_database()
- # qa_history_doc = gen_doc_from_history()
- # qa_doc.extend(qa_history_doc)
- # vectorstore = Chroma.from_documents(documents=qa_doc, embedding=OpenAIEmbeddings(), collection_name="qa_pairs")
- # retriever_qa = vectorstore.as_retriever(search_kwargs={"k": 3})
- # final_answer, reference_docs = naive_rag_for_qapairs(question, retriever_qa)
- final_answer = 'False'
- if final_answer == 'False':
- final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
- # print(CHAT_HISTORY)
-
- # with get_openai_callback() as cb:
- # final_answer, reference_docs = multi_query(question, retriever)
- processing_time = time.time() - start
- print(processing_time)
- save_history(question, final_answer, reference_docs, cb, processing_time)
- return {"Answer": final_answer}
- except Exception as e:
- logger.error(f"Error in /answer2 endpoint: {e}")
- raise HTTPException(status_code=500, detail="Internal Server Error")
- class ChatHistoryItem(BaseModel):
- q: str
- a: str
- # 處理帶有歷史聊天紀錄的問題並返回答案
- @app.post("/answer_with_history")
- def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
- start = time.time()
-
- chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
- print(chat_history)
- # TODO: similarity search
-
- with get_openai_callback() as cb:
- final_answer, reference_docs = multi_query(question, retriever, chat_history)
- processing_time = time.time() - start
- print(processing_time)
- save_history(question, final_answer, reference_docs, cb, processing_time)
- return {"Answer": final_answer}
- # 處理帶有聊天歷史紀錄和文件名過濾的問題,並返回答案
- @app.post("/answer_with_history2")
- def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
- start = time.time()
- retriever = vector_store.as_retriever(search_kwargs={"k": 4,
- 'filter': {'extension':extension}})
-
- chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
- print(chat_history)
- # TODO: similarity search
-
- with get_openai_callback() as cb:
- final_answer, reference_docs = multi_query(question, retriever, chat_history)
- processing_time = time.time() - start
- print(processing_time)
- save_history(question, final_answer, reference_docs, cb, processing_time)
- return {"Answer": final_answer}
- # 保存歷史。將處理結果儲存到數據庫
- def save_history(question, answer, reference, cb, processing_time):
- # reference = [doc.dict() for doc in reference]
- record = {
- 'Question': [question],
- 'Answer': [answer],
- 'Total_Tokens': [cb.total_tokens],
- 'Total_Cost': [cb.total_cost],
- 'Processing_time': [processing_time],
- 'Contexts': [str(reference)]
- }
- df = pd.DataFrame(record)
- engine = create_engine(URI)
- df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
- class history_output(BaseModel):
- Question: str
- Answer: str
- Contexts: str
- Total_Tokens: int
- Total_Cost: float
- Processing_time: float
- Time: datetime.datetime
- # 定義獲取歷史紀錄的路由
- @app.get('/history', response_model=List[history_output])
- async def get_history():
- engine = create_engine(URI, echo=True)
- df = pd.read_sql_table("systex_records", engine.connect())
- df.fillna('', inplace=True)
- result = df.to_json(orient='index', force_ascii=False)
- result = loads(result)
- return result.values()
- @app.get("/")
- def read_root():
- return {"message": "Welcome to the Carbon Chatbot API"}
- if __name__ == "__main__":
- uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)
-
- # if __name__ == "__main__":
- # uvicorn.run("RAG_app:app", host='cmm.ai', port=8081, reload=True, ssl_keyfile="/etc/letsencrypt/live/cmm.ai/privkey.pem",
- # ssl_certfile="/etc/letsencrypt/live/cmm.ai/fullchain.pem")
|