post_processing_sqlparse.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import sqlparse
  2. from sqlparse.sql import Comparison, Parenthesis, Token
  3. from sqlparse.tokens import Literal
  4. from langchain_community.utilities import SQLDatabase
  5. from dotenv import load_dotenv
  6. import os
  7. load_dotenv()
  8. # URI = os.getenv("SUPABASE_URI")
  9. # db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5)
  10. def get_query_columns(sql, get_real_name=False):
  11. stmt = sqlparse.parse(sql)[0]
  12. columns = []
  13. column_identifiers = []
  14. # get column_identifieres
  15. in_select = False
  16. for token in stmt.tokens:
  17. if isinstance(token, sqlparse.sql.Comment):
  18. continue
  19. if str(token).lower() == 'select':
  20. in_select = True
  21. elif in_select and token.ttype is None:
  22. if isinstance(token, sqlparse.sql.IdentifierList):
  23. for identifier in token.get_identifiers():
  24. column_identifiers.append(identifier)
  25. else:
  26. column_identifiers.append(token)
  27. break
  28. # get column names
  29. for column_identifier in column_identifiers:
  30. if get_real_name:
  31. columns.append(column_identifier.get_real_name())
  32. else:
  33. columns.append(column_identifier.get_name())
  34. return columns
  35. def extract_comparison_value(tokens, target):
  36. """Helper function to extract value based on a comparison target."""
  37. is_target = False
  38. for token in tokens:
  39. if token.value.strip("'\"") == target:
  40. is_target = True
  41. elif is_target and token.ttype is Literal.String.Single:
  42. return token.value.strip("'\"")
  43. elif is_target and isinstance(token, Parenthesis):
  44. data = db.run(token.value.strip("()"))
  45. return eval(data)[0][0]
  46. return None
  47. def parse_sql_where(sql):
  48. """Parse the SQL statement to extract 排放源, 類別"""
  49. stmt = sqlparse.parse(sql)[0]
  50. column_dict = {
  51. "排放源": None,
  52. "類別": None,
  53. "類別項目": None,
  54. "項目": None,
  55. }
  56. def get_column_details(token, column_args):
  57. if isinstance(token, Comparison):
  58. # print(token, type(token))
  59. for column_name in column_args.keys():
  60. if column_args[column_name] is None:
  61. column_args[column_name] = extract_comparison_value(token.tokens, column_name)
  62. return column_args
  63. for token in stmt.tokens:
  64. if isinstance(token, sqlparse.sql.Comment):
  65. continue
  66. if token.value.lower().startswith('where'):
  67. for token2 in token.tokens:
  68. # print(token2, type(token2))
  69. if isinstance(token2, Comparison):
  70. column_dict = get_column_details(token2, column_dict)
  71. elif isinstance(token2, Parenthesis):
  72. # print(token2, type(token2))
  73. for token3 in token2.tokens:
  74. column_dict = get_column_details(token3, column_dict)
  75. column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys() if column_dict[column_name] is not None]
  76. return column_values
  77. def get_table_name(sql):
  78. stmt = sqlparse.parse(sql)[0]
  79. in_from = False
  80. for token in stmt.tokens:
  81. if isinstance(token, sqlparse.sql.Comment):
  82. continue
  83. if str(token).lower() == 'from':
  84. in_from = True
  85. elif in_from and token.ttype is None:
  86. if isinstance(token, sqlparse.sql.Identifier):
  87. # print(token, token.ttype)
  88. return token.value
  89. if __name__ == "__main__":
  90. sql_query = """
  91. SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  92. FROM "104_112碳排放公開及建準資料"
  93. WHERE "事業名稱" like '%建準%'
  94. AND "排放源" = '固定燃燒'
  95. AND "盤查標準" = 'GHG'
  96. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
  97. """
  98. sql_query = """
  99. SELECT SUM("用電度數(kwh)") AS "綠電使用量"
  100. FROM "用電度數"
  101. WHERE "項目" = '自產電力(綠電)'
  102. AND "盤查標準" = 'GHG'
  103. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1
  104. """
  105. print(get_query_columns(sql_query, get_real_name=True))
  106. print(parse_sql_where(sql_query))
  107. print(get_table_name(sql_query))