import sqlparse from sqlparse.sql import Comparison, Parenthesis, Token from sqlparse.tokens import Literal from langchain_community.utilities import SQLDatabase from dotenv import load_dotenv import os load_dotenv() # URI = os.getenv("SUPABASE_URI") # db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5) def get_query_columns(sql, get_real_name=False): stmt = sqlparse.parse(sql)[0] columns = [] column_identifiers = [] # get column_identifieres in_select = False for token in stmt.tokens: if isinstance(token, sqlparse.sql.Comment): continue if str(token).lower() == 'select': in_select = True elif in_select and token.ttype is None: if isinstance(token, sqlparse.sql.IdentifierList): for identifier in token.get_identifiers(): column_identifiers.append(identifier) else: column_identifiers.append(token) break # get column names for column_identifier in column_identifiers: if get_real_name: columns.append(column_identifier.get_real_name()) else: columns.append(column_identifier.get_name()) return columns def extract_comparison_value(tokens, target): """Helper function to extract value based on a comparison target.""" is_target = False for token in tokens: if token.value.strip("'\"") == target: is_target = True elif is_target and token.ttype is Literal.String.Single: return token.value.strip("'\"") elif is_target and isinstance(token, Parenthesis): data = db.run(token.value.strip("()")) return eval(data)[0][0] return None def parse_sql_where(sql): """Parse the SQL statement to extract 排放源, 類別""" stmt = sqlparse.parse(sql)[0] column_dict = { "排放源": None, "類別": None } def get_column_details(token, column_args): if isinstance(token, Comparison): print(token, type(token)) for column_name in column_args.keys(): if column_args[column_name] is None: column_args[column_name] = extract_comparison_value(token.tokens, column_name) return column_args for token in stmt.tokens: if isinstance(token, sqlparse.sql.Comment): continue if token.value.lower().startswith('where'): for token2 in token.tokens: # print(token2, type(token2)) if isinstance(token2, Comparison): column_dict = get_column_details(token2, column_dict) elif isinstance(token2, Parenthesis): # print(token2, type(token2)) for token3 in token2.tokens: column_dict = get_column_details(token3, column_dict) column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys()] return column_values def get_table_name(sql): stmt = sqlparse.parse(sql)[0] in_from = False for token in stmt.tokens: if isinstance(token, sqlparse.sql.Comment): continue if str(token).lower() == 'from': in_from = True elif in_from and token.ttype is None: if isinstance(token, sqlparse.sql.Identifier): # print(token, token.ttype) return token.value if __name__ == "__main__": sql_query = """ SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量" FROM "104_112碳排放公開及建準資料" WHERE "事業名稱" like '%建準%' AND "排放源" = '固定燃燒' AND "盤查標準" = 'GHG' AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1; """ print(get_query_columns(sql_query, get_real_name=True)) print(parse_sql_for_stock_info(sql_query)) print(get_table_name(sql_query))