Преглед изворни кода

integrated taide model to RAG_strategy

SherryLiu пре 7 месеци
родитељ
комит
ef69f16e46
6 измењених фајлова са 102 додато и 370 уклоњено
  1. 11 64
      RAG_app_copy.py
  2. 61 144
      RAG_strategy.py
  3. 0 29
      ragas_data_generation.py
  4. 0 31
      run.sh
  5. 30 0
      taide_rag.py
  6. 0 102
      test_connection.py

+ 11 - 64
RAG_app_copy.py

@@ -1,39 +1,27 @@
 from dotenv import load_dotenv
 load_dotenv('environment.env')
 
-from fastapi import FastAPI, Request, HTTPException, status, Body
-# from fastapi.templating import Jinja2Templates
+from fastapi import FastAPI, HTTPException, status, Body, Depends
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from fastapi import Depends
 from contextlib import asynccontextmanager
 from pydantic import BaseModel
 from typing import List, Optional
 import uvicorn
 
-import sqlparse
 from sqlalchemy import create_engine
 import pandas as pd
-#from retrying import retry
 import datetime
 import json
 from json import loads
 import time
 from langchain.callbacks import get_openai_callback
 
-from langchain_community.vectorstores import Chroma
 from langchain_openai import OpenAIEmbeddings
 from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
-from Indexing_Split import create_retriever as split_retriever
-from Indexing_Split import gen_doc_from_database, gen_doc_from_history
 
 import os
-from langchain_community.vectorstores import SupabaseVectorStore
-from langchain_openai import OpenAIEmbeddings
 from supabase.client import Client, create_client
 from add_vectordb import GetVectorStore
-from langchain_community.cache import RedisSemanticCache  # 更新导入路径
-from langchain_core.prompts import PromptTemplate
 import openai
 
 # Get API log
@@ -44,23 +32,17 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
 URI = os.getenv("SUPABASE_URI")
 openai.api_key = openai_api_key
 
-
 global_retriever = None
 
-# 定義FastAPI的生命週期管理器,在啟動和關閉時執行特定操作
 @asynccontextmanager
 async def lifespan(app: FastAPI):
     global global_retriever
     global vector_store
     
     start = time.time()
-    # global_retriever = split_retriever(path='./Documents', extension="docx")
-    # global_retriever = raptor_retriever(path='../Documents', extension="txt")
-    # global_retriever = unstructured_retriever(path='../Documents')
 
     supabase_url = os.getenv("SUPABASE_URL")
     supabase_key = os.getenv("SUPABASE_KEY")
-    URI = os.getenv("SUPABASE_URI")
     document_table = "documents"
     supabase: Client = create_client(supabase_url, supabase_key)
 
@@ -68,21 +50,17 @@ async def lifespan(app: FastAPI):
     vector_store = GetVectorStore(embeddings, supabase, document_table)
     global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
 
-    print(time.time() - start)
+    print(f"Initialization time: {time.time() - start}")
     yield
 
-# 定義依賴注入函數,用於在請求處理過程中獲取全局變量
 def get_retriever():
     return global_retriever
 
-
 def get_vector_store():
     return vector_store
 
-# 創建FastAPI應用實例並配置以及中間件
 app = FastAPI(lifespan=lifespan)
 
-# templates = Jinja2Templates(directory="temp")
 app.add_middleware(
     CORSMiddleware,
     allow_origins=["*"],
@@ -91,31 +69,16 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-
-# 定義API路由和處理函數
-# 處理傳入的問題並返回答案
 @app.get("/answer2")
 def multi_query_answer(question, retriever=Depends(get_retriever)):
     try:
         start = time.time()
 
         with get_openai_callback() as cb:
-            # qa_doc = gen_doc_from_database()
-            # qa_history_doc = gen_doc_from_history()
-            # qa_doc.extend(qa_history_doc)
-            # vectorstore = Chroma.from_documents(documents=qa_doc, embedding=OpenAIEmbeddings(), collection_name="qa_pairs")
-            # retriever_qa = vectorstore.as_retriever(search_kwargs={"k": 3})
-            # final_answer, reference_docs = naive_rag_for_qapairs(question, retriever_qa)
-            final_answer = 'False'
-            if final_answer == 'False':
-                final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
-
-        # print(CHAT_HISTORY)
-        
-        # with get_openai_callback() as cb:
-        #     final_answer, reference_docs = multi_query(question, retriever)
+            final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
+
         processing_time = time.time() - start
-        print(processing_time)
+        print(f"Processing time: {processing_time}")
         save_history(question, final_answer, reference_docs, cb, processing_time)
 
         return {"Answer": final_answer}
@@ -127,48 +90,39 @@ class ChatHistoryItem(BaseModel):
     q: str
     a: str
 
-# 處理帶有歷史聊天紀錄的問題並返回答案
 @app.post("/answer_with_history")
 def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
     start = time.time()
     
     chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
-    print(chat_history)
-
-    # TODO: similarity search
+    print(f"Chat history: {chat_history}")
     
     with get_openai_callback() as cb:
         final_answer, reference_docs = multi_query(question, retriever, chat_history)
     processing_time = time.time() - start
-    print(processing_time)
+    print(f"Processing time: {processing_time}")
     save_history(question, final_answer, reference_docs, cb, processing_time)
 
     return {"Answer": final_answer}
 
-# 處理帶有聊天歷史紀錄和文件名過濾的問題,並返回答案
 @app.post("/answer_with_history2")
 def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
     start = time.time()
 
-    retriever = vector_store.as_retriever(search_kwargs={"k": 4,
-                                                         'filter': {'extension':extension}})
+    retriever = vector_store.as_retriever(search_kwargs={"k": 4, 'filter': {'extension':extension}})
     
     chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
-    print(chat_history)
-
-    # TODO: similarity search
+    print(f"Chat history: {chat_history}")
     
     with get_openai_callback() as cb:
         final_answer, reference_docs = multi_query(question, retriever, chat_history)
     processing_time = time.time() - start
-    print(processing_time)
+    print(f"Processing time: {processing_time}")
     save_history(question, final_answer, reference_docs, cb, processing_time)
 
     return {"Answer": final_answer}
 
-# 保存歷史。將處理結果儲存到數據庫
 def save_history(question, answer, reference, cb, processing_time):
-    # reference = [doc.dict() for doc in reference]
     record = {
         'Question': [question],
         'Answer': [answer],
@@ -190,7 +144,6 @@ class history_output(BaseModel):
     Processing_time: float
     Time: datetime.datetime
 
-# 定義獲取歷史紀錄的路由
 @app.get('/history', response_model=List[history_output])
 async def get_history():
     engine = create_engine(URI, echo=True)
@@ -205,11 +158,5 @@ async def get_history():
 def read_root():
     return {"message": "Welcome to the Carbon Chatbot API"}
 
-
 if __name__ == "__main__":
-    uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)
-    
-# if __name__ == "__main__":
-#     uvicorn.run("RAG_app:app", host='cmm.ai', port=8081, reload=True, ssl_keyfile="/etc/letsencrypt/live/cmm.ai/privkey.pem", 
-#                 ssl_certfile="/etc/letsencrypt/live/cmm.ai/fullchain.pem")
-
+    uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)

+ 61 - 144
RAG_strategy.py

@@ -9,7 +9,14 @@ from langchain_core.runnables import RunnablePassthrough
 from langchain import hub
 from langchain.globals import set_llm_cache
 from langchain import PromptTemplate
-from langchain.llms.base import LLM
+import subprocess
+import json
+from typing import Any, List, Optional, Dict
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
+from langchain_core.outputs import ChatResult, ChatGeneration
+from pydantic import Field
 
 from langchain_core.runnables import (
     RunnableBranch,
@@ -17,9 +24,6 @@ from langchain_core.runnables import (
     RunnableParallel,
     RunnablePassthrough,
 )
-from typing import Tuple, List, Optional
-from langchain_core.messages import AIMessage, HumanMessage
-
 
 from datasets import Dataset 
 from ragas import evaluate
@@ -29,92 +33,71 @@ from ragas.metrics import (
     context_recall,
     context_precision,
 )
-from typing import List
 import os
 from dotenv import load_dotenv
 load_dotenv('environment.env')
 
-########################################################################################################################
-########################################################################################################################
 from langchain.cache import SQLiteCache
-from langchain.cache import RedisSemanticCache
 from langchain_openai import OpenAIEmbeddings
 from langchain.globals import set_llm_cache
 
-########################################################################################################################
 import requests
 import openai
 openai_api_key = os.getenv("OPENAI_API_KEY")
 openai.api_key = openai_api_key
 URI = os.getenv("SUPABASE_URI")
 
-from typing import Optional, List, Any, Dict
-
-# 設置緩存,以減少對API的重複請求。使用Redis
+# 設置緩存,以減少對API的重複請求。使用SQLite
 set_llm_cache(SQLiteCache(database_path=".langchain.db"))
-# set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6379", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
 
-# # TAIDE model on Ollama https://ollama.com/jcai/llama3-taide-lx-8b-chat-alpha1
-# def interact_with_model(messages, api_url="http://localhost:11434/v1/chat/completions"):
-#     print("Using model: TAIDE")
-#     response = requests.post(api_url, json={"model": "jcai/llama3-taide-lx-8b-chat-alpha1:Q4_K_M", "messages": messages})
-#     return response.json()["choices"][0]["message"]["content"]
-
-import requests
-from typing import Optional, List, Any, Dict
-from langchain.llms.base import LLM
-
-class CustomTAIDELLM(LLM):
-    api_url: str = "http://localhost:11434/api/chat"
-    model_name: str = "taide-local"
-    system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
-
-    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
-        # Format the prompt according to TAIDE requirements
-        formatted_prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{prompt} [/INST]"
-        print(f"Formatted prompt being sent to TAIDE model: {formatted_prompt}")
-
-        payload = {
-            "model": self.model_name,
-            "messages": [
-                {"role": "system", "content": self.system_prompt},
-                {"role": "user", "content": prompt}
-            ]
-        }
-
-        try:
-            response = requests.post(self.api_url, json=payload)
-            response.raise_for_status()
-            result = response.json()
-            print(f"Full API response: {result}")
-            if "message" in result:
-                return result["message"]["content"].strip()
-            else:
-                print(f"Unexpected response structure: {result}")
-                return "Error: Unexpected response from the model."
-        except requests.RequestException as e:
-            print(f"Error calling Ollama API: {e}")
-            return f"Error: Unable to get response from the model. {str(e)}"
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+
+class OllamaChatModel(BaseChatModel):
+    model_name: str = Field(default="taide-local")
+
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        formatted_messages = []
+        for msg in messages:
+            if isinstance(msg, HumanMessage):
+                formatted_messages.append({"role": "user", "content": msg.content})
+            elif isinstance(msg, AIMessage):
+                formatted_messages.append({"role": "assistant", "content": msg.content})
+            elif isinstance(msg, SystemMessage):
+                 formatted_messages.append({"role": "system", "content": msg.content})
+
+        prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
+        for msg in formatted_messages:
+            if msg['role'] == 'user':
+                prompt += f"{msg['content']} [/INST]"
+            elif msg['role'] == "assistant":
+                prompt += f"{msg['content']} </s><s>[INST]"
+
+        command = ["ollama", "run", self.model_name, prompt]
+        result = subprocess.run(command, capture_output=True, text=True)
+
+        if result.returncode != 0:
+            raise Exception(f"Ollama command failed: {result.stderr}")
+        
+        content = result.stdout.strip()
 
+        message = AIMessage(content=content)
+        generation = ChatGeneration(message=message)
+        return ChatResult(generations=[generation])
+    
     @property
     def _llm_type(self) -> str:
-        return "custom_taide"
-
-    def get_model_name(self):
-        return self.model_name
-
-    @property
-    def _identifying_params(self) -> Dict[str, Any]:
-        return {"model_name": self.model_name}
-
-# Create an instance of the custom LLM
-taide_llm = CustomTAIDELLM(api_url="http://localhost:11434/api/chat", model_name="taide-local")
+        return "ollama-chat-model"
+    
+taide_llm = OllamaChatModel(model_name="taide-local")
 
-# 生成多個不同版本的問題,進行檢索,並返回答案和參考文檔
 def multi_query(question, retriever, chat_history):
-
     def multi_query_chain():
-        # Multi Query: Different Perspectives
         template = """You are an AI language model assistant. Your task is to generate three 
         different versions of the given user question to retrieve relevant documents from a vector 
         database. By generating multiple perspectives on the user question, your goal is to help
@@ -123,27 +106,12 @@ def multi_query(question, retriever, chat_history):
 
         You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
         
-        
         Original question: {question}"""
         prompt_perspectives = ChatPromptTemplate.from_template(template)
 
-        messages = [
-            {"role": "system", "content": template},
-            {"role": "user", "content": question},
-        ]
-        # generate_queries = interact_with_model(messages).split("\n")
-
-        
-        # llm = ChatOpenAI(model="gpt-4-1106-preview")
-        # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
-        # llm = ChatOllama(model="gemma2", temperature=0)
-        # llm = ChatOllama(model=model)
-
-
         generate_queries = (
             prompt_perspectives 
             | taide_llm
-            # | llm
             | StrOutputParser() 
             | (lambda x: x.split("\n"))
         )
@@ -151,14 +119,9 @@ def multi_query(question, retriever, chat_history):
         return generate_queries
 
     def get_unique_union(documents: List[list]):
-        """ Unique union of retrieved docs """
-        # Flatten list of lists, and convert each Document to string
         flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
-        # Get unique documents
         unique_docs = list(set(flattened_docs))
-        # Return
         return [loads(doc) for doc in unique_docs]
-    
 
     _search_query = get_search_query()
     modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
@@ -173,7 +136,6 @@ def multi_query(question, retriever, chat_history):
 
     return answer, docs
 
-# 根據檢索到的文檔和用戶問題生成最後答案
 def multi_query_rag_prompt(retrieval_chain, question):
     template = """Answer the following question based on this context:
 
@@ -205,25 +167,8 @@ def multi_query_rag_prompt(retrieval_chain, question):
     except Exception as e:
         print(f"Error invoking rag_chain: {e}")
         return "Error occurred while processing the question."
-    
-########################################################################################################################
 
-# 將聊天紀錄個跟進問題轉化為獨立問題
 def get_search_query():
-    # Condense a chat history and follow-up question into a standalone question
-    # 
-    # _template = """Given the following conversation and a follow up question, 
-    # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
-    # Generate standalone question in its original language.
-    # Chat History:
-    # {chat_history}
-    # Follow Up Input: {question}
-
-    # Hint:
-    # * Refer to chat history and add the subject to the question
-    # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
-    
-    # Standalone question:"""  # noqa: E501
     _template = """Rewrite the following query by incorporating relevant context from the conversation history.
     The rewritten query should:
     
@@ -244,7 +189,7 @@ def get_search_query():
     """
     CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
 
-    def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
+    def _format_chat_history(chat_history: List[tuple[str, str]]) -> List:
         buffer = []
         for human, ai in chat_history:
             buffer.append(HumanMessage(content=human))
@@ -252,11 +197,10 @@ def get_search_query():
         return buffer
 
     _search_query = RunnableBranch(
-        # If input includes chat_history, we condense it with the follow-up question
         (
             RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
                 run_name="HasChatHistoryCheck"
-            ),  # Condense follow-up question and chat into a standalone_question
+            ),
             RunnablePassthrough.assign(
                 chat_history=lambda x: _format_chat_history(x["chat_history"])
             )
@@ -264,48 +208,31 @@ def get_search_query():
             | ChatOpenAI()
             | StrOutputParser(),
         ),
-        # Else, we have no chat history, so just pass through the question
         RunnableLambda(lambda x : x["question"]),
     )
 
     return _search_query
-########################################################################################################################
-# 檢索文檔並生成答案
-def naive_rag(question, retriever):
-    #### RETRIEVAL and GENERATION ####
 
-    # Prompt
+def naive_rag(question, retriever):
     prompt = hub.pull("rlm/rag-prompt")
 
-    # LLM
-    # llm = ChatOpenAI(model_name="gpt-3.5-turbo")
-
-    # Post-processing
     def format_docs(docs):
         return "\n\n".join(doc.page_content for doc in docs)
 
     reference = retriever.get_relevant_documents(question)
     
-    # Chain
     rag_chain = (
         {"context": retriever | format_docs, "question": RunnablePassthrough()}
         | prompt
         | taide_llm
-        # | llm
         | StrOutputParser()
     )
 
-    # Question
     answer = rag_chain.invoke(question)
 
     return answer, reference
-################################################################################################
-# 處理question-answer pairs的檢索和生成答案
-def naive_rag_for_qapairs(question, retriever):
-    #### RETRIEVAL and GENERATION ####
 
-    # Prompt
-    # prompt = hub.pull("rlm/rag-prompt")
+def naive_rag_for_qapairs(question, retriever):
     template = """You are an assistant for question-answering tasks. 
     Use the following pieces of retrieved context to answer the question. 
     Following retrieved context is question-answer pairs of historical QA, Find the suitable answer from the qa pairs
@@ -320,19 +247,13 @@ def naive_rag_for_qapairs(question, retriever):
     """
     prompt = PromptTemplate.from_template(template)
 
-    # LLM
     llm = ChatOpenAI(model_name="gpt-4-0125-preview")
-    # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
-    # llm = ChatOllama(model="gemma2", num_gpu=1, temperature=0)
 
-
-    # Post-processing
     def format_docs(docs):
         return "\n\n".join(doc.page_content for doc in docs)
 
     reference = retriever.get_relevant_documents(question)
     
-    # Chain
     rag_chain = (
         {"context": retriever | format_docs, "question": RunnablePassthrough()}
         | prompt
@@ -340,19 +261,16 @@ def naive_rag_for_qapairs(question, retriever):
         | StrOutputParser()
     )
 
-    # Question
     answer = rag_chain.invoke(question)
 
     return answer, reference
-########################################################################################################################
 
 def rag_score(question, ground_truth, answer, reference_docs):
-    
     datasets = {
-              "question": [question],       # question: list[str]
-              "answer": [answer],           # answer: list[str]
-              "contexts": [reference_docs], # contexts: list[list[str]]
-              "ground_truths": [[ground_truth]] # ground_truth: list[list[str]]
+              "question": [question],
+              "answer": [answer],
+              "contexts": [reference_docs],
+              "ground_truths": [[ground_truth]]
             }
     evalsets = Dataset.from_dict(datasets)
 
@@ -371,9 +289,8 @@ def rag_score(question, ground_truth, answer, reference_docs):
     result_df.to_csv('ragas_rag.csv')
     return result
 
-
 def print_current_model(llm):
-    if isinstance(llm, CustomTAIDELLM):
-        print(f"Currently using model: {llm.get_model_name()}")
+    if isinstance(llm, OllamaChatModel):
+        print(f"Currently using model: {llm.model_name}")
     else:
         pass

+ 0 - 29
ragas_data_generation.py

@@ -1,29 +0,0 @@
-from dotenv import load_dotenv
-load_dotenv('environment.env')
-
-
-from ragas.testset.generator import TestsetGenerator
-from ragas.testset.evolutions import simple, reasoning, multi_context 
-from langchain_openai import ChatOpenAi, OpenAIEmbeddings
-from langchain_community.document_loaders import DirectoryLoader
-from langchain_community.document_loaders import PyPDFLoader
-
-loader = DirectoryLoader("Documents")
-for file in 
-documents = loader.load()
-
-
-for document in documents:
-    document.metadata['filename'] = document.metadata['source']
-
-generator_llm = ChatOpenAi(model = "gpt-3.5-turbo-16k")
-critic_llm = ChatOpenAI(model="gpt-4")
-embeddings = OpenAIEmbeddings()
-
-generator = TestGenerator.from_langchain(
-    generator_llm,
-    critic_llm,
-    embeddings
-)
-# Generate testset
-testset = generator.generate_with_langchain_docs(documents, test_size=10, distributions={simple: 0.5, reasoning: 0.25, multi_context: 0.25})

+ 0 - 31
run.sh

@@ -1,37 +1,6 @@
 #!/bin/bash
 
-# Function to check if Docker is running
-docker_running() {
-    docker info >/dev/null 2>&1
-}
 
-# Start Docker if it's not already running
-if ! docker_running; then
-    echo "Starting Docker..."
-    open -a Docker
-    
-    # Wait for Docker to start
-    while ! docker_running; do
-        echo "Waiting for Docker to start..."
-        sleep 5
-    done
-    echo "Docker is now running"
-fi
-
-# Get the script directory
-script_dir=$(dirname "$0")
-cd "$script_dir"
-
-# Start the services defined in docker-compose.yml
-echo "Starting services with Docker Compose..."
-docker-compose up -d
-echo "Waiting for services to start..."
-sleep 20 
-
-# Change to the directory containing Python script
-cd "$script_dir/systex-RAG-sherry"
-echo "Running RAG application..."
-python ollama_chat.py
 
 # 使脚本文件可执行:
 # chmod +x run.sh

+ 30 - 0
taide_rag.py

@@ -0,0 +1,30 @@
+from dotenv import load_dotenv
+from langchain.vectorstores import Chroma
+import os
+load_dotenv('environment.env')
+openai_api_key = os.getenv("OPENAI_API_KEY")
+from RAG_strategy import taide_llm, multi_query, naive_rag
+from langchain.vectorstores import FAISS
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.document_loaders import TextLoader
+from langchain.text_splitter import CharacterTextSplitter
+
+
+
+# Load and prepare a sample document
+loader = TextLoader("test_data.txt")
+documents = loader.load()
+text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
+docs = text_splitter.split_documents(documents)
+
+# Create a vector store
+embeddings = OpenAIEmbeddings()
+vectorstore = Chroma.from_documents(docs, embeddings)
+retriever = vectorstore.as_retriever()
+
+# Test multi_query
+print("\nTesting multi_query:")
+question = "什麼是碳排放獎勵辦法?"
+answer, docs = multi_query(question, retriever, [])
+print(f"Question: {question}")
+print(f"Answer: {answer}")

+ 0 - 102
test_connection.py

@@ -1,102 +0,0 @@
-# import os
-# import sys
-
-# from supabase import create_client, Client
-
-# # # Load environment variables
-# from dotenv import load_dotenv
-# load_dotenv('environment.env')
-
-# # Get Supabase configuration from environment variables
-# SUPABASE_URL = os.getenv("SUPABASE_URL")
-# SUPABASE_KEY = os.getenv("SUPABASE_KEY")
-# SUPABASE_URI = os.getenv("SUPABASE_URI")
-# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
-
-# # Check if environment variables are successfully loaded
-# if not SUPABASE_URL or not SUPABASE_KEY or not OPENAI_API_KEY or not SUPABASE_URI:
-#     print("Please ensure SUPABASE_URL, SUPABASE_KEY, and OPENAI_API_KEY are correctly set in the .env file.")
-#     sys.exit(1)
-# else:
-#     print("Connection successful.")
-#     try:
-#         supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
-#         print("Client created successfully.")
-#     except Exception as e:
-#         print("Client creation failed:", e)
-#         sys.exit(1)
-
-# # List all table names
-# try:
-#     response = supabase.table('information_schema.tables').select('table_name').eq('table_schema', 'public').execute()
-#     table_names = [table['table_name'] for table in response.data]
-#     print("All table names:")
-#     for name in table_names:
-#         print(name)
-# except Exception as e:
-#     print("Connection failed:", e)
-#     sys.exit(1)
-
-
-# ### Test hugging face tokens for the TAIDE local model. ######################################################
-# from transformers import AutoTokenizer, AutoModelForCausalLM
-
-# token = os.getenv("HF_API_KEY_7B4BIT")
-
-# # Check if the token is loaded correctly
-# if token is None:
-#     raise ValueError("Hugging Face API token is not set. Please check your environment.env file.")
-
-# # Load the tokenizer and model with the token
-# try:
-#     tokenizer = AutoTokenizer.from_pretrained("../TAIDE-LX-7B-Chat-4bit", token=token)  
-#     model = AutoModelForCausalLM.from_pretrained("../TAIDE-LX-7B-Chat-4bit", token=token)
-    
-#     # Verify the model and tokenizer
-#     print(f"Loaded tokenizer: {tokenizer.name_or_path}")
-#     print(f"Loaded model: {model.name_or_path}")
-
-#     # Optional: Print model and tokenizer configuration for more details
-#     print(f"Model configuration: {model.config}")
-#     print(f"Tokenizer configuration: {tokenizer}")
-
-# except Exception as e:
-#     print(f"Error loading model or tokenizer: {e}")
-
-#################################################################################################################
-# import torch
-# from transformers import AutoModelForCausalLM, AutoTokenizer
-# from huggingface_hub import hf_hub_download
-# from llama_cpp import Llama
-
-# ## Download the GGUF model
-# model_name = "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"
-# model_file = "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf" # this is the specific model file we'll use in this example. It's a 4-bit quant, but other levels of quantization are available in the model repo if preferred
-# model_path = hf_hub_download(model_name, filename=model_file)
-
-
-
-
-# import requests
-
-# def generate_response(input_text, max_length=512, temperature=0.7):
-#     # URL to interact with the model
-#     url = "http://localhost:11434/v1/chat/completions"  # Adjust based on how Ollama exposes the model
-
-#     # Payload to send to the model
-#     payload = {
-#         "input": input_text,
-#         "parameters": {
-#             "max_length": max_length,
-#             "temperature": temperature
-#         }
-#     }
-
-#     # Make a request to the model
-#     response = requests.post(url, json=payload)
-#     return response.json()["output"]
-
-# if __name__ == "__main__":
-#     input_text = "I believe the meaning of life is"
-#     response = generate_response(input_text, max_length=128, temperature=0.5)
-#     print(f"Model: {response}")