Indexing_Split.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from dotenv import load_dotenv
  2. load_dotenv('environment.env')
  3. from langchain_openai import OpenAIEmbeddings
  4. from langchain_community.embeddings import OllamaEmbeddings
  5. from langchain_community.vectorstores import Chroma
  6. from langchain_community.document_loaders import TextLoader
  7. from langchain.text_splitter import CharacterTextSplitter
  8. from langchain_text_splitters import RecursiveCharacterTextSplitter
  9. from langchain_core.documents import Document
  10. from langchain_community.document_loaders import PyPDFLoader
  11. from langchain_community.document_loaders import Docx2txtLoader
  12. from langchain_community.document_loaders import WebBaseLoader
  13. from PyPDF2 import PdfReader
  14. from langchain.docstore.document import Document
  15. from json import loads
  16. import pandas as pd
  17. from sqlalchemy import create_engine
  18. from langchain.prompts import ChatPromptTemplate
  19. from langchain_openai import ChatOpenAI
  20. from langchain_core.output_parsers import StrOutputParser
  21. from langchain import hub
  22. from tqdm import tqdm
  23. # __import__('pysqlite3')
  24. # import sys
  25. # sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
  26. from datasets import Dataset
  27. from ragas import evaluate
  28. from ragas.metrics import (
  29. answer_relevancy,
  30. faithfulness,
  31. context_recall,
  32. context_precision,
  33. )
  34. import pandas as pd
  35. import os
  36. import glob
  37. import openai
  38. URI = os.getenv("SUPABASE_URI")
  39. openai_api_key = os.getenv("OPENAI_API_KEY")
  40. openai.api_key = openai_api_key
  41. from RAG_strategy import multi_query, naive_rag
  42. def create_retriever(path='Documents', extension="pdf"):
  43. txt_files = glob.glob(os.path.join(path, f"*.{extension}"))
  44. doc = []
  45. for file_path in txt_files:
  46. doc.append(file_path)
  47. def load_and_split(file_list):
  48. chunks = []
  49. for file in file_list:
  50. if file.endswith(".txt"):
  51. loader = TextLoader(file, encoding='utf-8')
  52. elif file.endswith(".pdf"):
  53. loader = PyPDFLoader(file)
  54. elif file.endswith(".docx"):
  55. loader = Docx2txtLoader(file)
  56. else:
  57. raise ValueError(f"Unsupported file extension: {file}")
  58. docs = loader.load()
  59. # Split
  60. if file.endswith(".docx"):
  61. # separators = ["\n\n\u25cb", "\n\n\u25cf"]
  62. # text_splitter = RecursiveCharacterTextSplitter(separators=separators, chunk_size=500, chunk_overlap=0)
  63. separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
  64. text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
  65. else:
  66. text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
  67. splits = text_splitter.split_documents(docs)
  68. chunks.extend(splits)
  69. return chunks
  70. # Index
  71. docs = load_and_split(doc)
  72. qa_history_doc = gen_doc_from_history()
  73. docs.extend(qa_history_doc)
  74. # web_doc = web_data(os.path.join(path, 'web_url.csv'))
  75. # docs.extend(web_doc)
  76. # vectorstore
  77. # vectorstore = Chroma.from_texts(texts=docs, embedding=OpenAIEmbeddings())
  78. # vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings(openai_api_key=openai_api_key))
  79. # vectorstore = Chroma.from_documents(documents=docs, embedding=OllamaEmbeddings(model="llama3", num_gpu=1))
  80. vectorstore = Chroma.from_documents(documents=docs, embedding=OllamaEmbeddings(model="gemma2"))
  81. vectorstore.persist()
  82. retriever = vectorstore.as_retriever()
  83. return retriever
  84. def web_data(url_file):
  85. df = pd.read_csv(url_file, header = 0)
  86. url_list = df['url'].to_list()
  87. loader = WebBaseLoader(url_list)
  88. docs = loader.load()
  89. text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  90. chunk_size=1000, chunk_overlap=0)
  91. splits = text_splitter.split_documents(docs)
  92. return splits
  93. def gen_doc_from_history():
  94. engine = create_engine(URI, echo=True)
  95. df = pd.read_sql_table("systex_records", engine.connect())
  96. df.fillna('', inplace=True)
  97. result = df.to_json(orient='index', force_ascii=False)
  98. result = loads(result)
  99. df = pd.DataFrame(result).T
  100. qa_history_doc = []
  101. for i in range(len(df)):
  102. if df.iloc[i]['used_as_document'] is not True: continue
  103. Question = df.iloc[i]['Question']
  104. Answer = df.iloc[i]['Answer']
  105. context = f'Question: {Question}\nAnswer: {Answer}'
  106. doc = Document(page_content=context, metadata={"source": "History"})
  107. qa_history_doc.append(doc)
  108. # print(doc)
  109. return qa_history_doc
  110. def gen_doc_from_database():
  111. engine = create_engine(URI, echo=True)
  112. df = pd.read_sql_table("QA_database", engine.connect())
  113. # df.fillna('', inplace=True)
  114. result = df[['Question', 'Answer']].to_json(orient='index', force_ascii=False)
  115. result = loads(result)
  116. df = pd.DataFrame(result).T
  117. qa_doc = []
  118. for i in range(len(df)):
  119. # if df.iloc[i]['used_as_document'] is not True: continue
  120. Question = df.iloc[i]['Question']
  121. Answer = df.iloc[i]['Answer']
  122. context = f'Question: {Question}\nAnswer: {Answer}'
  123. doc = Document(page_content=context, metadata={"source": "History"})
  124. qa_doc.append(doc)
  125. # print(doc)
  126. return qa_doc
  127. if __name__ == "__main__":
  128. retriever = create_retriever(path='./Documents', extension="pdf")
  129. question = 'CEV系統可以支援盤查到什麼程度'
  130. final_answer, reference_docs = multi_query(question, retriever)
  131. print(question, final_answer)
  132. question = 'CEV系統依循標準為何'
  133. final_answer, reference_docs = multi_query(question, retriever)
  134. print(question, final_answer)