add_vectordb.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. from dotenv import load_dotenv
  2. load_dotenv("../.env")
  3. from langchain_openai import OpenAIEmbeddings
  4. from langchain_community.vectorstores import Chroma
  5. from langchain_community.document_loaders import TextLoader
  6. from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
  7. from langchain_community.document_loaders import PyPDFLoader
  8. from langchain_community.document_loaders import Docx2txtLoader
  9. from langchain.document_loaders import CSVLoader
  10. import os
  11. import glob
  12. from langchain_community.vectorstores import SupabaseVectorStore
  13. from langchain_openai import OpenAIEmbeddings
  14. from supabase.client import Client, create_client
  15. load_dotenv()
  16. supabase_url = os.environ.get("SUPABASE_URL")
  17. supabase_key = os.environ.get("SUPABASE_KEY")
  18. document_table = "documents2"
  19. supabase: Client = create_client(supabase_url, supabase_key)
  20. def get_data_list(data_list=None, path=None, extension=None, update=False):
  21. files = data_list or glob.glob(os.path.join(path, f"*.{extension}"))
  22. if update:
  23. doc = files.copy()
  24. else:
  25. existed_data = check_existed_data(supabase)
  26. doc = []
  27. for file_path in files:
  28. filename = os.path.basename(file_path)
  29. if filename not in existed_data:
  30. doc.append(file_path)
  31. return doc
  32. def read_and_split_files(data_list=None, path=None, extension=None, update=False):
  33. def read_csv(path):
  34. extension = "csv"
  35. # path = r"./Phase2/"
  36. files = glob.glob(os.path.join(path, f"*.{extension}"))
  37. if not files:
  38. return None
  39. documents = []
  40. for file_path in files:
  41. print(file_path)
  42. loader = CSVLoader(file_path, encoding="utf-8")
  43. doc = loader.load()
  44. documents.extend(doc)
  45. return documents
  46. def load_and_split(file_list):
  47. chunks = []
  48. for file in file_list:
  49. if file.endswith(".txt") or file.endswith(".md"):
  50. loader = TextLoader(file, encoding='utf-8')
  51. docs = loader.load()
  52. elif file.endswith(".pdf"):
  53. loader = PyPDFLoader(file)
  54. docs = loader.load()
  55. elif file.endswith(".docx"):
  56. loader = Docx2txtLoader(file)
  57. docs = loader.load()
  58. else:
  59. print(f"Unsupported file extension: {file}")
  60. continue
  61. # Split
  62. rules = ['低碳產品獎勵辦法.docx', '公私場所固定污染源空氣污染物排放量申報管理辦法.docx', '氣候變遷因應法.docx', '氣候變遷因應法施行細則.docx',
  63. '淘汰老舊機車換購電動機車溫室氣體減量獎勵辦法訂定總說明及逐條說明.docx', '溫室氣體抵換專案管理辦法.docx', '溫室氣體排放源符合效能標準獎勵辦法.docx',
  64. '溫室氣體排放量增量抵換管理辦法.docx', '溫室氣體排放量增量抵換管理辦法訂定總說明及逐條說明.docx', '溫室氣體排放量盤查登錄及查驗管理辦法修正條文.docx',
  65. '溫室氣體排放量盤查登錄管理辦法(溫室氣體排放量盤查登錄及查驗管理辦法修正條文前身).docx', '溫室氣體自願減量專案管理辦法.docx',
  66. '溫室氣體自願減量專案管理辦法中華民國112年10月12日訂定總說明及逐條說明.docx', '溫室氣體認證機構及查驗機構管理辦法.docx',
  67. '溫室氣體階段管制目標及管制方式作業準則.docx', '碳足跡產品類別規則訂定、引用及修訂指引.docx',
  68. '老舊汽車汰舊換新溫室氣體減量獎勵辦法中華民國112年1月11日訂定總說明及逐條說明.docx']
  69. print(os.path.basename(file))
  70. if file.endswith(".docx") and os.path.basename(file) in rules:
  71. separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
  72. text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=500, chunk_overlap=0)
  73. splits = text_splitter.split_documents(docs)
  74. elif os.path.basename(file) in ["new_information.docx"]:
  75. print(file)
  76. separators = ['###']
  77. text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=500, chunk_overlap=0)
  78. splits = text_splitter.split_documents(docs)
  79. elif file.endswith(".md"):
  80. headers_to_split_on = [
  81. ("#", "Header 1"),
  82. ("##", "Header 2"),
  83. ("###", "Header 3"),
  84. ]
  85. markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
  86. splits = markdown_splitter.split_text(docs[0].page_content)
  87. for split in splits:
  88. split
  89. else:
  90. text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
  91. splits = text_splitter.split_documents(docs)
  92. chunks.extend(splits)
  93. # doc = read_csv(path)
  94. # chunks.extend(doc)
  95. return chunks
  96. # specific data type
  97. doc = get_data_list(data_list=data_list, path=path, extension=extension, update=update)
  98. # web url
  99. # csv
  100. # doc = read_csv(path)
  101. # Index
  102. docs = load_and_split(doc)
  103. return docs
  104. def create_ids(docs):
  105. # Create a dictionary to count occurrences of each page in each document
  106. page_counter = {}
  107. # List to store the resulting IDs
  108. document_ids = []
  109. # Generate IDs
  110. for doc in [docs[i].metadata for i in range(len(docs))]:
  111. if "source" in doc.keys():
  112. source = doc['source']
  113. file_name = os.path.basename(source).split('.')[0]
  114. else:
  115. source = "supplement"
  116. file_name = "supplement"
  117. if "page" in doc.keys():
  118. page = doc['page']
  119. key = f"{source}_{page}"
  120. else:
  121. key = f"{source}"
  122. if key not in page_counter:
  123. page_counter[key] = 1
  124. else:
  125. page_counter[key] += 1
  126. if "page" in doc.keys():
  127. doc_id = f"{file_name} | page {page} | chunk {page_counter[key]}"
  128. else:
  129. doc_id = f"{file_name} | chunk {page_counter[key]}"
  130. document_ids.append(doc_id)
  131. return document_ids
  132. def get_document(data_list=None, path=None, extension=None, update=False):
  133. docs = read_and_split_files(data_list=data_list, path=path, extension=extension, update=update)
  134. document_ids = create_ids(docs)
  135. for doc in docs:
  136. doc.metadata['source'] = os.path.basename(doc.metadata['source']) if 'source' in doc.metadata else "supplement.md"
  137. # print(doc.metadata)
  138. # document_metadatas = [{'source': doc.metadata['source'], 'page': doc.metadata['page'], 'chunk': int(id.split("chunk ")[-1])} for doc, id in zip(docs, document_ids)]
  139. document_metadatas = []
  140. for doc, id in zip(docs, document_ids):
  141. chunk_number = int(id.split("chunk ")[-1])
  142. doc.metadata['chunk'] = chunk_number
  143. doc.metadata['extension'] = os.path.basename(doc.metadata['source']).split(".")[-1]
  144. document_metadatas.append(doc.metadata)
  145. documents = [docs.metadata['source'].split(".")[0] + docs.page_content for docs in docs]
  146. return document_ids, documents, document_metadatas
  147. def check_existed_data(supabase):
  148. response = supabase.table(document_table).select("id, metadata").execute()
  149. existed_data = list(set([data['metadata']['source'] for data in response.data]))
  150. # existed_data = [(data['id'], data['metadata']['source']) for data in response.data]
  151. return existed_data
  152. class GetVectorStore(SupabaseVectorStore):
  153. def __init__(self, embeddings, supabase, table_name):
  154. super().__init__(embedding=embeddings, client=supabase, table_name=table_name, query_name="match_documents")
  155. def insert(self, documents, document_metadatas):
  156. self.add_texts(
  157. texts=documents,
  158. metadatas=document_metadatas,
  159. )
  160. def delete(self, file_list):
  161. for file in file_list:
  162. file_name = os.path.basename(file)
  163. self._client.table(self.table_name).delete().eq('metadata->>source', file_name).execute()
  164. def update(self, documents, document_metadatas, update_existing_data=False):
  165. if not document_metadatas: # no new data
  166. return
  167. if update_existing_data:
  168. file_list = list(set(metadata['source'] for metadata in document_metadatas))
  169. self.delete(file_list)
  170. self.insert(documents, document_metadatas)
  171. if __name__ == "__main__":
  172. load_dotenv("../.env")
  173. supabase_url = os.environ.get("SUPABASE_URL")
  174. supabase_key = os.environ.get("SUPABASE_KEY")
  175. document_table = "documents"
  176. supabase: Client = create_client(supabase_url, supabase_key)
  177. embeddings = OpenAIEmbeddings()
  178. ###################################################################################
  179. # get vector store
  180. vector_store = GetVectorStore(embeddings, supabase, document_table)
  181. ###################################################################################
  182. # update data (old + new / all new / all old)
  183. path = "/home/mia/systex/Documents"
  184. extension = "pdf"
  185. # file = None
  186. # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
  187. # file = [os.path.join(path, file) for file in file_list]
  188. # file_list = glob.glob(os.path.join(path, "*"))
  189. file_list =glob.glob(os.path.join(path, f"*.{extension}"))
  190. # print(file_list)
  191. # update = False
  192. # document_ids, documents, document_metadatas = get_document(data_list=file_list, path=path, extension=extension, update=update)
  193. # vector_store.update(documents, document_metadatas, update_existing_data=update)
  194. ###################################################################################
  195. # insert new data (all new)
  196. # vector_store.insert(documents, document_metadatas)
  197. ###################################################################################
  198. # delete data
  199. # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
  200. # file_list = glob.glob(os.path.join(path, f"*.docx"))
  201. # file_list = [os.path.basename(file_path) for file_path in file_list]
  202. # print(file_list)
  203. # vector_store.delete(file_list)
  204. ###################################################################################
  205. # get retriver
  206. # retriever = vector_store.as_retriever(search_kwargs={"k": 6})