123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- from dotenv import load_dotenv
- load_dotenv()
- from langchain_openai import OpenAIEmbeddings
- from langchain_community.embeddings import OllamaEmbeddings
- from langchain_community.vectorstores import Chroma
- from langchain_community.document_loaders import TextLoader
- from langchain.text_splitter import CharacterTextSplitter
- from langchain_text_splitters import RecursiveCharacterTextSplitter
- from langchain_core.documents import Document
- from langchain_community.document_loaders import PyPDFLoader
- from langchain_community.document_loaders import Docx2txtLoader
- from langchain_community.document_loaders import WebBaseLoader
- from PyPDF2 import PdfReader
- from langchain.docstore.document import Document
- from json import loads
- import pandas as pd
- from sqlalchemy import create_engine
- from langchain.prompts import ChatPromptTemplate
- from langchain_openai import ChatOpenAI
- from langchain_core.output_parsers import StrOutputParser
- from langchain import hub
- from tqdm import tqdm
- # __import__('pysqlite3')
- # import sys
- # sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
- from datasets import Dataset
- # from ragas import evaluate
- # from ragas.metrics import (
- # answer_relevancy,
- # faithfulness,
- # context_recall,
- # context_precision,
- # )
- import pandas as pd
- import os
- import glob
- from dotenv import load_dotenv
- import os
- load_dotenv()
- URI = os.getenv("SUPABASE_URI")
- from RAG_strategy import multi_query, naive_rag
- def create_retriever(path='Documents', extension="pdf"):
- txt_files = glob.glob(os.path.join(path, f"*.{extension}"))
-
- doc = []
- for file_path in txt_files:
- doc.append(file_path)
-
- def load_and_split(file_list):
- chunks = []
- for file in file_list:
- if file.endswith(".txt"):
- loader = TextLoader(file, encoding='utf-8')
- elif file.endswith(".pdf"):
- loader = PyPDFLoader(file)
- elif file.endswith(".docx"):
- loader = Docx2txtLoader(file)
- else:
- raise ValueError(f"Unsupported file extension: {file}")
-
- docs = loader.load()
- # Split
- if file.endswith(".docx"):
- # separators = ["\n\n\u25cb", "\n\n\u25cf"]
- # text_splitter = RecursiveCharacterTextSplitter(separators=separators, chunk_size=500, chunk_overlap=0)
- separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
- text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
- else:
- text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
-
- splits = text_splitter.split_documents(docs)
- chunks.extend(splits)
- return chunks
- # Index
- docs = load_and_split(doc)
- qa_history_doc = gen_doc_from_history()
- docs.extend(qa_history_doc)
- # web_doc = web_data(os.path.join(path, 'web_url.csv'))
- # docs.extend(web_doc)
- # vectorstore
- # vectorstore = Chroma.from_texts(texts=docs, embedding=OpenAIEmbeddings())
- vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings())
- # vectorstore = Chroma.from_documents(documents=docs, embedding=OllamaEmbeddings(model="llama3", num_gpu=1))
- vectorstore.persist()
- retriever = vectorstore.as_retriever()
- return retriever
- def web_data(url_file):
- df = pd.read_csv(url_file, header = 0)
- url_list = df['url'].to_list()
- loader = WebBaseLoader(url_list)
- docs = loader.load()
- text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
- chunk_size=1000, chunk_overlap=0)
- splits = text_splitter.split_documents(docs)
-
- return splits
- def gen_doc_from_history():
- engine = create_engine(URI, echo=True)
- df = pd.read_sql_table("systex_records", engine.connect())
- df.fillna('', inplace=True)
- result = df.to_json(orient='index', force_ascii=False)
- result = loads(result)
- df = pd.DataFrame(result).T
- qa_history_doc = []
- for i in range(len(df)):
- if df.iloc[i]['used_as_document'] is not True: continue
- Question = df.iloc[i]['Question']
- Answer = df.iloc[i]['Answer']
- context = f'Question: {Question}\nAnswer: {Answer}'
-
- doc = Document(page_content=context, metadata={"source": "History"})
- qa_history_doc.append(doc)
- # print(doc)
- return qa_history_doc
- def gen_doc_from_database():
- engine = create_engine(URI, echo=True)
- df = pd.read_sql_table("QA_database", engine.connect())
- # df.fillna('', inplace=True)
- result = df[['Question', 'Answer']].to_json(orient='index', force_ascii=False)
- result = loads(result)
- df = pd.DataFrame(result).T
- qa_doc = []
- for i in range(len(df)):
- # if df.iloc[i]['used_as_document'] is not True: continue
- Question = df.iloc[i]['Question']
- Answer = df.iloc[i]['Answer']
- context = f'Question: {Question}\nAnswer: {Answer}'
-
- doc = Document(page_content=context, metadata={"source": "History"})
- qa_doc.append(doc)
- # print(doc)
- return qa_doc
- if __name__ == "__main__":
- retriever = create_retriever(path='./Documents', extension="pdf")
- question = 'CEV系統可以支援盤查到什麼程度'
- final_answer, reference_docs = multi_query(question, retriever)
- print(question, final_answer)
- question = 'CEV系統依循標準為何'
- final_answer, reference_docs = multi_query(question, retriever)
- print(question, final_answer)
|