post_processing_sqlparse.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import sqlparse
  2. from sqlparse.sql import Comparison, Parenthesis, Token
  3. from sqlparse.tokens import Literal
  4. from langchain_community.utilities import SQLDatabase
  5. from dotenv import load_dotenv
  6. import os
  7. load_dotenv()
  8. # URI = os.getenv("SUPABASE_URI")
  9. # db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5)
  10. def get_query_columns(sql, get_real_name=False):
  11. stmt = sqlparse.parse(sql)[0]
  12. columns = []
  13. column_identifiers = []
  14. # get column_identifieres
  15. in_select = False
  16. for token in stmt.tokens:
  17. if isinstance(token, sqlparse.sql.Comment):
  18. continue
  19. if str(token).lower() == 'select':
  20. in_select = True
  21. elif in_select and token.ttype is None:
  22. if isinstance(token, sqlparse.sql.IdentifierList):
  23. for identifier in token.get_identifiers():
  24. column_identifiers.append(identifier)
  25. else:
  26. column_identifiers.append(token)
  27. break
  28. # get column names
  29. for column_identifier in column_identifiers:
  30. if get_real_name:
  31. columns.append(column_identifier.get_real_name())
  32. else:
  33. columns.append(column_identifier.get_name())
  34. return columns
  35. def extract_comparison_value(tokens, target):
  36. """Helper function to extract value based on a comparison target."""
  37. is_target = False
  38. for token in tokens:
  39. if token.value.strip("'\"") == target:
  40. is_target = True
  41. elif is_target and token.ttype is Literal.String.Single:
  42. return token.value.strip("'\"")
  43. elif is_target and isinstance(token, Parenthesis):
  44. data = db.run(token.value.strip("()"))
  45. return eval(data)[0][0]
  46. return None
  47. def parse_sql_where(sql):
  48. """Parse the SQL statement to extract 排放源, 類別"""
  49. stmt = sqlparse.parse(sql)[0]
  50. column_dict = {
  51. "排放源": None,
  52. "類別": None
  53. }
  54. def get_column_details(token, column_args):
  55. if isinstance(token, Comparison):
  56. print(token, type(token))
  57. for column_name in column_args.keys():
  58. if column_args[column_name] is None:
  59. column_args[column_name] = extract_comparison_value(token.tokens, column_name)
  60. return column_args
  61. for token in stmt.tokens:
  62. if isinstance(token, sqlparse.sql.Comment):
  63. continue
  64. if token.value.lower().startswith('where'):
  65. for token2 in token.tokens:
  66. # print(token2, type(token2))
  67. if isinstance(token2, Comparison):
  68. column_dict = get_column_details(token2, column_dict)
  69. elif isinstance(token2, Parenthesis):
  70. # print(token2, type(token2))
  71. for token3 in token2.tokens:
  72. column_dict = get_column_details(token3, column_dict)
  73. column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys()]
  74. return column_values
  75. def get_table_name(sql):
  76. stmt = sqlparse.parse(sql)[0]
  77. in_from = False
  78. for token in stmt.tokens:
  79. if isinstance(token, sqlparse.sql.Comment):
  80. continue
  81. if str(token).lower() == 'from':
  82. in_from = True
  83. elif in_from and token.ttype is None:
  84. if isinstance(token, sqlparse.sql.Identifier):
  85. # print(token, token.ttype)
  86. return token.value
  87. if __name__ == "__main__":
  88. sql_query = """
  89. SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  90. FROM "104_112碳排放公開及建準資料"
  91. WHERE "事業名稱" like '%建準%'
  92. AND "排放源" = '固定燃燒'
  93. AND "盤查標準" = 'GHG'
  94. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
  95. """
  96. print(get_query_columns(sql_query, get_real_name=True))
  97. print(parse_sql_for_stock_info(sql_query))
  98. print(get_table_name(sql_query))