add_vectordb.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from dotenv import load_dotenv
  2. load_dotenv("../.env")
  3. from langchain_openai import OpenAIEmbeddings
  4. from langchain_community.vectorstores import Chroma
  5. from langchain_community.document_loaders import TextLoader
  6. from langchain_text_splitters import RecursiveCharacterTextSplitter
  7. from langchain_community.document_loaders import PyPDFLoader
  8. from langchain_community.document_loaders import Docx2txtLoader
  9. from langchain.document_loaders import CSVLoader
  10. import os
  11. import glob
  12. from langchain_community.vectorstores import SupabaseVectorStore
  13. from langchain_openai import OpenAIEmbeddings
  14. from supabase.client import Client, create_client
  15. document_table = "documents"
  16. def get_data_list(data_list=None, path=None, extension=None, update=False):
  17. files = data_list or glob.glob(os.path.join(path, f"*.{extension}"))
  18. if update:
  19. doc = files.copy()
  20. else:
  21. existed_data = check_existed_data(supabase)
  22. doc = []
  23. for file_path in files:
  24. filename = os.path.basename(file_path)
  25. if filename not in existed_data:
  26. doc.append(file_path)
  27. return doc
  28. def read_and_split_files(data_list=None, path=None, extension=None, update=False):
  29. def read_csv(path):
  30. extension = "csv"
  31. # path = r"./Phase2/"
  32. files = glob.glob(os.path.join(path, f"*.{extension}"))
  33. documents = []
  34. for file_path in files:
  35. print(file_path)
  36. loader = CSVLoader(file_path, encoding="utf-8")
  37. doc = loader.load()
  38. documents.extend(doc)
  39. return documents
  40. def load_and_split(file_list):
  41. chunks = []
  42. for file in file_list:
  43. if file.endswith(".txt"):
  44. loader = TextLoader(file, encoding='utf-8')
  45. elif file.endswith(".pdf"):
  46. loader = PyPDFLoader(file)
  47. elif file.endswith(".docx"):
  48. loader = Docx2txtLoader(file)
  49. else:
  50. print(f"Unsupported file extension: {file}")
  51. continue
  52. docs = loader.load()
  53. # Split
  54. if file.endswith(".docx"):
  55. separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
  56. text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
  57. else:
  58. text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
  59. splits = text_splitter.split_documents(docs)
  60. chunks.extend(splits)
  61. doc = read_csv(path)
  62. chunks.extend(doc)
  63. return chunks
  64. # specific data type
  65. doc = get_data_list(data_list=data_list, path=path, extension=extension, update=update)
  66. # web url
  67. # csv
  68. # doc = read_csv(path)
  69. # Index
  70. docs = load_and_split(doc)
  71. return docs
  72. def create_ids(docs):
  73. # Create a dictionary to count occurrences of each page in each document
  74. page_counter = {}
  75. # List to store the resulting IDs
  76. document_ids = []
  77. # Generate IDs
  78. for doc in [docs[i].metadata for i in range(len(docs))]:
  79. source = doc['source']
  80. file_name = os.path.basename(source).split('.')[0]
  81. if "page" in doc.keys():
  82. page = doc['page']
  83. key = f"{source}_{page}"
  84. else:
  85. key = f"{source}"
  86. if key not in page_counter:
  87. page_counter[key] = 1
  88. else:
  89. page_counter[key] += 1
  90. if "page" in doc.keys():
  91. doc_id = f"{file_name} | page {page} | chunk {page_counter[key]}"
  92. else:
  93. doc_id = f"{file_name} | chunk {page_counter[key]}"
  94. document_ids.append(doc_id)
  95. return document_ids
  96. def get_document(data_list=None, path=None, extension=None, update=False):
  97. docs = read_and_split_files(data_list=data_list, path=path, extension=extension, update=update)
  98. document_ids = create_ids(docs)
  99. for doc in docs:
  100. doc.metadata['source'] = os.path.basename(doc.metadata['source'])
  101. # print(doc.metadata)
  102. # document_metadatas = [{'source': doc.metadata['source'], 'page': doc.metadata['page'], 'chunk': int(id.split("chunk ")[-1])} for doc, id in zip(docs, document_ids)]
  103. document_metadatas = []
  104. for doc, id in zip(docs, document_ids):
  105. chunk_number = int(id.split("chunk ")[-1])
  106. doc.metadata['chunk'] = chunk_number
  107. doc.metadata['extension'] = os.path.basename(doc.metadata['source']).split(".")[-1]
  108. document_metadatas.append(doc.metadata)
  109. documents = [docs.metadata['source'].split(".")[0] + docs.page_content for docs in docs]
  110. return document_ids, documents, document_metadatas
  111. def check_existed_data(supabase):
  112. response = supabase.table(document_table).select("id, metadata").execute()
  113. existed_data = list(set([data['metadata']['source'] for data in response.data]))
  114. # existed_data = [(data['id'], data['metadata']['source']) for data in response.data]
  115. return existed_data
  116. class GetVectorStore(SupabaseVectorStore):
  117. def __init__(self, embeddings, supabase, table_name):
  118. super().__init__(embedding=embeddings, client=supabase, table_name=table_name, query_name="match_documents")
  119. def insert(self, documents, document_metadatas):
  120. self.add_texts(
  121. texts=documents,
  122. metadatas=document_metadatas,
  123. )
  124. def delete(self, file_list):
  125. for file_name in file_list:
  126. self._client.table(self.table_name).delete().eq('metadata->>source', file_name).execute()
  127. def update(self, documents, document_metadatas, update_existing_data=False):
  128. if not document_metadatas: # no new data
  129. return
  130. if update_existing_data:
  131. file_list = list(set(metadata['source'] for metadata in document_metadatas))
  132. self.delete(file_list)
  133. self.insert(documents, document_metadatas)
  134. if __name__ == "__main__":
  135. load_dotenv("../.env")
  136. supabase_url = os.environ.get("SUPABASE_URL")
  137. supabase_key = os.environ.get("SUPABASE_KEY")
  138. document_table = "documents"
  139. supabase: Client = create_client(supabase_url, supabase_key)
  140. embeddings = OpenAIEmbeddings()
  141. ###################################################################################
  142. # get vector store
  143. vector_store = GetVectorStore(embeddings, supabase, document_table)
  144. ###################################################################################
  145. # update data (old + new / all new / all old)
  146. path = "/home/mia/systex/Documents"
  147. extension = "pdf"
  148. # file = None
  149. # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
  150. # file = [os.path.join(path, file) for file in file_list]
  151. # file_list = glob.glob(os.path.join(path, "*"))
  152. file_list =glob.glob(os.path.join(path, f"*.{extension}"))
  153. # print(file_list)
  154. # update = False
  155. # document_ids, documents, document_metadatas = get_document(data_list=file_list, path=path, extension=extension, update=update)
  156. # vector_store.update(documents, document_metadatas, update_existing_data=update)
  157. ###################################################################################
  158. # insert new data (all new)
  159. # vector_store.insert(documents, document_metadatas)
  160. ###################################################################################
  161. # delete data
  162. # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
  163. file_list = glob.glob(os.path.join(path, f"*.docx"))
  164. file_list = [os.path.basename(file_path) for file_path in file_list]
  165. print(file_list)
  166. vector_store.delete(file_list)
  167. ###################################################################################
  168. # get retriver
  169. # retriever = vector_store.as_retriever(search_kwargs={"k": 6})