Indexing_Split.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from dotenv import load_dotenv
  2. load_dotenv()
  3. from langchain_openai import OpenAIEmbeddings
  4. from langchain_community.embeddings import OllamaEmbeddings
  5. from langchain_community.vectorstores import Chroma
  6. from langchain_community.document_loaders import TextLoader
  7. from langchain.text_splitter import CharacterTextSplitter
  8. from langchain_text_splitters import RecursiveCharacterTextSplitter
  9. from langchain_core.documents import Document
  10. from langchain_community.document_loaders import PyPDFLoader
  11. from langchain_community.document_loaders import Docx2txtLoader
  12. from langchain_community.document_loaders import WebBaseLoader
  13. from PyPDF2 import PdfReader
  14. from langchain.docstore.document import Document
  15. from json import loads
  16. import pandas as pd
  17. from sqlalchemy import create_engine
  18. from langchain.prompts import ChatPromptTemplate
  19. from langchain_openai import ChatOpenAI
  20. from langchain_core.output_parsers import StrOutputParser
  21. from langchain import hub
  22. from tqdm import tqdm
  23. # __import__('pysqlite3')
  24. # import sys
  25. # sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
  26. from datasets import Dataset
  27. # from ragas import evaluate
  28. # from ragas.metrics import (
  29. # answer_relevancy,
  30. # faithfulness,
  31. # context_recall,
  32. # context_precision,
  33. # )
  34. import pandas as pd
  35. import os
  36. import glob
  37. from dotenv import load_dotenv
  38. import os
  39. load_dotenv()
  40. URI = os.getenv("SUPABASE_URI")
  41. from RAG_strategy import multi_query, naive_rag
  42. def create_retriever(path='Documents', extension="pdf"):
  43. txt_files = glob.glob(os.path.join(path, f"*.{extension}"))
  44. doc = []
  45. for file_path in txt_files:
  46. doc.append(file_path)
  47. def load_and_split(file_list):
  48. chunks = []
  49. for file in file_list:
  50. if file.endswith(".txt"):
  51. loader = TextLoader(file, encoding='utf-8')
  52. elif file.endswith(".pdf"):
  53. loader = PyPDFLoader(file)
  54. elif file.endswith(".docx"):
  55. loader = Docx2txtLoader(file)
  56. else:
  57. raise ValueError(f"Unsupported file extension: {file}")
  58. docs = loader.load()
  59. # Split
  60. if file.endswith(".docx"):
  61. # separators = ["\n\n\u25cb", "\n\n\u25cf"]
  62. # text_splitter = RecursiveCharacterTextSplitter(separators=separators, chunk_size=500, chunk_overlap=0)
  63. separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
  64. text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
  65. else:
  66. text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
  67. splits = text_splitter.split_documents(docs)
  68. chunks.extend(splits)
  69. return chunks
  70. # Index
  71. docs = load_and_split(doc)
  72. qa_history_doc = gen_doc_from_history()
  73. docs.extend(qa_history_doc)
  74. # web_doc = web_data(os.path.join(path, 'web_url.csv'))
  75. # docs.extend(web_doc)
  76. # vectorstore
  77. # vectorstore = Chroma.from_texts(texts=docs, embedding=OpenAIEmbeddings())
  78. vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings())
  79. # vectorstore = Chroma.from_documents(documents=docs, embedding=OllamaEmbeddings(model="llama3", num_gpu=1))
  80. vectorstore.persist()
  81. retriever = vectorstore.as_retriever()
  82. return retriever
  83. def web_data(url_file):
  84. df = pd.read_csv(url_file, header = 0)
  85. url_list = df['url'].to_list()
  86. loader = WebBaseLoader(url_list)
  87. docs = loader.load()
  88. text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  89. chunk_size=1000, chunk_overlap=0)
  90. splits = text_splitter.split_documents(docs)
  91. return splits
  92. def gen_doc_from_history():
  93. engine = create_engine(URI, echo=True)
  94. df = pd.read_sql_table("systex_records", engine.connect())
  95. df.fillna('', inplace=True)
  96. result = df.to_json(orient='index', force_ascii=False)
  97. result = loads(result)
  98. df = pd.DataFrame(result).T
  99. qa_history_doc = []
  100. for i in range(len(df)):
  101. if df.iloc[i]['used_as_document'] is not True: continue
  102. Question = df.iloc[i]['Question']
  103. Answer = df.iloc[i]['Answer']
  104. context = f'Question: {Question}\nAnswer: {Answer}'
  105. doc = Document(page_content=context, metadata={"source": "History"})
  106. qa_history_doc.append(doc)
  107. # print(doc)
  108. return qa_history_doc
  109. def gen_doc_from_database():
  110. engine = create_engine(URI, echo=True)
  111. df = pd.read_sql_table("QA_database", engine.connect())
  112. # df.fillna('', inplace=True)
  113. result = df[['Question', 'Answer']].to_json(orient='index', force_ascii=False)
  114. result = loads(result)
  115. df = pd.DataFrame(result).T
  116. qa_doc = []
  117. for i in range(len(df)):
  118. # if df.iloc[i]['used_as_document'] is not True: continue
  119. Question = df.iloc[i]['Question']
  120. Answer = df.iloc[i]['Answer']
  121. context = f'Question: {Question}\nAnswer: {Answer}'
  122. doc = Document(page_content=context, metadata={"source": "History"})
  123. qa_doc.append(doc)
  124. # print(doc)
  125. return qa_doc
  126. if __name__ == "__main__":
  127. retriever = create_retriever(path='./Documents', extension="pdf")
  128. question = 'CEV系統可以支援盤查到什麼程度'
  129. final_answer, reference_docs = multi_query(question, retriever)
  130. print(question, final_answer)
  131. question = 'CEV系統依循標準為何'
  132. final_answer, reference_docs = multi_query(question, retriever)
  133. print(question, final_answer)