Преглед изворни кода

integrated taide model to RAG_strategy

SherryLiu пре 11 месеци
родитељ
комит
ef69f16e46
6 измењених фајлова са 102 додато и 370 уклоњено
  1. 11 64
      RAG_app_copy.py
  2. 61 144
      RAG_strategy.py
  3. 0 29
      ragas_data_generation.py
  4. 0 31
      run.sh
  5. 30 0
      taide_rag.py
  6. 0 102
      test_connection.py

+ 11 - 64
RAG_app_copy.py

@@ -1,39 +1,27 @@
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 load_dotenv('environment.env')
 load_dotenv('environment.env')
 
 
-from fastapi import FastAPI, Request, HTTPException, status, Body
-# from fastapi.templating import Jinja2Templates
+from fastapi import FastAPI, HTTPException, status, Body, Depends
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from fastapi import Depends
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 from pydantic import BaseModel
 from pydantic import BaseModel
 from typing import List, Optional
 from typing import List, Optional
 import uvicorn
 import uvicorn
 
 
-import sqlparse
 from sqlalchemy import create_engine
 from sqlalchemy import create_engine
 import pandas as pd
 import pandas as pd
-#from retrying import retry
 import datetime
 import datetime
 import json
 import json
 from json import loads
 from json import loads
 import time
 import time
 from langchain.callbacks import get_openai_callback
 from langchain.callbacks import get_openai_callback
 
 
-from langchain_community.vectorstores import Chroma
 from langchain_openai import OpenAIEmbeddings
 from langchain_openai import OpenAIEmbeddings
 from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
 from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
-from Indexing_Split import create_retriever as split_retriever
-from Indexing_Split import gen_doc_from_database, gen_doc_from_history
 
 
 import os
 import os
-from langchain_community.vectorstores import SupabaseVectorStore
-from langchain_openai import OpenAIEmbeddings
 from supabase.client import Client, create_client
 from supabase.client import Client, create_client
 from add_vectordb import GetVectorStore
 from add_vectordb import GetVectorStore
-from langchain_community.cache import RedisSemanticCache  # 更新导入路径
-from langchain_core.prompts import PromptTemplate
 import openai
 import openai
 
 
 # Get API log
 # Get API log
@@ -44,23 +32,17 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
 URI = os.getenv("SUPABASE_URI")
 URI = os.getenv("SUPABASE_URI")
 openai.api_key = openai_api_key
 openai.api_key = openai_api_key
 
 
-
 global_retriever = None
 global_retriever = None
 
 
-# 定義FastAPI的生命週期管理器,在啟動和關閉時執行特定操作
 @asynccontextmanager
 @asynccontextmanager
 async def lifespan(app: FastAPI):
 async def lifespan(app: FastAPI):
     global global_retriever
     global global_retriever
     global vector_store
     global vector_store
     
     
     start = time.time()
     start = time.time()
-    # global_retriever = split_retriever(path='./Documents', extension="docx")
-    # global_retriever = raptor_retriever(path='../Documents', extension="txt")
-    # global_retriever = unstructured_retriever(path='../Documents')
 
 
     supabase_url = os.getenv("SUPABASE_URL")
     supabase_url = os.getenv("SUPABASE_URL")
     supabase_key = os.getenv("SUPABASE_KEY")
     supabase_key = os.getenv("SUPABASE_KEY")
-    URI = os.getenv("SUPABASE_URI")
     document_table = "documents"
     document_table = "documents"
     supabase: Client = create_client(supabase_url, supabase_key)
     supabase: Client = create_client(supabase_url, supabase_key)
 
 
@@ -68,21 +50,17 @@ async def lifespan(app: FastAPI):
     vector_store = GetVectorStore(embeddings, supabase, document_table)
     vector_store = GetVectorStore(embeddings, supabase, document_table)
     global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
     global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
 
 
-    print(time.time() - start)
+    print(f"Initialization time: {time.time() - start}")
     yield
     yield
 
 
-# 定義依賴注入函數,用於在請求處理過程中獲取全局變量
 def get_retriever():
 def get_retriever():
     return global_retriever
     return global_retriever
 
 
-
 def get_vector_store():
 def get_vector_store():
     return vector_store
     return vector_store
 
 
-# 創建FastAPI應用實例並配置以及中間件
 app = FastAPI(lifespan=lifespan)
 app = FastAPI(lifespan=lifespan)
 
 
-# templates = Jinja2Templates(directory="temp")
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
     allow_origins=["*"],
     allow_origins=["*"],
@@ -91,31 +69,16 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
-
-# 定義API路由和處理函數
-# 處理傳入的問題並返回答案
 @app.get("/answer2")
 @app.get("/answer2")
 def multi_query_answer(question, retriever=Depends(get_retriever)):
 def multi_query_answer(question, retriever=Depends(get_retriever)):
     try:
     try:
         start = time.time()
         start = time.time()
 
 
         with get_openai_callback() as cb:
         with get_openai_callback() as cb:
-            # qa_doc = gen_doc_from_database()
-            # qa_history_doc = gen_doc_from_history()
-            # qa_doc.extend(qa_history_doc)
-            # vectorstore = Chroma.from_documents(documents=qa_doc, embedding=OpenAIEmbeddings(), collection_name="qa_pairs")
-            # retriever_qa = vectorstore.as_retriever(search_kwargs={"k": 3})
-            # final_answer, reference_docs = naive_rag_for_qapairs(question, retriever_qa)
-            final_answer = 'False'
-            if final_answer == 'False':
-                final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
-
-        # print(CHAT_HISTORY)
-        
-        # with get_openai_callback() as cb:
-        #     final_answer, reference_docs = multi_query(question, retriever)
+            final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
+
         processing_time = time.time() - start
         processing_time = time.time() - start
-        print(processing_time)
+        print(f"Processing time: {processing_time}")
         save_history(question, final_answer, reference_docs, cb, processing_time)
         save_history(question, final_answer, reference_docs, cb, processing_time)
 
 
         return {"Answer": final_answer}
         return {"Answer": final_answer}
@@ -127,48 +90,39 @@ class ChatHistoryItem(BaseModel):
     q: str
     q: str
     a: str
     a: str
 
 
-# 處理帶有歷史聊天紀錄的問題並返回答案
 @app.post("/answer_with_history")
 @app.post("/answer_with_history")
 def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
 def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
     start = time.time()
     start = time.time()
     
     
     chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
     chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
-    print(chat_history)
-
-    # TODO: similarity search
+    print(f"Chat history: {chat_history}")
     
     
     with get_openai_callback() as cb:
     with get_openai_callback() as cb:
         final_answer, reference_docs = multi_query(question, retriever, chat_history)
         final_answer, reference_docs = multi_query(question, retriever, chat_history)
     processing_time = time.time() - start
     processing_time = time.time() - start
-    print(processing_time)
+    print(f"Processing time: {processing_time}")
     save_history(question, final_answer, reference_docs, cb, processing_time)
     save_history(question, final_answer, reference_docs, cb, processing_time)
 
 
     return {"Answer": final_answer}
     return {"Answer": final_answer}
 
 
-# 處理帶有聊天歷史紀錄和文件名過濾的問題,並返回答案
 @app.post("/answer_with_history2")
 @app.post("/answer_with_history2")
 def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
 def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
     start = time.time()
     start = time.time()
 
 
-    retriever = vector_store.as_retriever(search_kwargs={"k": 4,
-                                                         'filter': {'extension':extension}})
+    retriever = vector_store.as_retriever(search_kwargs={"k": 4, 'filter': {'extension':extension}})
     
     
     chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
     chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
-    print(chat_history)
-
-    # TODO: similarity search
+    print(f"Chat history: {chat_history}")
     
     
     with get_openai_callback() as cb:
     with get_openai_callback() as cb:
         final_answer, reference_docs = multi_query(question, retriever, chat_history)
         final_answer, reference_docs = multi_query(question, retriever, chat_history)
     processing_time = time.time() - start
     processing_time = time.time() - start
-    print(processing_time)
+    print(f"Processing time: {processing_time}")
     save_history(question, final_answer, reference_docs, cb, processing_time)
     save_history(question, final_answer, reference_docs, cb, processing_time)
 
 
     return {"Answer": final_answer}
     return {"Answer": final_answer}
 
 
-# 保存歷史。將處理結果儲存到數據庫
 def save_history(question, answer, reference, cb, processing_time):
 def save_history(question, answer, reference, cb, processing_time):
-    # reference = [doc.dict() for doc in reference]
     record = {
     record = {
         'Question': [question],
         'Question': [question],
         'Answer': [answer],
         'Answer': [answer],
@@ -190,7 +144,6 @@ class history_output(BaseModel):
     Processing_time: float
     Processing_time: float
     Time: datetime.datetime
     Time: datetime.datetime
 
 
-# 定義獲取歷史紀錄的路由
 @app.get('/history', response_model=List[history_output])
 @app.get('/history', response_model=List[history_output])
 async def get_history():
 async def get_history():
     engine = create_engine(URI, echo=True)
     engine = create_engine(URI, echo=True)
@@ -205,11 +158,5 @@ async def get_history():
 def read_root():
 def read_root():
     return {"message": "Welcome to the Carbon Chatbot API"}
     return {"message": "Welcome to the Carbon Chatbot API"}
 
 
-
 if __name__ == "__main__":
 if __name__ == "__main__":
-    uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)
-    
-# if __name__ == "__main__":
-#     uvicorn.run("RAG_app:app", host='cmm.ai', port=8081, reload=True, ssl_keyfile="/etc/letsencrypt/live/cmm.ai/privkey.pem", 
-#                 ssl_certfile="/etc/letsencrypt/live/cmm.ai/fullchain.pem")
-
+    uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)

+ 61 - 144
RAG_strategy.py

@@ -9,7 +9,14 @@ from langchain_core.runnables import RunnablePassthrough
 from langchain import hub
 from langchain import hub
 from langchain.globals import set_llm_cache
 from langchain.globals import set_llm_cache
 from langchain import PromptTemplate
 from langchain import PromptTemplate
-from langchain.llms.base import LLM
+import subprocess
+import json
+from typing import Any, List, Optional, Dict
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
+from langchain_core.outputs import ChatResult, ChatGeneration
+from pydantic import Field
 
 
 from langchain_core.runnables import (
 from langchain_core.runnables import (
     RunnableBranch,
     RunnableBranch,
@@ -17,9 +24,6 @@ from langchain_core.runnables import (
     RunnableParallel,
     RunnableParallel,
     RunnablePassthrough,
     RunnablePassthrough,
 )
 )
-from typing import Tuple, List, Optional
-from langchain_core.messages import AIMessage, HumanMessage
-
 
 
 from datasets import Dataset 
 from datasets import Dataset 
 from ragas import evaluate
 from ragas import evaluate
@@ -29,92 +33,71 @@ from ragas.metrics import (
     context_recall,
     context_recall,
     context_precision,
     context_precision,
 )
 )
-from typing import List
 import os
 import os
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 load_dotenv('environment.env')
 load_dotenv('environment.env')
 
 
-########################################################################################################################
-########################################################################################################################
 from langchain.cache import SQLiteCache
 from langchain.cache import SQLiteCache
-from langchain.cache import RedisSemanticCache
 from langchain_openai import OpenAIEmbeddings
 from langchain_openai import OpenAIEmbeddings
 from langchain.globals import set_llm_cache
 from langchain.globals import set_llm_cache
 
 
-########################################################################################################################
 import requests
 import requests
 import openai
 import openai
 openai_api_key = os.getenv("OPENAI_API_KEY")
 openai_api_key = os.getenv("OPENAI_API_KEY")
 openai.api_key = openai_api_key
 openai.api_key = openai_api_key
 URI = os.getenv("SUPABASE_URI")
 URI = os.getenv("SUPABASE_URI")
 
 
-from typing import Optional, List, Any, Dict
-
-# 設置緩存,以減少對API的重複請求。使用Redis
+# 設置緩存,以減少對API的重複請求。使用SQLite
 set_llm_cache(SQLiteCache(database_path=".langchain.db"))
 set_llm_cache(SQLiteCache(database_path=".langchain.db"))
-# set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6379", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
 
 
-# # TAIDE model on Ollama https://ollama.com/jcai/llama3-taide-lx-8b-chat-alpha1
-# def interact_with_model(messages, api_url="http://localhost:11434/v1/chat/completions"):
-#     print("Using model: TAIDE")
-#     response = requests.post(api_url, json={"model": "jcai/llama3-taide-lx-8b-chat-alpha1:Q4_K_M", "messages": messages})
-#     return response.json()["choices"][0]["message"]["content"]
-
-import requests
-from typing import Optional, List, Any, Dict
-from langchain.llms.base import LLM
-
-class CustomTAIDELLM(LLM):
-    api_url: str = "http://localhost:11434/api/chat"
-    model_name: str = "taide-local"
-    system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
-
-    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
-        # Format the prompt according to TAIDE requirements
-        formatted_prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{prompt} [/INST]"
-        print(f"Formatted prompt being sent to TAIDE model: {formatted_prompt}")
-
-        payload = {
-            "model": self.model_name,
-            "messages": [
-                {"role": "system", "content": self.system_prompt},
-                {"role": "user", "content": prompt}
-            ]
-        }
-
-        try:
-            response = requests.post(self.api_url, json=payload)
-            response.raise_for_status()
-            result = response.json()
-            print(f"Full API response: {result}")
-            if "message" in result:
-                return result["message"]["content"].strip()
-            else:
-                print(f"Unexpected response structure: {result}")
-                return "Error: Unexpected response from the model."
-        except requests.RequestException as e:
-            print(f"Error calling Ollama API: {e}")
-            return f"Error: Unable to get response from the model. {str(e)}"
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+
+class OllamaChatModel(BaseChatModel):
+    model_name: str = Field(default="taide-local")
+
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        formatted_messages = []
+        for msg in messages:
+            if isinstance(msg, HumanMessage):
+                formatted_messages.append({"role": "user", "content": msg.content})
+            elif isinstance(msg, AIMessage):
+                formatted_messages.append({"role": "assistant", "content": msg.content})
+            elif isinstance(msg, SystemMessage):
+                 formatted_messages.append({"role": "system", "content": msg.content})
+
+        prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
+        for msg in formatted_messages:
+            if msg['role'] == 'user':
+                prompt += f"{msg['content']} [/INST]"
+            elif msg['role'] == "assistant":
+                prompt += f"{msg['content']} </s><s>[INST]"
+
+        command = ["ollama", "run", self.model_name, prompt]
+        result = subprocess.run(command, capture_output=True, text=True)
+
+        if result.returncode != 0:
+            raise Exception(f"Ollama command failed: {result.stderr}")
+        
+        content = result.stdout.strip()
 
 
+        message = AIMessage(content=content)
+        generation = ChatGeneration(message=message)
+        return ChatResult(generations=[generation])
+    
     @property
     @property
     def _llm_type(self) -> str:
     def _llm_type(self) -> str:
-        return "custom_taide"
-
-    def get_model_name(self):
-        return self.model_name
-
-    @property
-    def _identifying_params(self) -> Dict[str, Any]:
-        return {"model_name": self.model_name}
-
-# Create an instance of the custom LLM
-taide_llm = CustomTAIDELLM(api_url="http://localhost:11434/api/chat", model_name="taide-local")
+        return "ollama-chat-model"
+    
+taide_llm = OllamaChatModel(model_name="taide-local")
 
 
-# 生成多個不同版本的問題,進行檢索,並返回答案和參考文檔
 def multi_query(question, retriever, chat_history):
 def multi_query(question, retriever, chat_history):
-
     def multi_query_chain():
     def multi_query_chain():
-        # Multi Query: Different Perspectives
         template = """You are an AI language model assistant. Your task is to generate three 
         template = """You are an AI language model assistant. Your task is to generate three 
         different versions of the given user question to retrieve relevant documents from a vector 
         different versions of the given user question to retrieve relevant documents from a vector 
         database. By generating multiple perspectives on the user question, your goal is to help
         database. By generating multiple perspectives on the user question, your goal is to help
@@ -123,27 +106,12 @@ def multi_query(question, retriever, chat_history):
 
 
         You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
         You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
         
         
-        
         Original question: {question}"""
         Original question: {question}"""
         prompt_perspectives = ChatPromptTemplate.from_template(template)
         prompt_perspectives = ChatPromptTemplate.from_template(template)
 
 
-        messages = [
-            {"role": "system", "content": template},
-            {"role": "user", "content": question},
-        ]
-        # generate_queries = interact_with_model(messages).split("\n")
-
-        
-        # llm = ChatOpenAI(model="gpt-4-1106-preview")
-        # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
-        # llm = ChatOllama(model="gemma2", temperature=0)
-        # llm = ChatOllama(model=model)
-
-
         generate_queries = (
         generate_queries = (
             prompt_perspectives 
             prompt_perspectives 
             | taide_llm
             | taide_llm
-            # | llm
             | StrOutputParser() 
             | StrOutputParser() 
             | (lambda x: x.split("\n"))
             | (lambda x: x.split("\n"))
         )
         )
@@ -151,14 +119,9 @@ def multi_query(question, retriever, chat_history):
         return generate_queries
         return generate_queries
 
 
     def get_unique_union(documents: List[list]):
     def get_unique_union(documents: List[list]):
-        """ Unique union of retrieved docs """
-        # Flatten list of lists, and convert each Document to string
         flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
         flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
-        # Get unique documents
         unique_docs = list(set(flattened_docs))
         unique_docs = list(set(flattened_docs))
-        # Return
         return [loads(doc) for doc in unique_docs]
         return [loads(doc) for doc in unique_docs]
-    
 
 
     _search_query = get_search_query()
     _search_query = get_search_query()
     modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
     modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
@@ -173,7 +136,6 @@ def multi_query(question, retriever, chat_history):
 
 
     return answer, docs
     return answer, docs
 
 
-# 根據檢索到的文檔和用戶問題生成最後答案
 def multi_query_rag_prompt(retrieval_chain, question):
 def multi_query_rag_prompt(retrieval_chain, question):
     template = """Answer the following question based on this context:
     template = """Answer the following question based on this context:
 
 
@@ -205,25 +167,8 @@ def multi_query_rag_prompt(retrieval_chain, question):
     except Exception as e:
     except Exception as e:
         print(f"Error invoking rag_chain: {e}")
         print(f"Error invoking rag_chain: {e}")
         return "Error occurred while processing the question."
         return "Error occurred while processing the question."
-    
-########################################################################################################################
 
 
-# 將聊天紀錄個跟進問題轉化為獨立問題
 def get_search_query():
 def get_search_query():
-    # Condense a chat history and follow-up question into a standalone question
-    # 
-    # _template = """Given the following conversation and a follow up question, 
-    # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
-    # Generate standalone question in its original language.
-    # Chat History:
-    # {chat_history}
-    # Follow Up Input: {question}
-
-    # Hint:
-    # * Refer to chat history and add the subject to the question
-    # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
-    
-    # Standalone question:"""  # noqa: E501
     _template = """Rewrite the following query by incorporating relevant context from the conversation history.
     _template = """Rewrite the following query by incorporating relevant context from the conversation history.
     The rewritten query should:
     The rewritten query should:
     
     
@@ -244,7 +189,7 @@ def get_search_query():
     """
     """
     CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
     CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
 
 
-    def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
+    def _format_chat_history(chat_history: List[tuple[str, str]]) -> List:
         buffer = []
         buffer = []
         for human, ai in chat_history:
         for human, ai in chat_history:
             buffer.append(HumanMessage(content=human))
             buffer.append(HumanMessage(content=human))
@@ -252,11 +197,10 @@ def get_search_query():
         return buffer
         return buffer
 
 
     _search_query = RunnableBranch(
     _search_query = RunnableBranch(
-        # If input includes chat_history, we condense it with the follow-up question
         (
         (
             RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
             RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
                 run_name="HasChatHistoryCheck"
                 run_name="HasChatHistoryCheck"
-            ),  # Condense follow-up question and chat into a standalone_question
+            ),
             RunnablePassthrough.assign(
             RunnablePassthrough.assign(
                 chat_history=lambda x: _format_chat_history(x["chat_history"])
                 chat_history=lambda x: _format_chat_history(x["chat_history"])
             )
             )
@@ -264,48 +208,31 @@ def get_search_query():
             | ChatOpenAI()
             | ChatOpenAI()
             | StrOutputParser(),
             | StrOutputParser(),
         ),
         ),
-        # Else, we have no chat history, so just pass through the question
         RunnableLambda(lambda x : x["question"]),
         RunnableLambda(lambda x : x["question"]),
     )
     )
 
 
     return _search_query
     return _search_query
-########################################################################################################################
-# 檢索文檔並生成答案
-def naive_rag(question, retriever):
-    #### RETRIEVAL and GENERATION ####
 
 
-    # Prompt
+def naive_rag(question, retriever):
     prompt = hub.pull("rlm/rag-prompt")
     prompt = hub.pull("rlm/rag-prompt")
 
 
-    # LLM
-    # llm = ChatOpenAI(model_name="gpt-3.5-turbo")
-
-    # Post-processing
     def format_docs(docs):
     def format_docs(docs):
         return "\n\n".join(doc.page_content for doc in docs)
         return "\n\n".join(doc.page_content for doc in docs)
 
 
     reference = retriever.get_relevant_documents(question)
     reference = retriever.get_relevant_documents(question)
     
     
-    # Chain
     rag_chain = (
     rag_chain = (
         {"context": retriever | format_docs, "question": RunnablePassthrough()}
         {"context": retriever | format_docs, "question": RunnablePassthrough()}
         | prompt
         | prompt
         | taide_llm
         | taide_llm
-        # | llm
         | StrOutputParser()
         | StrOutputParser()
     )
     )
 
 
-    # Question
     answer = rag_chain.invoke(question)
     answer = rag_chain.invoke(question)
 
 
     return answer, reference
     return answer, reference
-################################################################################################
-# 處理question-answer pairs的檢索和生成答案
-def naive_rag_for_qapairs(question, retriever):
-    #### RETRIEVAL and GENERATION ####
 
 
-    # Prompt
-    # prompt = hub.pull("rlm/rag-prompt")
+def naive_rag_for_qapairs(question, retriever):
     template = """You are an assistant for question-answering tasks. 
     template = """You are an assistant for question-answering tasks. 
     Use the following pieces of retrieved context to answer the question. 
     Use the following pieces of retrieved context to answer the question. 
     Following retrieved context is question-answer pairs of historical QA, Find the suitable answer from the qa pairs
     Following retrieved context is question-answer pairs of historical QA, Find the suitable answer from the qa pairs
@@ -320,19 +247,13 @@ def naive_rag_for_qapairs(question, retriever):
     """
     """
     prompt = PromptTemplate.from_template(template)
     prompt = PromptTemplate.from_template(template)
 
 
-    # LLM
     llm = ChatOpenAI(model_name="gpt-4-0125-preview")
     llm = ChatOpenAI(model_name="gpt-4-0125-preview")
-    # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
-    # llm = ChatOllama(model="gemma2", num_gpu=1, temperature=0)
 
 
-
-    # Post-processing
     def format_docs(docs):
     def format_docs(docs):
         return "\n\n".join(doc.page_content for doc in docs)
         return "\n\n".join(doc.page_content for doc in docs)
 
 
     reference = retriever.get_relevant_documents(question)
     reference = retriever.get_relevant_documents(question)
     
     
-    # Chain
     rag_chain = (
     rag_chain = (
         {"context": retriever | format_docs, "question": RunnablePassthrough()}
         {"context": retriever | format_docs, "question": RunnablePassthrough()}
         | prompt
         | prompt
@@ -340,19 +261,16 @@ def naive_rag_for_qapairs(question, retriever):
         | StrOutputParser()
         | StrOutputParser()
     )
     )
 
 
-    # Question
     answer = rag_chain.invoke(question)
     answer = rag_chain.invoke(question)
 
 
     return answer, reference
     return answer, reference
-########################################################################################################################
 
 
 def rag_score(question, ground_truth, answer, reference_docs):
 def rag_score(question, ground_truth, answer, reference_docs):
-    
     datasets = {
     datasets = {
-              "question": [question],       # question: list[str]
-              "answer": [answer],           # answer: list[str]
-              "contexts": [reference_docs], # contexts: list[list[str]]
-              "ground_truths": [[ground_truth]] # ground_truth: list[list[str]]
+              "question": [question],
+              "answer": [answer],
+              "contexts": [reference_docs],
+              "ground_truths": [[ground_truth]]
             }
             }
     evalsets = Dataset.from_dict(datasets)
     evalsets = Dataset.from_dict(datasets)
 
 
@@ -371,9 +289,8 @@ def rag_score(question, ground_truth, answer, reference_docs):
     result_df.to_csv('ragas_rag.csv')
     result_df.to_csv('ragas_rag.csv')
     return result
     return result
 
 
-
 def print_current_model(llm):
 def print_current_model(llm):
-    if isinstance(llm, CustomTAIDELLM):
-        print(f"Currently using model: {llm.get_model_name()}")
+    if isinstance(llm, OllamaChatModel):
+        print(f"Currently using model: {llm.model_name}")
     else:
     else:
         pass
         pass

+ 0 - 29
ragas_data_generation.py

@@ -1,29 +0,0 @@
-from dotenv import load_dotenv
-load_dotenv('environment.env')
-
-
-from ragas.testset.generator import TestsetGenerator
-from ragas.testset.evolutions import simple, reasoning, multi_context 
-from langchain_openai import ChatOpenAi, OpenAIEmbeddings
-from langchain_community.document_loaders import DirectoryLoader
-from langchain_community.document_loaders import PyPDFLoader
-
-loader = DirectoryLoader("Documents")
-for file in 
-documents = loader.load()
-
-
-for document in documents:
-    document.metadata['filename'] = document.metadata['source']
-
-generator_llm = ChatOpenAi(model = "gpt-3.5-turbo-16k")
-critic_llm = ChatOpenAI(model="gpt-4")
-embeddings = OpenAIEmbeddings()
-
-generator = TestGenerator.from_langchain(
-    generator_llm,
-    critic_llm,
-    embeddings
-)
-# Generate testset
-testset = generator.generate_with_langchain_docs(documents, test_size=10, distributions={simple: 0.5, reasoning: 0.25, multi_context: 0.25})

+ 0 - 31
run.sh

@@ -1,37 +1,6 @@
 #!/bin/bash
 #!/bin/bash
 
 
-# Function to check if Docker is running
-docker_running() {
-    docker info >/dev/null 2>&1
-}
 
 
-# Start Docker if it's not already running
-if ! docker_running; then
-    echo "Starting Docker..."
-    open -a Docker
-    
-    # Wait for Docker to start
-    while ! docker_running; do
-        echo "Waiting for Docker to start..."
-        sleep 5
-    done
-    echo "Docker is now running"
-fi
-
-# Get the script directory
-script_dir=$(dirname "$0")
-cd "$script_dir"
-
-# Start the services defined in docker-compose.yml
-echo "Starting services with Docker Compose..."
-docker-compose up -d
-echo "Waiting for services to start..."
-sleep 20 
-
-# Change to the directory containing Python script
-cd "$script_dir/systex-RAG-sherry"
-echo "Running RAG application..."
-python ollama_chat.py
 
 
 # 使脚本文件可执行:
 # 使脚本文件可执行:
 # chmod +x run.sh
 # chmod +x run.sh

+ 30 - 0
taide_rag.py

@@ -0,0 +1,30 @@
+from dotenv import load_dotenv
+from langchain.vectorstores import Chroma
+import os
+load_dotenv('environment.env')
+openai_api_key = os.getenv("OPENAI_API_KEY")
+from RAG_strategy import taide_llm, multi_query, naive_rag
+from langchain.vectorstores import FAISS
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.document_loaders import TextLoader
+from langchain.text_splitter import CharacterTextSplitter
+
+
+
+# Load and prepare a sample document
+loader = TextLoader("test_data.txt")
+documents = loader.load()
+text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
+docs = text_splitter.split_documents(documents)
+
+# Create a vector store
+embeddings = OpenAIEmbeddings()
+vectorstore = Chroma.from_documents(docs, embeddings)
+retriever = vectorstore.as_retriever()
+
+# Test multi_query
+print("\nTesting multi_query:")
+question = "什麼是碳排放獎勵辦法?"
+answer, docs = multi_query(question, retriever, [])
+print(f"Question: {question}")
+print(f"Answer: {answer}")

+ 0 - 102
test_connection.py

@@ -1,102 +0,0 @@
-# import os
-# import sys
-
-# from supabase import create_client, Client
-
-# # # Load environment variables
-# from dotenv import load_dotenv
-# load_dotenv('environment.env')
-
-# # Get Supabase configuration from environment variables
-# SUPABASE_URL = os.getenv("SUPABASE_URL")
-# SUPABASE_KEY = os.getenv("SUPABASE_KEY")
-# SUPABASE_URI = os.getenv("SUPABASE_URI")
-# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
-
-# # Check if environment variables are successfully loaded
-# if not SUPABASE_URL or not SUPABASE_KEY or not OPENAI_API_KEY or not SUPABASE_URI:
-#     print("Please ensure SUPABASE_URL, SUPABASE_KEY, and OPENAI_API_KEY are correctly set in the .env file.")
-#     sys.exit(1)
-# else:
-#     print("Connection successful.")
-#     try:
-#         supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
-#         print("Client created successfully.")
-#     except Exception as e:
-#         print("Client creation failed:", e)
-#         sys.exit(1)
-
-# # List all table names
-# try:
-#     response = supabase.table('information_schema.tables').select('table_name').eq('table_schema', 'public').execute()
-#     table_names = [table['table_name'] for table in response.data]
-#     print("All table names:")
-#     for name in table_names:
-#         print(name)
-# except Exception as e:
-#     print("Connection failed:", e)
-#     sys.exit(1)
-
-
-# ### Test hugging face tokens for the TAIDE local model. ######################################################
-# from transformers import AutoTokenizer, AutoModelForCausalLM
-
-# token = os.getenv("HF_API_KEY_7B4BIT")
-
-# # Check if the token is loaded correctly
-# if token is None:
-#     raise ValueError("Hugging Face API token is not set. Please check your environment.env file.")
-
-# # Load the tokenizer and model with the token
-# try:
-#     tokenizer = AutoTokenizer.from_pretrained("../TAIDE-LX-7B-Chat-4bit", token=token)  
-#     model = AutoModelForCausalLM.from_pretrained("../TAIDE-LX-7B-Chat-4bit", token=token)
-    
-#     # Verify the model and tokenizer
-#     print(f"Loaded tokenizer: {tokenizer.name_or_path}")
-#     print(f"Loaded model: {model.name_or_path}")
-
-#     # Optional: Print model and tokenizer configuration for more details
-#     print(f"Model configuration: {model.config}")
-#     print(f"Tokenizer configuration: {tokenizer}")
-
-# except Exception as e:
-#     print(f"Error loading model or tokenizer: {e}")
-
-#################################################################################################################
-# import torch
-# from transformers import AutoModelForCausalLM, AutoTokenizer
-# from huggingface_hub import hf_hub_download
-# from llama_cpp import Llama
-
-# ## Download the GGUF model
-# model_name = "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"
-# model_file = "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf" # this is the specific model file we'll use in this example. It's a 4-bit quant, but other levels of quantization are available in the model repo if preferred
-# model_path = hf_hub_download(model_name, filename=model_file)
-
-
-
-
-# import requests
-
-# def generate_response(input_text, max_length=512, temperature=0.7):
-#     # URL to interact with the model
-#     url = "http://localhost:11434/v1/chat/completions"  # Adjust based on how Ollama exposes the model
-
-#     # Payload to send to the model
-#     payload = {
-#         "input": input_text,
-#         "parameters": {
-#             "max_length": max_length,
-#             "temperature": temperature
-#         }
-#     }
-
-#     # Make a request to the model
-#     response = requests.post(url, json=payload)
-#     return response.json()["output"]
-
-# if __name__ == "__main__":
-#     input_text = "I believe the meaning of life is"
-#     response = generate_response(input_text, max_length=128, temperature=0.5)
-#     print(f"Model: {response}")