SherryLiu 7 bulan lalu
melakukan
fb23c66940
7 mengubah file dengan 384 tambahan dan 0 penghapusan
  1. 13 0
      RAG/config.py
  2. 140 0
      RAG/embeddings.py
  3. 61 0
      RAG/main.py
  4. 55 0
      RAG/models.py
  5. 62 0
      RAG/rag_chain.py
  6. 24 0
      RAG/run.sh
  7. 29 0
      RAG/text_processing.py

+ 13 - 0
RAG/config.py

@@ -0,0 +1,13 @@
+import os
+from dotenv import load_dotenv
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+parent_dir = os.path.dirname(current_dir)
+env_path = os.path.join(parent_dir, 'environment.env')
+load_dotenv(env_path)
+
+EMBEDDINGS_FILE = 'qa_embeddings.pkl'
+FAISS_INDEX_FILE = 'qa_faiss_index.bin'
+CSV_FILE = 'log_record_rows.csv'
+
+system_prompt = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"

+ 140 - 0
RAG/embeddings.py

@@ -0,0 +1,140 @@
+import os
+from dotenv import load_dotenv
+from langchain_core.documents import Document
+from langchain_openai import OpenAIEmbeddings
+from json import loads
+from sqlalchemy import create_engine
+import pandas as pd
+import numpy as np
+import faiss
+import pickle
+
+# Get the current script's directory
+current_dir = os.path.dirname(os.path.abspath(__file__))
+parent_dir = os.path.dirname(current_dir)
+env_path = os.path.join(parent_dir, 'environment.env')
+load_dotenv(env_path)
+
+URI = os.getenv("SUPABASE_URI")
+openai_api_key = os.getenv("OPENAI_API_KEY")
+
+EMBEDDINGS_FILE = 'qa_embeddings.pkl'
+FAISS_INDEX_FILE = 'qa_faiss_index.bin'
+CSV_FILE = 'log_record_rows.csv'
+
+def gen_doc_from_database():
+    engine = create_engine(URI, echo=True)
+    df = pd.read_sql_table("log_record", engine.connect())  
+    result = df[['question', 'answer']].to_json(orient='id', force_ascii=False)
+    result = loads(result)
+    df = pd.DataFrame(result).T
+    df.drop_duplicates(subset=['question', 'answer'], keep='first', inplace=True)
+    print(f"Number of records after removing duplicates: {len(df)}")
+    
+    qa_doc = []
+    for i in range(len(df)):
+        Question = df.iloc[i]['question']
+        Answer = df.iloc[i]['answer']
+        context = f'question: {Question}\nanswer: {Answer}'
+        doc = Document(page_content=context)
+        qa_doc.append(doc)
+    return qa_doc, df
+
+def gen_doc_from_csv(csv_filename=CSV_FILE):
+    csv_path = os.path.join(current_dir, csv_filename)
+    df = pd.read_csv(csv_path)
+    df.drop_duplicates(subset=['question', 'answer'], keep='first', inplace=True)
+    print(f"Number of records after removing duplicates: {len(df)}")
+    
+    qa_doc = []
+    for _, row in df.iterrows():
+        Question = row['question']
+        Answer = row['answer']
+        context = f'question: {Question}\nanswer: {Answer}'
+        doc = Document(page_content=context)
+        qa_doc.append(doc)
+    return qa_doc, df
+
+def create_embeddings(docs):
+    embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
+    print("Creating embeddings...")
+    doc_embeddings = embeddings.embed_documents([doc.page_content for doc in docs])
+    print(f"Created {len(doc_embeddings)} embeddings.")
+    print(f"Each embedding is a vector of length {len(doc_embeddings[0])}.")
+    return doc_embeddings
+
+def create_faiss_index(embeddings):
+    dimension = len(embeddings[0])
+    index = faiss.IndexFlatL2(dimension)
+    index.add(np.array(embeddings).astype('float32'))
+    return index
+
+def save_embeddings(embeddings, docs, df):
+    with open(os.path.join(current_dir, EMBEDDINGS_FILE), 'wb') as f:
+        pickle.dump({'embeddings': embeddings, 'docs': docs, 'df': df}, f)
+    
+    index = create_faiss_index(embeddings)
+    faiss.write_index(index, os.path.join(current_dir, FAISS_INDEX_FILE))
+    
+    print(f"Saved embeddings to {EMBEDDINGS_FILE} and FAISS index to {FAISS_INDEX_FILE}")
+
+def load_embeddings():
+    embeddings_path = os.path.join(current_dir, EMBEDDINGS_FILE)
+    faiss_path = os.path.join(current_dir, FAISS_INDEX_FILE)
+    
+    if os.path.exists(embeddings_path) and os.path.exists(faiss_path):
+        with open(embeddings_path, 'rb') as f:
+            data = pickle.load(f)
+        
+        index = faiss.read_index(faiss_path)
+        
+        print("Loaded existing embeddings and FAISS index from files")
+        return data['embeddings'], data['docs'], data['df'], index
+    else:
+        raise FileNotFoundError("Embeddings or FAISS index file not found. Please run embeddings.py first.")
+
+def similarity_search(query, index, docs, k=3, threshold=0.83, method='logistic', sigma=1.0):
+    embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
+    query_vector = embeddings.embed_query(query)
+    
+    # 增加 k 值以提高找到足夠唯一文檔的機會
+    D, I = index.search(np.array([query_vector]).astype('float32'), k * 2)
+    
+    results = []
+    seen_docs = set()  # 用於跟踪已經看到的文檔
+    
+    for dist, idx in zip(D[0], I[0]):
+        if method == 'logistic':
+            similarity = 1 / (1 + dist)
+        elif method == 'exponential':
+            similarity = np.exp(-dist)
+        elif method == 'gaussian':
+            similarity = np.exp(-dist**2 / (2 * sigma**2))
+        else:
+            raise ValueError("Unknown similarity method")
+        
+        if similarity >= threshold:
+            doc_content = docs[idx].page_content
+            if doc_content not in seen_docs:  # 檢查是否已經添加過這個文檔
+                results.append((docs[idx], similarity))
+                seen_docs.add(doc_content)
+                
+                if len(results) == k:  # 如果我們已經找到了 k 個唯一的文檔,就停止搜索
+                    break
+    
+    return results
+
+def main():
+    # Choose which method to use for generating documents
+    use_csv = True  # Set this to False if you want to use the database instead
+    
+    if use_csv:
+        docs, df = gen_doc_from_csv()
+    else:
+        docs, df = gen_doc_from_database()
+    
+    embeddings = create_embeddings(docs)
+    save_embeddings(embeddings, docs, df)
+
+if __name__ == "__main__":
+    main()

+ 61 - 0
RAG/main.py

@@ -0,0 +1,61 @@
+import time
+import pandas as pd
+from config import (
+    current_dir, CSV_FILE, system_prompt, 
+    EMBEDDINGS_FILE, FAISS_INDEX_FILE
+)
+from langchain.globals import set_llm_cache
+from langchain_community.cache import SQLiteCache
+from embeddings import load_embeddings
+from rag_chain import get_context, simple_rag_prompt, calculate_similarity
+
+# Set up cache
+set_llm_cache(SQLiteCache(database_path=".langchain.db"))
+
+def main():
+    # 測試前N個問題
+    n = 8
+    embeddings, docs, df, index = load_embeddings()
+    
+    retrieval_chain = lambda q: get_context(q, index, docs)
+    
+    csv_path = f"{current_dir}/{CSV_FILE}"
+    qa_df = pd.read_csv(csv_path)
+    
+    output_file = 'rag_output.txt'
+    
+    with open(output_file, 'w', encoding='utf-8') as f:
+        for i in range(n):  
+            try:
+                question = qa_df.iloc[i]['question']
+                original_answer = qa_df.iloc[i]['answer']
+                
+                start_time = time.time()
+                rag_answer, similarity_score = simple_rag_prompt(retrieval_chain, question)
+                end_time = time.time()
+                
+                response_time = end_time - start_time
+                answer_similarity = calculate_similarity(original_answer, rag_answer)
+                
+                f.write(f"Question {i+1}: {question}\n")
+                f.write(f"Original Answer: {original_answer}\n")
+                f.write(f"RAG Answer: {rag_answer}\n")
+                f.write(f"Response Time: {response_time:.2f} seconds\n")
+                f.write(f"Retrieval Similarity Score: {similarity_score:.4f}\n")
+                f.write(f"Answer Similarity Score: {answer_similarity:.4f}\n")
+                f.write("-" * 50 + "\n")
+                
+                f.flush()
+                print(f"Processed question {i+1}")
+                
+                time.sleep(1) 
+            except Exception as e:
+                print(f"Error processing question {i+1}: {str(e)}")
+                f.write(f"Error processing question {i+1}: {str(e)}\n")
+                f.write("-" * 50 + "\n")
+                f.flush()
+    
+    print(f"Output has been saved to {output_file}")
+
+if __name__ == "__main__":
+    main()

+ 55 - 0
RAG/models.py

@@ -0,0 +1,55 @@
+from langchain_core.language_models import BaseChatModel
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.outputs import ChatResult, ChatGeneration
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
+from pydantic import Field
+import subprocess
+from typing import Any, List, Optional
+from tenacity import retry, stop_after_attempt, wait_random_exponential
+from config import system_prompt
+
+class OllamaChatModel(BaseChatModel):
+    model_name: str = Field(default="taide-local-3")
+
+    @retry(stop=stop_after_attempt(3), wait=wait_random_exponential(min=1, max=60))
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        formatted_messages = []
+        for msg in messages:
+            if isinstance(msg, HumanMessage):
+                formatted_messages.append({"role": "user", "content": msg.content})
+            elif isinstance(msg, AIMessage):
+                formatted_messages.append({"role": "assistant", "content": msg.content})
+            elif isinstance(msg, SystemMessage):
+                 formatted_messages.append({"role": "system", "content": msg.content})
+
+        prompt = f"<s>[INST] {system_prompt}\n\n"
+        for msg in formatted_messages:
+            if msg['role'] == 'user':
+                prompt += f"{msg['content']} [/INST]"
+            elif msg['role'] == "assistant":
+                prompt += f"{msg['content']} </s><s>[INST]"
+
+        command = ["ollama", "run", self.model_name, prompt]
+        try:
+            result = subprocess.run(command, capture_output=True, text=True, timeout=60)  # 添加60秒超時
+        except subprocess.TimeoutExpired:
+            raise Exception("Ollama command timed out")
+
+        if result.returncode != 0:
+            raise Exception(f"Ollama command failed: {result.stderr}")
+        
+        content = result.stdout.strip()
+
+        message = AIMessage(content=content)
+        generation = ChatGeneration(message=message)
+        return ChatResult(generations=[generation])
+    
+    @property
+    def _llm_type(self) -> str:
+        return "ollama-chat-model"

+ 62 - 0
RAG/rag_chain.py

@@ -0,0 +1,62 @@
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.runnables import RunnablePassthrough
+from models import OllamaChatModel
+from embeddings import similarity_search
+from text_processing import remove_unwanted_content
+from langchain_openai import OpenAIEmbeddings
+from sklearn.metrics.pairwise import cosine_similarity
+import os
+
+
+
+taide_llm = OllamaChatModel(model_name="taide-local-3")
+
+def get_context(query, index, docs):
+    results = similarity_search(query, index, docs)
+    context = "\n".join([doc.page_content for doc, _ in results])
+
+    # 印出問題和搜尋到的文檔的前幾個字
+    print(f"Question: {query}")
+    print("Retrieved documents:")
+    for i, (doc, similarity) in enumerate(results):
+        print(f"Doc {i+1} (similarity: {similarity:.4f}): {doc.page_content[:50]}...")
+    print("-" * 50)
+    
+    return context, results[0][1] if results else (context, 0)  # Return context and top similarity score
+
+def simple_rag_prompt(retrieval_chain, question):
+    template = """Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English.
+    You should not mention anything about "根據提供的文件內容" or other similar terms. Do not mention anything relate with the Documents or context.
+    If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝。I'm sorry I cannot answer your question. Please send your question to test@email.com for further assistance. Thank you."
+    """
+
+    prompt = ChatPromptTemplate.from_template(template)
+    context, similarity_score = retrieval_chain(question)
+
+    final_rag_chain = (
+        {"context": lambda x: context, 
+        "question": lambda x: x} 
+        | prompt
+        | taide_llm
+        | StrOutputParser()
+    )
+
+    try:
+        answer = final_rag_chain.invoke(question)
+        answer = remove_unwanted_content(answer)  # 使用 remove_unwanted_content
+        return answer, similarity_score
+    except Exception as e:
+        print(f"Error invoking rag_chain: {e}")
+        return "Error occurred while processing the question.", 0
+
+def calculate_similarity(text1, text2):
+    embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
+    emb1 = embeddings.embed_query(text1)
+    emb2 = embeddings.embed_query(text2)
+    return cosine_similarity([emb1], [emb2])[0][0]

+ 24 - 0
RAG/run.sh

@@ -0,0 +1,24 @@
+#!/bin/bash
+
+# Environment
+conda create -n conda-env python=3.9
+conda activate conda-env
+pip install -r requirements.txt
+
+# Downloads
+wget https://huggingface.co/taide/Llama3-TAIDE-LX-8B-Chat-Alpha1-4bit/tree/main?show_file_info=taide-8b-a.3-q4_k_m.gguf
+curl https://ollama.ai/install.sh | sh
+
+# Build Taide model
+mkdir -p Modelfile
+mv taide-8b-a.3-q4_k_m.gguf Modelfile/
+ollama create taide-local-3 -f Modelfile
+
+python embeddings.py
+python RAG.py
+
+# Make the script executable:
+# chmod +x run.sh
+
+
+

+ 29 - 0
RAG/text_processing.py

@@ -0,0 +1,29 @@
+def remove_unwanted_content(answer):
+    ## 可以試試用similarity score 排除這類詞
+    unwanted_phrases = [
+        "<<SYS>> 你是一個來自台灣的AI助理,名字叫TAIDE,樂於用繁體中文幫助使用者,會根據問題提供相關答案。> <</SYS>>",
+        "TAIDE 敬上",
+        "你是一個來自台灣的AI助理,我的名字是 TAIDE,我很高興用繁體中文幫助您!請告訴我如何才能為您服務呢?",
+        "此處無需提供文件或上下文,因為已經在先前的回應中提及過。若真有需要,可再次詢問相關內容。",
+        "If you have any further questions or need additional assistance, please do not hesitate to contact us. Thank you!",
+        "我在這裡可以告訴你",
+        "根據提供的資訊,",
+        "根據您的問題,",
+        "請注意,以上資訊僅供參考。",
+        "如果您還有其他問題,請隨時問我。",
+        "希望這個資訊對您有幫助。",
+        "很高興能為您解答這個問題。",
+        "這些信息是根據我所知道的最新資料提供的。",
+        "如果您需要更詳細的資訊,建議您查看官方網站或直接聯繫相關單位。",
+        "我很抱歉,作為AI助理,我無法知道您所提及的特定文件或上下文。",
+        "[/INST]",
+        "[/ANS]"
+    ]
+    
+    for phrase in unwanted_phrases:
+        answer = answer.replace(phrase, "")
+    
+    answer = answer.strip()
+    answer = '\n'.join(line for line in answer.splitlines() if line.strip())
+    
+    return answer