post_processing_sqlparse.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import sqlparse
  2. from sqlparse.sql import Comparison, Parenthesis, Token
  3. from sqlparse.tokens import Literal
  4. from langchain_community.utilities import SQLDatabase
  5. from dotenv import load_dotenv
  6. import os
  7. load_dotenv()
  8. # URI = os.getenv("SUPABASE_URI")
  9. # db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5)
  10. def get_query_columns(sql, get_real_name=False):
  11. stmt = sqlparse.parse(sql)[0]
  12. columns = []
  13. column_identifiers = []
  14. # get column_identifieres
  15. in_select = False
  16. for token in stmt.tokens:
  17. if isinstance(token, sqlparse.sql.Comment):
  18. continue
  19. if str(token).lower() == 'select':
  20. in_select = True
  21. elif in_select and token.ttype is None:
  22. if isinstance(token, sqlparse.sql.IdentifierList):
  23. for identifier in token.get_identifiers():
  24. column_identifiers.append(identifier)
  25. else:
  26. column_identifiers.append(token)
  27. break
  28. # get column names
  29. for column_identifier in column_identifiers:
  30. if get_real_name:
  31. columns.append(column_identifier.get_real_name())
  32. else:
  33. columns.append(column_identifier.get_name())
  34. return columns
  35. def extract_comparison_value(tokens, target):
  36. """Helper function to extract value based on a comparison target."""
  37. is_target = False
  38. for token in tokens:
  39. if token.value.strip("'\"") == target:
  40. is_target = True
  41. elif is_target and token.ttype is Literal.String.Single:
  42. return token.value.strip("'\"")
  43. elif is_target and isinstance(token, Parenthesis):
  44. data = db.run(token.value.strip("()"))
  45. return eval(data)[0][0]
  46. return None
  47. def parse_sql_for_stock_info(sql):
  48. """Parse the SQL statement to extract 排放源, 類別"""
  49. stmt = sqlparse.parse(sql)[0]
  50. emission, class_type = None, None
  51. for token in stmt.tokens:
  52. if isinstance(token, sqlparse.sql.Comment):
  53. continue
  54. if token.value.lower().startswith('where'):
  55. for token2 in token.tokens:
  56. if isinstance(token2, Comparison):
  57. if emission is None:
  58. emission = extract_comparison_value(token2.tokens, "排放源")
  59. if class_type is None:
  60. class_type = extract_comparison_value(token2.tokens, "類別")
  61. return emission, class_type
  62. def get_table_name(sql):
  63. stmt = sqlparse.parse(sql)[0]
  64. in_from = False
  65. for token in stmt.tokens:
  66. if isinstance(token, sqlparse.sql.Comment):
  67. continue
  68. if str(token).lower() == 'from':
  69. in_from = True
  70. elif in_from and token.ttype is None:
  71. if isinstance(token, sqlparse.sql.Identifier):
  72. # print(token, token.ttype)
  73. return token.value
  74. if __name__ == "__main__":
  75. sql_query = """
  76. SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  77. FROM "104_112碳排放公開及建準資料"
  78. WHERE "事業名稱" like '%建準%'
  79. AND "排放源" = '固定燃燒'
  80. AND "盤查標準" = 'GHG'
  81. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
  82. """
  83. print(get_query_columns(sql_query, get_real_name=True))
  84. print(parse_sql_for_stock_info(sql_query))
  85. print(get_table_name(sql_query))