123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- from langchain_community.chat_models import ChatOllama
- from langchain_core.output_parsers import JsonOutputParser
- from langchain_core.prompts import PromptTemplate
- from langchain.prompts import ChatPromptTemplate
- from langchain_core.output_parsers import StrOutputParser
- # graph usage
- from pprint import pprint
- from typing import List
- from langchain_core.documents import Document
- from typing_extensions import TypedDict
- from langgraph.graph import END, StateGraph, START
- from langgraph.pregel import RetryPolicy
- # supabase db
- from langchain_community.utilities import SQLDatabase
- import os
- from dotenv import load_dotenv
- load_dotenv()
- URI: str = os.environ.get('SUPABASE_URI')
- db = SQLDatabase.from_uri(URI)
- # LLM
- # local_llm = "llama3.1:8b-instruct-fp16"
- local_llm = "llama3-groq-tool-use:latest"
- llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
- llm = ChatOllama(model=local_llm, temperature=0)
- # RAG usage
- from faiss_index import create_faiss_retriever, faiss_query
- retriever = create_faiss_retriever()
- # text-to-sql usage
- from text_to_sql2 import run, get_query, query_to_nl, table_description
- def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
-
- context = docs
-
- system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
- template = """
- <|begin_of_text|>
-
- <|start_header_id|>system<|end_header_id|>
- 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
- You should not mention anything about "根據提供的文件內容" or other similar terms.
- Use five sentences maximum and keep the answer concise.
- 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
- 勿回答無關資訊
- <|eot_id|>
-
- <|start_header_id|>user<|end_header_id|>
- Answer the following question based on this context:
- {context}
- Question: {question}
- 用繁體中文
- <|eot_id|>
-
- <|start_header_id|>assistant<|end_header_id|>
- """
- prompt = ChatPromptTemplate.from_template(
- system_prompt + "\n\n" +
- template
- )
- rag_chain = prompt | llm | StrOutputParser()
- return rag_chain.invoke({"context": context, "question": question})
- ### Hallucination Grader
- def Hallucination_Grader():
- # Prompt
- prompt = PromptTemplate(
- template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|>
- You are a grader assessing whether an answer is grounded in / supported by a set of facts.
- Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts.
- Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation.
- Return the a JSON with a single key 'score' and no premable or explanation.
- <|eot_id|><|start_header_id|>user<|end_header_id|>
- Here are the facts:
- \n ------- \n
- {documents}
- \n ------- \n
- Here is the answer: {generation}
- Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
- input_variables=["generation", "documents"],
- )
- hallucination_grader = prompt | llm_json | JsonOutputParser()
-
- return hallucination_grader
- ### Answer Grader
- def Answer_Grader():
- # Prompt
- prompt = PromptTemplate(
- template="""
- <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
- answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
- useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
- <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
- \n ------- \n
- {generation}
- \n ------- \n
- Here is the question: {question}
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
- input_variables=["generation", "question"],
- )
- answer_grader = prompt | llm_json | JsonOutputParser()
-
- return answer_grader
- # Text-to-SQL
- def run_text_to_sql(question: str):
- selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
- # question = "建準去年的固定燃燒總排放量是多少?"
- query, result, answer = run(db, question, selected_table, llm)
-
- return answer, query
- def _get_query(question: str):
- selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
- query = get_query(db, question, selected_table, llm)
- return query
- def _query_to_nl(question: str, query: str):
- answer = query_to_nl(db, question, query, llm)
- return answer
- ### SQL Grader
- def SQL_Grader():
- prompt = PromptTemplate(
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
- You are a SQL query grader assessing correctness of PostgreSQL query to a user question.
- Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
-
- Here is database description:
- {table_info}
-
- You need to check that each where statement is correctly filtered out what user question need.
-
- For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
- "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
- FROM "104_112碳排放公開及建準資料"
- WHERE "事業名稱" like '%建準%'
- AND "排放源" = '下游租賃'
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
- For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
-
- Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
- "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
- FROM "104_112碳排放公開及建準資料"
- WHERE "事業名稱" like '%台積電%'
- AND "排放源" = '固定燃燒'
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
- For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
-
- and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
-
- If the PostgreSQL query do not exactly matches the user question, grade it as incorrect.
- You need to strictly examine whether the sql PostgreSQL query matches the user question.
- Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
- Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
- <|eot_id|>
-
- <|start_header_id|>user<|end_header_id|>
- Here is the PostgreSQL query: \n\n {sql_query} \n\n
- Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
- """,
- input_variables=["table_info", "question", "sql_query"],
- )
- sql_query_grader = prompt | llm_json | JsonOutputParser()
-
- return sql_query_grader
- ### Router
- def Router():
- prompt = PromptTemplate(
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
- You are an expert at routing a user question to a vectorstore or company private data.
- Use company private data for questions about the informations about a company's greenhouse gas emissions data.
- Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG.
- You do not need to be stringent with the keywords in the question related to these topics.
- Give a binary choice 'company_private_data' or 'vectorstore' based on the question.
- Return the a JSON with a single key 'datasource' and no premable or explanation.
-
- Question to route: {question}
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
- input_variables=["question"],
- )
- question_router = prompt | llm_json | JsonOutputParser()
-
- return question_router
- class GraphState(TypedDict):
- """
- Represents the state of our graph.
- Attributes:
- question: question
- generation: LLM generation
- company_private_data: whether to search company private data
- documents: list of documents
- """
- question: str
- generation: str
- documents: List[str]
- retry: int
- sql_query: str
-
- # Node
- def retrieve_and_generation(state):
- """
- Retrieve documents from vectorstore
- Args:
- state (dict): The current graph state
- Returns:
- state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
- """
- print("---RETRIEVE---")
- question = state["question"]
- # Retrieval
- # documents = retriever.invoke(question)
- # TODO: correct Retrieval function
- documents = retriever.get_relevant_documents(question, k=30)
- # docs_documents = "\n\n".join(doc.page_content for doc in documents)
- # print(documents)
- generation = faiss_query(question, documents, llm)
- return {"documents": documents, "question": question, "generation": generation}
- def company_private_data_get_sql_query(state):
- """
- Get PostgreSQL query according to question
- Args:
- state (dict): The current graph state
- Returns:
- state (dict): return generated PostgreSQL query and record retry times
- """
- print("---SQL QUERY---")
- question = state["question"]
-
- if state["retry"]:
- retry = state["retry"]
- retry += 1
- else:
- retry = 0
- # print("RETRY: ", retry)
-
- sql_query = _get_query(question)
-
- return {"sql_query": sql_query, "question": question, "retry": retry}
-
- def company_private_data_search(state):
- """
- Execute PostgreSQL query and convert to nature language.
- Args:
- state (dict): The current graph state
- Returns:
- state (dict): Appended sql results to state
- """
- print("---SQL TO NL---")
- # print(state)
- question = state["question"]
- sql_query = state["sql_query"]
- generation = _query_to_nl(question, sql_query)
-
- # generation = [company_private_data_result]
-
- return {"sql_query": sql_query, "question": question, "generation": generation}
- ### Conditional edge
- def route_question(state):
- """
- Route question to web search or RAG.
- Args:
- state (dict): The current graph state
- Returns:
- str: Next node to call
- """
- print("---ROUTE QUESTION---")
- question = state["question"]
- # print(question)
- question_router = Router()
- source = question_router.invoke({"question": question})
- # print(source)
- print(source["datasource"])
- if source["datasource"] == "company_private_data":
- print("---ROUTE QUESTION TO TEXT-TO-SQL---")
- return "company_private_data"
- elif source["datasource"] == "vectorstore":
- print("---ROUTE QUESTION TO RAG---")
- return "vectorstore"
-
- def grade_generation_v_documents_and_question(state):
- """
- Determines whether the generation is grounded in the document and answers question.
- Args:
- state (dict): The current graph state
- Returns:
- str: Decision for next node to call
- """
- print("---CHECK HALLUCINATIONS---")
- question = state["question"]
- documents = state["documents"]
- generation = state["generation"]
-
- # print(docs_documents)
- # print(generation)
- hallucination_grader = Hallucination_Grader()
- score = hallucination_grader.invoke(
- {"documents": documents, "generation": generation}
- )
- # print(score)
- grade = score["score"]
- # Check hallucination
- if grade in ["yes", "true", 1, "1"]:
- print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
- # Check question-answering
- print("---GRADE GENERATION vs QUESTION---")
- answer_grader = Answer_Grader()
- score = answer_grader.invoke({"question": question, "generation": generation})
- grade = score["score"]
- if grade in ["yes", "true", 1, "1"]:
- print("---DECISION: GENERATION ADDRESSES QUESTION---")
- return "useful"
- else:
- print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
- return "not useful"
- else:
- pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
- return "not supported"
-
- def grade_sql_query(state):
- """
- Determines whether the Postgresql query are correct to the question
- Args:
- state (dict): The current graph state
- Returns:
- state (dict): Decision for retry or continue
- """
- print("---CHECK SQL CORRECTNESS TO QUESTION---")
- question = state["question"]
- sql_query = state["sql_query"]
- retry = state["retry"]
- # Score each doc
- sql_query_grader = SQL_Grader()
- score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
- grade = score["score"]
- # Document relevant
- if grade in ["yes", "true", 1, "1"]:
- print("---GRADE: CORRECT SQL QUERY---")
- return "correct"
- elif retry >= 5:
- print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
- return "failed"
- else:
- print("---GRADE: INCORRECT SQL QUERY---")
- return "incorrect"
- def build_graph():
- workflow = StateGraph(GraphState)
- # Define the nodes
- workflow.add_node("company_private_data_query", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5)) # web search
- workflow.add_node("company_private_data_search", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search
- workflow.add_node("retrieve_and_generation", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
-
- workflow.add_conditional_edges(
- START,
- route_question,
- {
- "company_private_data": "company_private_data_query",
- "vectorstore": "retrieve_and_generation",
- },
- )
- workflow.add_conditional_edges(
- "retrieve_and_generation",
- grade_generation_v_documents_and_question,
- {
- "not supported": "retrieve_and_generation",
- "useful": END,
- "not useful": "retrieve_and_generation",
- },
- )
- workflow.add_conditional_edges(
- "company_private_data_query",
- grade_sql_query,
- {
- "correct": "company_private_data_search",
- "incorrect": "company_private_data_query",
- "failed": END
-
- },
- )
- workflow.add_edge("company_private_data_search", END)
- app = workflow.compile()
-
- return app
- def main():
- app = build_graph()
- #建準去年的類別一排放量?
- inputs = {"question": "溫室氣體是什麼"}
- for output in app.stream(inputs, {"recursion_limit": 10}):
- for key, value in output.items():
- pprint(f"Finished running: {key}:")
- pprint(value["generation"])
-
- return value["generation"]
- if __name__ == "__main__":
- result = main()
- print("------------------------------------------------------")
- print(result)
|