123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import sqlparse
- from sqlparse.sql import Comparison, Parenthesis, Token
- from sqlparse.tokens import Literal
- from langchain_community.utilities import SQLDatabase
- from dotenv import load_dotenv
- import os
- load_dotenv()
- # URI = os.getenv("SUPABASE_URI")
- # db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5)
- def get_query_columns(sql, get_real_name=False):
- stmt = sqlparse.parse(sql)[0]
- columns = []
- column_identifiers = []
- # get column_identifieres
- in_select = False
- for token in stmt.tokens:
- if isinstance(token, sqlparse.sql.Comment):
- continue
- if str(token).lower() == 'select':
- in_select = True
- elif in_select and token.ttype is None:
- if isinstance(token, sqlparse.sql.IdentifierList):
- for identifier in token.get_identifiers():
- column_identifiers.append(identifier)
- else:
- column_identifiers.append(token)
- break
- # get column names
- for column_identifier in column_identifiers:
- if get_real_name:
- columns.append(column_identifier.get_real_name())
- else:
- columns.append(column_identifier.get_name())
- return columns
- def extract_comparison_value(tokens, target):
- """Helper function to extract value based on a comparison target."""
- is_target = False
- for token in tokens:
- if token.value.strip("'\"") == target:
- is_target = True
- elif is_target and token.ttype is Literal.String.Single:
- return token.value.strip("'\"")
- elif is_target and isinstance(token, Parenthesis):
- data = db.run(token.value.strip("()"))
- return eval(data)[0][0]
- return None
- def parse_sql_where(sql):
- """Parse the SQL statement to extract 排放源, 類別"""
- stmt = sqlparse.parse(sql)[0]
- column_dict = {
- "排放源": None,
- "類別": None
- }
- def get_column_details(token, column_args):
- if isinstance(token, Comparison):
- print(token, type(token))
- for column_name in column_args.keys():
- if column_args[column_name] is None:
- column_args[column_name] = extract_comparison_value(token.tokens, column_name)
-
- return column_args
- for token in stmt.tokens:
- if isinstance(token, sqlparse.sql.Comment):
- continue
- if token.value.lower().startswith('where'):
- for token2 in token.tokens:
- # print(token2, type(token2))
- if isinstance(token2, Comparison):
- column_dict = get_column_details(token2, column_dict)
- elif isinstance(token2, Parenthesis):
- # print(token2, type(token2))
- for token3 in token2.tokens:
- column_dict = get_column_details(token3, column_dict)
- column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys()]
- return column_values
- def get_table_name(sql):
- stmt = sqlparse.parse(sql)[0]
- in_from = False
- for token in stmt.tokens:
- if isinstance(token, sqlparse.sql.Comment):
- continue
- if str(token).lower() == 'from':
- in_from = True
- elif in_from and token.ttype is None:
- if isinstance(token, sqlparse.sql.Identifier):
- # print(token, token.ttype)
- return token.value
- if __name__ == "__main__":
- sql_query = """
- SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
- FROM "104_112碳排放公開及建準資料"
- WHERE "事業名稱" like '%建準%'
- AND "排放源" = '固定燃燒'
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
- """
-
- print(get_query_columns(sql_query, get_real_name=True))
- print(parse_sql_for_stock_info(sql_query))
- print(get_table_name(sql_query))
|