gnews_keyword_extraction.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # ----------------------------------------
  2. # Import packages
  3. # ----------------------------------------
  4. import warnings
  5. import os
  6. from random import sample
  7. from collections import Counter
  8. import pickle
  9. import pandas as pd
  10. import dataset
  11. import time
  12. import json
  13. import operator
  14. from typing import List
  15. import os
  16. import jieba
  17. import jieba.posseg as pseg
  18. import jieba.analyse
  19. import paddle
  20. import hdbscan
  21. from collections import Counter
  22. from itertools import compress
  23. from collections import Counter
  24. import pprint
  25. from sentence_transformers import SentenceTransformer
  26. import torch
  27. from nltk.corpus import stopwords
  28. import time
  29. import opencc
  30. from tqdm import tqdm
  31. from random import sample
  32. from string import punctuation
  33. import argparse
  34. import pke
  35. import pymysql
  36. pymysql.install_as_MySQLdb()
  37. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  38. warnings.filterwarnings("ignore")
  39. def chunkstring(string, length):
  40. return (string[0+i:length+i] for i in range(0, len(string), length))
  41. def print_block(text):
  42. # sentence = input('Sentence: ')
  43. width = 60
  44. print('\n┌─' + '─' * width + '─┐')
  45. for line in chunkstring(text, width):
  46. print('| {0:^{1}} |'.format(line, width))
  47. print('└─' + '─'*(width) + '─┘')
  48. if __name__ == '__main__':
  49. parser = argparse.ArgumentParser()
  50. parser.add_argument('--topK', required=False)
  51. parser.add_argument('--target_domain_range',
  52. required=False)
  53. args = parser.parse_args()
  54. start_time = time.time()
  55. # ┌──────────────────────────────────────┐
  56. # | |
  57. # | Loading data from database ... |
  58. # | |
  59. # └──────────────────────────────────────┘
  60. print_block('Loading data from database ...')
  61. db = dataset.connect(
  62. 'mysql://choozmo:pAssw0rd@db.ptt.cx:3306/hhh?charset=utf8mb4')
  63. result = db.query('SELECT * FROM gnews.gnews_detail ')
  64. data = pd.DataFrame(result, columns=next(iter(result)).keys())
  65. print('Number of news:', len(data))
  66. print('Dropping duplicates ... ', end='')
  67. data = data.drop_duplicates(subset=['news_content'], keep='first').drop_duplicates(
  68. subset=['news_url'], keep='first')
  69. print('number of news:', len(data))
  70. db.close()
  71. # ----------------------------------------
  72. # Functions
  73. # ----------------------------------------
  74. # Check if contains num
  75. def notNumStr(instr):
  76. for item in instr:
  77. if '\u0041' <= item <= '\u005a' or ('\u0061' <= item <= '\u007a') or item.isdigit():
  78. return False
  79. return True
  80. # Read Target Case if Json
  81. def readSingleTestCases(testFile):
  82. with open(testFile) as json_data:
  83. try:
  84. testData = json.load(json_data)
  85. except:
  86. # This try block deals with incorrect json format that has ' instead of "
  87. data = json_data.read().replace("'", '"')
  88. try:
  89. testData = json.loads(data)
  90. # This try block deals with empty transcript file
  91. except:
  92. return ""
  93. returnString = ""
  94. for item in testData:
  95. try:
  96. returnString += item['text']
  97. except:
  98. returnString += item['statement']
  99. return returnString
  100. class Word():
  101. def __init__(self, char, freq=0, deg=0):
  102. self.freq = freq
  103. self.deg = deg
  104. self.char = char
  105. def returnScore(self):
  106. return self.deg/self.freq
  107. def updateOccur(self, phraseLength):
  108. self.freq += 1
  109. self.deg += phraseLength
  110. def getChar(self):
  111. return self.char
  112. def updateFreq(self):
  113. self.freq += 1
  114. def getFreq(self):
  115. return self.freq
  116. class Rake:
  117. # , stopwordPath: str = None, delimWordPath: str = None):
  118. def __init__(self):
  119. # If both Found and Initialized
  120. self.initialized = False
  121. self.stopWordList = list()
  122. self.delimWordList = list()
  123. def initializeFromPath(self, stopwordPath: str = "", delimWordPath: str = ""):
  124. if not os.path.exists(stopwordPath):
  125. print("Stop Word Path invalid")
  126. return
  127. if not os.path.exists(delimWordPath):
  128. print("Delim Word Path Invalid")
  129. return
  130. swLibList = [line.rstrip('\n') for line in converter.convert(
  131. open(stopwordPath, 'r').read()).split('\n')]
  132. conjLibList = [line.rstrip('\n') for line in converter.convert(
  133. open(delimWordPath, 'r').read()).split('\n')]
  134. self.initializeFromList(swLibList, conjLibList)
  135. return
  136. def initializeFromList(self, swList: List = None, dwList: List = None):
  137. self.stopWordList = swList
  138. self.delimWordList = dwList
  139. if len(self.stopWordList) == 0 or len(self.delimWordList) == 0:
  140. print("Empty Stop word list or deliminator word list, uninitialized")
  141. return
  142. else:
  143. self.initialized = True
  144. def extractKeywordFromPath(self, text: str, num_kw: int = 10):
  145. if not self.initialized:
  146. print("Not initialized")
  147. return
  148. with open(text, 'r') as fp:
  149. text = fp.read()
  150. return self.extractKeywordFromString(text, num_kw=num_kw)
  151. def extractKeywordFromString(self, text: str, num_kw: int = 10):
  152. rawtextList = pseg.cut(text)
  153. # Construct List of Phrases and Preliminary textList
  154. textList = []
  155. listofSingleWord = dict()
  156. lastWord = ''
  157. # for jieba
  158. poSPrty = ['zg', 'm', 'x', 'uj', 'ul', 'mq', 'u', 'v', 'f', 't', 'vd', 'q', 'r', 'd', 'p', 'nr', 'r'
  159. 'c', 'TIME', 'xc', 'a', 'ad', 'an', 'nrt', 'df', 'b', 'vn', 'l', 'y', 'o', 'i']
  160. meaningfulCount = 0
  161. checklist = []
  162. for eachWord, flag in rawtextList:
  163. check_pos_list.append(str(eachWord)+'/'+str(flag))
  164. checklist.append([eachWord, flag])
  165. if eachWord in self.delimWordList or not notNumStr(eachWord) or eachWord in self.stopWordList or flag in poSPrty or eachWord == '\n':
  166. if lastWord != '|':
  167. textList.append("|")
  168. lastWord = "|"
  169. elif eachWord not in self.stopWordList and eachWord != '\n':
  170. textList.append(eachWord)
  171. meaningfulCount += 1
  172. if eachWord not in listofSingleWord:
  173. listofSingleWord[eachWord] = Word(eachWord)
  174. lastWord = ''
  175. # Construct List of list that has phrases as wrds
  176. newList = []
  177. tempList = []
  178. for everyWord in textList:
  179. if everyWord != '|':
  180. tempList.append(everyWord)
  181. else:
  182. newList.append(tempList)
  183. tempList = []
  184. tempStr = ''
  185. for everyWord in textList:
  186. if everyWord != '|':
  187. tempStr += everyWord + '|'
  188. else:
  189. if tempStr[:-1] not in listofSingleWord:
  190. listofSingleWord[tempStr[:-1]] = Word(tempStr[:-1])
  191. tempStr = ''
  192. # Update the entire List
  193. for everyPhrase in newList:
  194. res = ''
  195. for everyWord in everyPhrase:
  196. listofSingleWord[everyWord].updateOccur(len(everyPhrase))
  197. res += everyWord + '|'
  198. phraseKey = res[:-1]
  199. if phraseKey not in listofSingleWord:
  200. listofSingleWord[phraseKey] = Word(phraseKey)
  201. else:
  202. listofSingleWord[phraseKey].updateFreq()
  203. # Get score for entire Set
  204. outputList = dict()
  205. for everyPhrase in newList:
  206. if len(everyPhrase) > 5:
  207. continue
  208. score = 0
  209. phraseString = ''
  210. outStr = ''
  211. for everyWord in everyPhrase:
  212. score += listofSingleWord[everyWord].returnScore()
  213. phraseString += everyWord + '|'
  214. outStr += everyWord
  215. phraseKey = phraseString[:-1]
  216. freq = listofSingleWord[phraseKey].getFreq()
  217. if freq / meaningfulCount < 0.05 and freq < 3:
  218. continue
  219. outputList[outStr] = score
  220. sorted_list = sorted(outputList.items(),
  221. key=operator.itemgetter(1), reverse=True)
  222. sorted_list = [s[0] for s in sorted_list]
  223. return sorted_list[:num_kw]
  224. def pke_MultipartiteRank(text):
  225. # initialize a TopicRank extractor
  226. extractor = pke.unsupervised.MultipartiteRank()
  227. # load the content of the document and perform French stemming
  228. extractor.load_document(input=text,
  229. language='zh',
  230. normalization=None)
  231. # keyphrase candidate selection, here sequences of nouns and adjectives
  232. # defined by the Universal PoS tagset
  233. extractor.candidate_selection(
  234. pos={"NOUN", "PROPN"}, stoplist=customized_stopwords)
  235. # candidate weighting, here using a random walk algorithm
  236. extractor.candidate_weighting(alpha=1.1,
  237. threshold=0.65,
  238. method='average')
  239. # N-best selection, keyphrases contains the 10 highest scored candidates as
  240. # (keyphrase, score) tuples
  241. keyphrases = extractor.get_n_best(n=40) # topK
  242. return [j for sub in [k[0].split() for k in keyphrases] for j in sub]
  243. def half2full(s):
  244. n = []
  245. for c in list(s):
  246. num = ord(c)
  247. if num == 320:
  248. num = 0x3000
  249. elif 0x21 <= num <= 0x7E:
  250. num += 0xfee0
  251. num = chr(num)
  252. n.append(num)
  253. return ''.join(n)
  254. def find_tags(doc_list):
  255. pos_list = ['ns', 'n', 'nt', 'nz', 'x', 'ns', 'nrfg', 'an'] # , 'vn'
  256. punctuations = punctuation + \
  257. half2full(punctuation)+'、●{}「」[]【】()()<>《》〈〉『』〔〕'
  258. tag_list = []
  259. pbar = tqdm(range(len(doc_list)))
  260. pbar.set_description("[Extracting keywords...]")
  261. fail_count = 0
  262. for d in doc_list:
  263. pbar.update()
  264. d = d.translate(d.maketrans(
  265. punctuations, ' '*len(punctuations), ""))
  266. try:
  267. result = []
  268. result.extend(obj.extractKeywordFromString(d, num_kw=topK))
  269. result.extend(jieba.analyse.extract_tags(
  270. d, topK=topK, allowPOS=(pos_list)))
  271. result.extend(jieba.analyse.textrank(
  272. d, topK=topK, allowPOS=(pos_list)))
  273. result.extend(pke_MultipartiteRank(d))
  274. except:
  275. result = 'a'
  276. fail_count += 1
  277. print('-'*80)
  278. print('Keywords of this news are not available:\n',
  279. '...', d[50:150], '...\n', '-'*80)
  280. tags = list(filter(lambda x: len(x) > 1 and notNumStr(
  281. x) and x not in list(set(customized_stopwords)), result))
  282. tag_list.extend(tags)
  283. tag_list = list(set(tag_list))
  284. pbar.close()
  285. print('Num of keywords:', len(tag_list), end=', ')
  286. print('Fail:', fail_count, '\n')
  287. return tag_list
  288. # ┌─────────────────────────────┐
  289. # | |
  290. # | Loading stopwords ... |
  291. # | |
  292. # └─────────────────────────────┘
  293. print_block('Loading stopwords ...')
  294. paddle.enable_static()
  295. jieba.enable_paddle()
  296. jieba.set_dictionary('dict.txt.big')
  297. jieba.load_userdict('jieba_add_word.txt')
  298. jieba.load_userdict('jieba_add_word_kw_with_weighting.txt')
  299. check_pos_list = []
  300. converter = opencc.OpenCC('s2t.json')
  301. converter.convert('汉字') # 漢字
  302. cc_stopwords = converter.convert(
  303. open("cn_stopwords.txt", "r").read()).split('\n')
  304. if not args.topK:
  305. topK = 80
  306. else:
  307. topK = int(args.topK)
  308. obj = Rake()
  309. stop_path = "./stoplist/中文停用词表(1208个).txt"
  310. conj_path = "./stoplist/中文分隔词词库.txt"
  311. obj.initializeFromPath(stop_path, conj_path)
  312. # 创建一个停用词列表
  313. with open('customized_stopwords.pickle', 'rb') as handle:
  314. customized_stopwords = pickle.load(handle)
  315. customized_stopwords.extend(stopwords.words('english'))
  316. customized_stopwords.extend(cc_stopwords)
  317. # ┌───────────────────────────────┐
  318. # | |
  319. # | Documents embedding ... |
  320. # | |
  321. # └───────────────────────────────┘
  322. print_block('Documents embedding ...')
  323. device = "cuda" if torch.cuda.is_available() else "cpu"
  324. # documents embedding
  325. print('【Transformer】Documents embedding ... ', end='')
  326. model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
  327. embeddings = model.encode(data['news_content'].tolist())
  328. print('DONE! ', 'Embeddings.shape:', embeddings.shape)
  329. # ┌────────────────────────────────────────┐
  330. # | |
  331. # | Clustering ... (Topic detection) |
  332. # | |
  333. # └────────────────────────────────────────┘
  334. print_block('Clustering ... (Topic detection)')
  335. # HDBSCAN clustering
  336. print('【HDBSCAN】Clustering ...', end='')
  337. hclusterer = hdbscan.HDBSCAN(prediction_data=True).fit(
  338. embeddings) # embeddings_list
  339. print('DONE!\n', '-'*80)
  340. print('Number of clusterers:', len(Counter(hclusterer.labels_)))
  341. print('Noise ratio:', round(list(hclusterer.labels_).count(-1) / len(embeddings), 3)
  342. * 100, '% ', list(hclusterer.labels_).count(-1), '/', len(embeddings))
  343. print('Clusterers:')
  344. print(dict(Counter(hclusterer.labels_)))
  345. print('-'*80)
  346. # approximate predict cluster to find target domain
  347. predict_doc = open("predict_doc.txt", "r").read()
  348. print('Predict doc:')
  349. print('...', predict_doc[50:150], '...\n')
  350. test_labels, strengths = hdbscan.approximate_predict(
  351. hclusterer, model.encode([predict_doc]))
  352. target_domain_cluster = test_labels[0]
  353. print('Predict_doc (target domain) is predicted to be in cluster #',
  354. target_domain_cluster, end='')
  355. if target_domain_cluster == -1:
  356. temp = Counter(hclusterer.labels_)
  357. del temp[-1]
  358. target_domain_cluster = max(
  359. temp.items(), key=operator.itemgetter(1))[0]
  360. print(' (noise)\n -> Replace with the largest cluster #',
  361. target_domain_cluster)
  362. labels, values = zip(
  363. *sorted(Counter(hclusterer.labels_[hclusterer.labels_ != -1]).items()))
  364. # check news in the cluster
  365. cluster_num = target_domain_cluster
  366. fil = [l == cluster_num for l in hclusterer.labels_]
  367. doc_list = list(compress(data['news_content'].tolist(), fil))
  368. print('num of news in the cluster #', cluster_num, ':', len(doc_list))
  369. # ┌──────────────────────────────┐
  370. # | |
  371. # | Keyword extracting ... |
  372. # | |
  373. # └──────────────────────────────┘
  374. print_block(' Keyword extracting ...')
  375. # candidate target domain cluster
  376. if not args.target_domain_range:
  377. target_domain_range = 1
  378. else:
  379. target_domain_range = int(args.target_domain_range)
  380. cluster_list = list(range(max(labels[0], target_domain_cluster-target_domain_range), min(
  381. labels[-1], target_domain_cluster+target_domain_range)+1))
  382. # if your target_domain_cluster is 6 and target_domain_range is 1, cluster_list will be [5, 6, 7]
  383. # get all news in all candidate cluster
  384. doc_list = []
  385. for c in cluster_list:
  386. fil = [l == c for l in hclusterer.labels_]
  387. doc_list.extend(list(compress(data['news_content'].tolist(), fil)))
  388. print('Num of news in the cluster #', ' #'.join(
  389. [str(c) for c in cluster_list]), ':', len(doc_list),'\n')
  390. tag_list = []
  391. # find tags from renewhouse website
  392. print('Find tags from renewhouse website ...')
  393. with open('renewhouse_list.pickle', 'rb') as handle:
  394. renewhouse_list = pickle.load(handle)
  395. tag_list.extend(find_tags(renewhouse_list))
  396. # merge news in a candidate cluster to a doc and find tags
  397. print('Merge news in a candidate cluster to a doc and find tags ...')
  398. count = 0
  399. for c in cluster_list:
  400. count += 1
  401. print('Extracting keywords in the cluster #', str(
  402. c), '... (', count, '/', len(cluster_list), ')')
  403. fil = [l == c for l in hclusterer.labels_]
  404. doc_list_ = list(compress(data['news_content'].tolist(), fil))
  405. doc_list_ = sample(doc_list_, min(150, len(doc_list_)))
  406. tag_list.extend(find_tags(['\n'.join(doc_list_)]))
  407. # merge news in all candidate clusters to a doc and find tags
  408. print('Merge news in all candidate clusters to a doc and find tags ...')
  409. doc_list_ = sample(doc_list, min(150*len(cluster_list), len(doc_list)))
  410. tag_list.extend(find_tags(['\n'.join(doc_list_)]))
  411. tag_list = list(filter(lambda x: x not in list(
  412. set(customized_stopwords)), list(set(tag_list))))
  413. print('='*80, '\n')
  414. print('Num of keywords:', len(set(tag_list)))
  415. # save to csv
  416. pd.DataFrame(data={'id': list(range(1, int(len(set(tag_list)))+1)), 'kw': list(
  417. set(tag_list))}).to_csv('tag_list.csv', index=False, encoding='utf-8-sig')
  418. print('Save to "tag_list.csv"')
  419. print("\n--- %s seconds ---\n" % (time.time() - start_time))