# ---------------------------------------- # Import packages # ---------------------------------------- import warnings import os from random import sample from collections import Counter import pickle import pandas as pd import dataset import time import json import operator from typing import List import os import jieba import jieba.posseg as pseg import jieba.analyse import paddle import hdbscan from collections import Counter from itertools import compress from collections import Counter import pprint from sentence_transformers import SentenceTransformer import torch from nltk.corpus import stopwords import time import opencc from tqdm import tqdm from random import sample from string import punctuation import argparse import pke import pymysql pymysql.install_as_MySQLdb() os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") def chunkstring(string, length): return (string[0+i:length+i] for i in range(0, len(string), length)) def print_block(text): # sentence = input('Sentence: ') width = 60 print('\n┌─' + '─' * width + '─┐') for line in chunkstring(text, width): print('| {0:^{1}} |'.format(line, width)) print('└─' + '─'*(width) + '─┘') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--topK', required=False) parser.add_argument('--target_domain_range', required=False) args = parser.parse_args() start_time = time.time() # ┌──────────────────────────────────────┐ # | | # | Loading data from database ... | # | | # └──────────────────────────────────────┘ print_block('Loading data from database ...') db = dataset.connect( 'mysql://choozmo:pAssw0rd@db.ptt.cx:3306/hhh?charset=utf8mb4') result = db.query('SELECT * FROM gnews.gnews_detail ') data = pd.DataFrame(result, columns=next(iter(result)).keys()) print('Number of news:', len(data)) print('Dropping duplicates ... ', end='') data = data.drop_duplicates(subset=['news_content'], keep='first').drop_duplicates( subset=['news_url'], keep='first') print('number of news:', len(data)) db.close() # ---------------------------------------- # Functions # ---------------------------------------- # Check if contains num def notNumStr(instr): for item in instr: if '\u0041' <= item <= '\u005a' or ('\u0061' <= item <= '\u007a') or item.isdigit(): return False return True # Read Target Case if Json def readSingleTestCases(testFile): with open(testFile) as json_data: try: testData = json.load(json_data) except: # This try block deals with incorrect json format that has ' instead of " data = json_data.read().replace("'", '"') try: testData = json.loads(data) # This try block deals with empty transcript file except: return "" returnString = "" for item in testData: try: returnString += item['text'] except: returnString += item['statement'] return returnString class Word(): def __init__(self, char, freq=0, deg=0): self.freq = freq self.deg = deg self.char = char def returnScore(self): return self.deg/self.freq def updateOccur(self, phraseLength): self.freq += 1 self.deg += phraseLength def getChar(self): return self.char def updateFreq(self): self.freq += 1 def getFreq(self): return self.freq class Rake: # , stopwordPath: str = None, delimWordPath: str = None): def __init__(self): # If both Found and Initialized self.initialized = False self.stopWordList = list() self.delimWordList = list() def initializeFromPath(self, stopwordPath: str = "", delimWordPath: str = ""): if not os.path.exists(stopwordPath): print("Stop Word Path invalid") return if not os.path.exists(delimWordPath): print("Delim Word Path Invalid") return swLibList = [line.rstrip('\n') for line in converter.convert( open(stopwordPath, 'r').read()).split('\n')] conjLibList = [line.rstrip('\n') for line in converter.convert( open(delimWordPath, 'r').read()).split('\n')] self.initializeFromList(swLibList, conjLibList) return def initializeFromList(self, swList: List = None, dwList: List = None): self.stopWordList = swList self.delimWordList = dwList if len(self.stopWordList) == 0 or len(self.delimWordList) == 0: print("Empty Stop word list or deliminator word list, uninitialized") return else: self.initialized = True def extractKeywordFromPath(self, text: str, num_kw: int = 10): if not self.initialized: print("Not initialized") return with open(text, 'r') as fp: text = fp.read() return self.extractKeywordFromString(text, num_kw=num_kw) def extractKeywordFromString(self, text: str, num_kw: int = 10): rawtextList = pseg.cut(text) # Construct List of Phrases and Preliminary textList textList = [] listofSingleWord = dict() lastWord = '' # for jieba poSPrty = ['zg', 'm', 'x', 'uj', 'ul', 'mq', 'u', 'v', 'f', 't', 'vd', 'q', 'r', 'd', 'p', 'nr', 'r' 'c', 'TIME', 'xc', 'a', 'ad', 'an', 'nrt', 'df', 'b', 'vn', 'l', 'y', 'o', 'i'] meaningfulCount = 0 checklist = [] for eachWord, flag in rawtextList: check_pos_list.append(str(eachWord)+'/'+str(flag)) checklist.append([eachWord, flag]) if eachWord in self.delimWordList or not notNumStr(eachWord) or eachWord in self.stopWordList or flag in poSPrty or eachWord == '\n': if lastWord != '|': textList.append("|") lastWord = "|" elif eachWord not in self.stopWordList and eachWord != '\n': textList.append(eachWord) meaningfulCount += 1 if eachWord not in listofSingleWord: listofSingleWord[eachWord] = Word(eachWord) lastWord = '' # Construct List of list that has phrases as wrds newList = [] tempList = [] for everyWord in textList: if everyWord != '|': tempList.append(everyWord) else: newList.append(tempList) tempList = [] tempStr = '' for everyWord in textList: if everyWord != '|': tempStr += everyWord + '|' else: if tempStr[:-1] not in listofSingleWord: listofSingleWord[tempStr[:-1]] = Word(tempStr[:-1]) tempStr = '' # Update the entire List for everyPhrase in newList: res = '' for everyWord in everyPhrase: listofSingleWord[everyWord].updateOccur(len(everyPhrase)) res += everyWord + '|' phraseKey = res[:-1] if phraseKey not in listofSingleWord: listofSingleWord[phraseKey] = Word(phraseKey) else: listofSingleWord[phraseKey].updateFreq() # Get score for entire Set outputList = dict() for everyPhrase in newList: if len(everyPhrase) > 5: continue score = 0 phraseString = '' outStr = '' for everyWord in everyPhrase: score += listofSingleWord[everyWord].returnScore() phraseString += everyWord + '|' outStr += everyWord phraseKey = phraseString[:-1] freq = listofSingleWord[phraseKey].getFreq() if freq / meaningfulCount < 0.05 and freq < 3: continue outputList[outStr] = score sorted_list = sorted(outputList.items(), key=operator.itemgetter(1), reverse=True) sorted_list = [s[0] for s in sorted_list] return sorted_list[:num_kw] def pke_MultipartiteRank(text): # initialize a TopicRank extractor extractor = pke.unsupervised.MultipartiteRank() # load the content of the document and perform French stemming extractor.load_document(input=text, language='zh', normalization=None) # keyphrase candidate selection, here sequences of nouns and adjectives # defined by the Universal PoS tagset extractor.candidate_selection( pos={"NOUN", "PROPN"}, stoplist=customized_stopwords) # candidate weighting, here using a random walk algorithm extractor.candidate_weighting(alpha=1.1, threshold=0.65, method='average') # N-best selection, keyphrases contains the 10 highest scored candidates as # (keyphrase, score) tuples keyphrases = extractor.get_n_best(n=40) # topK return [j for sub in [k[0].split() for k in keyphrases] for j in sub] def half2full(s): n = [] for c in list(s): num = ord(c) if num == 320: num = 0x3000 elif 0x21 <= num <= 0x7E: num += 0xfee0 num = chr(num) n.append(num) return ''.join(n) def find_tags(doc_list): pos_list = ['ns', 'n', 'nt', 'nz', 'x', 'ns', 'nrfg', 'an'] # , 'vn' punctuations = punctuation + \ half2full(punctuation)+'、●{}「」[]【】()()<>《》〈〉『』〔〕' tag_list = [] pbar = tqdm(range(len(doc_list))) pbar.set_description("[Extracting keywords...]") fail_count = 0 for d in doc_list: pbar.update() d = d.translate(d.maketrans( punctuations, ' '*len(punctuations), "")) try: result = [] result.extend(obj.extractKeywordFromString(d, num_kw=topK)) result.extend(jieba.analyse.extract_tags( d, topK=topK, allowPOS=(pos_list))) result.extend(jieba.analyse.textrank( d, topK=topK, allowPOS=(pos_list))) result.extend(pke_MultipartiteRank(d)) except: result = 'a' fail_count += 1 print('-'*80) print('Keywords of this news are not available:\n', '...', d[50:150], '...\n', '-'*80) tags = list(filter(lambda x: len(x) > 1 and notNumStr( x) and x not in list(set(customized_stopwords)), result)) tag_list.extend(tags) tag_list = list(set(tag_list)) pbar.close() print('Num of keywords:', len(tag_list), end=', ') print('Fail:', fail_count, '\n') return tag_list # ┌─────────────────────────────┐ # | | # | Loading stopwords ... | # | | # └─────────────────────────────┘ print_block('Loading stopwords ...') paddle.enable_static() jieba.enable_paddle() jieba.set_dictionary('dict.txt.big') jieba.load_userdict('jieba_add_word.txt') jieba.load_userdict('jieba_add_word_kw_with_weighting.txt') check_pos_list = [] converter = opencc.OpenCC('s2t.json') converter.convert('汉字') # 漢字 cc_stopwords = converter.convert( open("cn_stopwords.txt", "r").read()).split('\n') if not args.topK: topK = 80 else: topK = int(args.topK) obj = Rake() stop_path = "./stoplist/中文停用词表(1208个).txt" conj_path = "./stoplist/中文分隔词词库.txt" obj.initializeFromPath(stop_path, conj_path) # 创建一个停用词列表 with open('customized_stopwords.pickle', 'rb') as handle: customized_stopwords = pickle.load(handle) customized_stopwords.extend(stopwords.words('english')) customized_stopwords.extend(cc_stopwords) # ┌───────────────────────────────┐ # | | # | Documents embedding ... | # | | # └───────────────────────────────┘ print_block('Documents embedding ...') device = "cuda" if torch.cuda.is_available() else "cpu" # documents embedding print('【Transformer】Documents embedding ... ', end='') model = SentenceTransformer('distiluse-base-multilingual-cased-v1') embeddings = model.encode(data['news_content'].tolist()) print('DONE! ', 'Embeddings.shape:', embeddings.shape) # ┌────────────────────────────────────────┐ # | | # | Clustering ... (Topic detection) | # | | # └────────────────────────────────────────┘ print_block('Clustering ... (Topic detection)') # HDBSCAN clustering print('【HDBSCAN】Clustering ...', end='') hclusterer = hdbscan.HDBSCAN(prediction_data=True).fit( embeddings) # embeddings_list print('DONE!\n', '-'*80) print('Number of clusterers:', len(Counter(hclusterer.labels_))) print('Noise ratio:', round(list(hclusterer.labels_).count(-1) / len(embeddings), 3) * 100, '% ', list(hclusterer.labels_).count(-1), '/', len(embeddings)) print('Clusterers:') print(dict(Counter(hclusterer.labels_))) print('-'*80) # approximate predict cluster to find target domain predict_doc = open("predict_doc.txt", "r").read() print('Predict doc:') print('...', predict_doc[50:150], '...\n') test_labels, strengths = hdbscan.approximate_predict( hclusterer, model.encode([predict_doc])) target_domain_cluster = test_labels[0] print('Predict_doc (target domain) is predicted to be in cluster #', target_domain_cluster, end='') if target_domain_cluster == -1: temp = Counter(hclusterer.labels_) del temp[-1] target_domain_cluster = max( temp.items(), key=operator.itemgetter(1))[0] print(' (noise)\n -> Replace with the largest cluster #', target_domain_cluster) labels, values = zip( *sorted(Counter(hclusterer.labels_[hclusterer.labels_ != -1]).items())) # check news in the cluster cluster_num = target_domain_cluster fil = [l == cluster_num for l in hclusterer.labels_] doc_list = list(compress(data['news_content'].tolist(), fil)) print('num of news in the cluster #', cluster_num, ':', len(doc_list)) # ┌──────────────────────────────┐ # | | # | Keyword extracting ... | # | | # └──────────────────────────────┘ print_block(' Keyword extracting ...') # candidate target domain cluster if not args.target_domain_range: target_domain_range = 1 else: target_domain_range = int(args.target_domain_range) cluster_list = list(range(max(labels[0], target_domain_cluster-target_domain_range), min( labels[-1], target_domain_cluster+target_domain_range)+1)) # if your target_domain_cluster is 6 and target_domain_range is 1, cluster_list will be [5, 6, 7] # get all news in all candidate cluster doc_list = [] for c in cluster_list: fil = [l == c for l in hclusterer.labels_] doc_list.extend(list(compress(data['news_content'].tolist(), fil))) print('Num of news in the cluster #', ' #'.join( [str(c) for c in cluster_list]), ':', len(doc_list),'\n') tag_list = [] # find tags from renewhouse website print('Find tags from renewhouse website ...') with open('renewhouse_list.pickle', 'rb') as handle: renewhouse_list = pickle.load(handle) tag_list.extend(find_tags(renewhouse_list)) # merge news in a candidate cluster to a doc and find tags print('Merge news in a candidate cluster to a doc and find tags ...') count = 0 for c in cluster_list: count += 1 print('Extracting keywords in the cluster #', str( c), '... (', count, '/', len(cluster_list), ')') fil = [l == c for l in hclusterer.labels_] doc_list_ = list(compress(data['news_content'].tolist(), fil)) doc_list_ = sample(doc_list_, min(150, len(doc_list_))) tag_list.extend(find_tags(['\n'.join(doc_list_)])) # merge news in all candidate clusters to a doc and find tags print('Merge news in all candidate clusters to a doc and find tags ...') doc_list_ = sample(doc_list, min(150*len(cluster_list), len(doc_list))) tag_list.extend(find_tags(['\n'.join(doc_list_)])) tag_list = list(filter(lambda x: x not in list( set(customized_stopwords)), list(set(tag_list)))) print('='*80, '\n') print('Num of keywords:', len(set(tag_list))) # save to csv pd.DataFrame(data={'id': list(range(1, int(len(set(tag_list)))+1)), 'kw': list( set(tag_list))}).to_csv('tag_list.csv', index=False, encoding='utf-8-sig') print('Save to "tag_list.csv"') print("\n--- %s seconds ---\n" % (time.time() - start_time))