post_processing_sqlparse.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import sqlparse
  2. from sqlparse.sql import Comparison, Parenthesis, Token
  3. from sqlparse.tokens import Literal
  4. from langchain_community.utilities import SQLDatabase
  5. from dotenv import load_dotenv
  6. import os
  7. load_dotenv()
  8. # URI = os.getenv("SUPABASE_URI")
  9. # db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5)
  10. def get_query_columns(sql, get_real_name=False):
  11. stmt = sqlparse.parse(sql)[0]
  12. columns = []
  13. column_identifiers = []
  14. # get column_identifieres
  15. in_select = False
  16. for token in stmt.tokens:
  17. if isinstance(token, sqlparse.sql.Comment):
  18. continue
  19. if str(token).lower() == 'select':
  20. in_select = True
  21. elif in_select and token.ttype is None:
  22. if isinstance(token, sqlparse.sql.IdentifierList):
  23. for identifier in token.get_identifiers():
  24. column_identifiers.append(identifier)
  25. else:
  26. column_identifiers.append(token)
  27. break
  28. # get column names
  29. for column_identifier in column_identifiers:
  30. if get_real_name:
  31. columns.append(column_identifier.get_real_name())
  32. else:
  33. columns.append(column_identifier.get_name())
  34. return columns
  35. def extract_comparison_value(tokens, target):
  36. """Helper function to extract value based on a comparison target."""
  37. is_target = False
  38. for token in tokens:
  39. if token.value.strip("'\"") == target:
  40. is_target = True
  41. elif is_target and token.ttype is Literal.String.Single:
  42. return token.value.strip("'\"")
  43. elif is_target and isinstance(token, Parenthesis):
  44. data = db.run(token.value.strip("()"))
  45. return eval(data)[0][0]
  46. return None
  47. def parse_sql_where(sql):
  48. """Parse the SQL statement to extract 排放源, 類別"""
  49. stmt = sqlparse.parse(sql)[0]
  50. column_dict = {
  51. "排放源": None,
  52. "類別": None,
  53. "範疇": None,
  54. "類別項目": None,
  55. "項目": None,
  56. }
  57. def get_column_details(token, column_args):
  58. if isinstance(token, Comparison):
  59. # print(token, type(token))
  60. for column_name in column_args.keys():
  61. if column_args[column_name] is None:
  62. column_args[column_name] = extract_comparison_value(token.tokens, column_name)
  63. return column_args
  64. for token in stmt.tokens:
  65. if isinstance(token, sqlparse.sql.Comment):
  66. continue
  67. if token.value.lower().startswith('where'):
  68. for token2 in token.tokens:
  69. # print(token2, type(token2))
  70. if isinstance(token2, Comparison):
  71. column_dict = get_column_details(token2, column_dict)
  72. elif isinstance(token2, Parenthesis):
  73. # print(token2, type(token2))
  74. for token3 in token2.tokens:
  75. column_dict = get_column_details(token3, column_dict)
  76. column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys() if column_dict[column_name] is not None]
  77. return column_values
  78. def get_table_name(sql):
  79. stmt = sqlparse.parse(sql)[0]
  80. in_from = False
  81. for token in stmt.tokens:
  82. if isinstance(token, sqlparse.sql.Comment):
  83. continue
  84. if str(token).lower() == 'from':
  85. in_from = True
  86. elif in_from and token.ttype is None:
  87. if isinstance(token, sqlparse.sql.Identifier):
  88. # print(token, token.ttype)
  89. return token.value
  90. if __name__ == "__main__":
  91. sql_query = """
  92. SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  93. FROM "104_112碳排放公開及建準資料"
  94. WHERE "事業名稱" like '%建準%'
  95. AND "排放源" = '固定燃燒'
  96. AND "盤查標準" = 'GHG'
  97. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
  98. """
  99. sql_query = """
  100. SELECT SUM("用電度數(kwh)") AS "綠電使用量"
  101. FROM "用電度數"
  102. WHERE "項目" = '自產電力(綠電)'
  103. AND "盤查標準" = 'GHG'
  104. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1
  105. """
  106. print(get_query_columns(sql_query, get_real_name=True))
  107. print(parse_sql_where(sql_query))
  108. print(get_table_name(sql_query))