123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- # ----------------------------------------
- # Import packages
- # ----------------------------------------
- import warnings
- import os
- from random import sample
- from collections import Counter
- import pickle
- import pandas as pd
- import dataset
- import time
- import json
- import operator
- from typing import List
- import os
- import jieba
- import jieba.posseg as pseg
- import jieba.analyse
- import paddle
- import hdbscan
- from collections import Counter
- from itertools import compress
- from collections import Counter
- import pprint
- from sentence_transformers import SentenceTransformer
- import torch
- from nltk.corpus import stopwords
- import time
- import opencc
- from tqdm import tqdm
- from random import sample
- from string import punctuation
- import argparse
- import pke
- import pymysql
- pymysql.install_as_MySQLdb()
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- warnings.filterwarnings("ignore")
- def chunkstring(string, length):
- return (string[0+i:length+i] for i in range(0, len(string), length))
- def print_block(text):
- # sentence = input('Sentence: ')
- width = 60
- print('\n┌─' + '─' * width + '─┐')
- for line in chunkstring(text, width):
- print('| {0:^{1}} |'.format(line, width))
- print('└─' + '─'*(width) + '─┘')
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--topK', required=False)
- parser.add_argument('--target_domain_range',
- required=False)
- args = parser.parse_args()
- start_time = time.time()
- # ┌──────────────────────────────────────┐
- # | |
- # | Loading data from database ... |
- # | |
- # └──────────────────────────────────────┘
- print_block('Loading data from database ...')
-
- db = dataset.connect(
- 'mysql://choozmo:pAssw0rd@db.ptt.cx:3306/hhh?charset=utf8mb4')
- result = db.query('SELECT * FROM gnews.gnews_detail ')
- data = pd.DataFrame(result, columns=next(iter(result)).keys())
- print('Number of news:', len(data))
- print('Dropping duplicates ... ', end='')
- data = data.drop_duplicates(subset=['news_content'], keep='first').drop_duplicates(
- subset=['news_url'], keep='first')
- print('number of news:', len(data))
- db.close()
- # ----------------------------------------
- # Functions
- # ----------------------------------------
- # Check if contains num
- def notNumStr(instr):
- for item in instr:
- if '\u0041' <= item <= '\u005a' or ('\u0061' <= item <= '\u007a') or item.isdigit():
- return False
- return True
- # Read Target Case if Json
- def readSingleTestCases(testFile):
- with open(testFile) as json_data:
- try:
- testData = json.load(json_data)
- except:
- # This try block deals with incorrect json format that has ' instead of "
- data = json_data.read().replace("'", '"')
- try:
- testData = json.loads(data)
- # This try block deals with empty transcript file
- except:
- return ""
- returnString = ""
- for item in testData:
- try:
- returnString += item['text']
- except:
- returnString += item['statement']
- return returnString
- class Word():
- def __init__(self, char, freq=0, deg=0):
- self.freq = freq
- self.deg = deg
- self.char = char
- def returnScore(self):
- return self.deg/self.freq
- def updateOccur(self, phraseLength):
- self.freq += 1
- self.deg += phraseLength
- def getChar(self):
- return self.char
- def updateFreq(self):
- self.freq += 1
- def getFreq(self):
- return self.freq
- class Rake:
- # , stopwordPath: str = None, delimWordPath: str = None):
- def __init__(self):
- # If both Found and Initialized
- self.initialized = False
- self.stopWordList = list()
- self.delimWordList = list()
- def initializeFromPath(self, stopwordPath: str = "", delimWordPath: str = ""):
- if not os.path.exists(stopwordPath):
- print("Stop Word Path invalid")
- return
- if not os.path.exists(delimWordPath):
- print("Delim Word Path Invalid")
- return
- swLibList = [line.rstrip('\n') for line in converter.convert(
- open(stopwordPath, 'r').read()).split('\n')]
- conjLibList = [line.rstrip('\n') for line in converter.convert(
- open(delimWordPath, 'r').read()).split('\n')]
- self.initializeFromList(swLibList, conjLibList)
- return
- def initializeFromList(self, swList: List = None, dwList: List = None):
- self.stopWordList = swList
- self.delimWordList = dwList
- if len(self.stopWordList) == 0 or len(self.delimWordList) == 0:
- print("Empty Stop word list or deliminator word list, uninitialized")
- return
- else:
- self.initialized = True
- def extractKeywordFromPath(self, text: str, num_kw: int = 10):
- if not self.initialized:
- print("Not initialized")
- return
- with open(text, 'r') as fp:
- text = fp.read()
- return self.extractKeywordFromString(text, num_kw=num_kw)
- def extractKeywordFromString(self, text: str, num_kw: int = 10):
- rawtextList = pseg.cut(text)
- # Construct List of Phrases and Preliminary textList
- textList = []
- listofSingleWord = dict()
- lastWord = ''
- # for jieba
- poSPrty = ['zg', 'm', 'x', 'uj', 'ul', 'mq', 'u', 'v', 'f', 't', 'vd', 'q', 'r', 'd', 'p', 'nr', 'r'
- 'c', 'TIME', 'xc', 'a', 'ad', 'an', 'nrt', 'df', 'b', 'vn', 'l', 'y', 'o', 'i']
- meaningfulCount = 0
- checklist = []
- for eachWord, flag in rawtextList:
- check_pos_list.append(str(eachWord)+'/'+str(flag))
- checklist.append([eachWord, flag])
- if eachWord in self.delimWordList or not notNumStr(eachWord) or eachWord in self.stopWordList or flag in poSPrty or eachWord == '\n':
- if lastWord != '|':
- textList.append("|")
- lastWord = "|"
- elif eachWord not in self.stopWordList and eachWord != '\n':
- textList.append(eachWord)
- meaningfulCount += 1
- if eachWord not in listofSingleWord:
- listofSingleWord[eachWord] = Word(eachWord)
- lastWord = ''
- # Construct List of list that has phrases as wrds
- newList = []
- tempList = []
- for everyWord in textList:
- if everyWord != '|':
- tempList.append(everyWord)
- else:
- newList.append(tempList)
- tempList = []
- tempStr = ''
- for everyWord in textList:
- if everyWord != '|':
- tempStr += everyWord + '|'
- else:
- if tempStr[:-1] not in listofSingleWord:
- listofSingleWord[tempStr[:-1]] = Word(tempStr[:-1])
- tempStr = ''
- # Update the entire List
- for everyPhrase in newList:
- res = ''
- for everyWord in everyPhrase:
- listofSingleWord[everyWord].updateOccur(len(everyPhrase))
- res += everyWord + '|'
- phraseKey = res[:-1]
- if phraseKey not in listofSingleWord:
- listofSingleWord[phraseKey] = Word(phraseKey)
- else:
- listofSingleWord[phraseKey].updateFreq()
- # Get score for entire Set
- outputList = dict()
- for everyPhrase in newList:
- if len(everyPhrase) > 5:
- continue
- score = 0
- phraseString = ''
- outStr = ''
- for everyWord in everyPhrase:
- score += listofSingleWord[everyWord].returnScore()
- phraseString += everyWord + '|'
- outStr += everyWord
- phraseKey = phraseString[:-1]
- freq = listofSingleWord[phraseKey].getFreq()
- if freq / meaningfulCount < 0.05 and freq < 3:
- continue
- outputList[outStr] = score
- sorted_list = sorted(outputList.items(),
- key=operator.itemgetter(1), reverse=True)
- sorted_list = [s[0] for s in sorted_list]
- return sorted_list[:num_kw]
- def pke_MultipartiteRank(text):
- # initialize a TopicRank extractor
- extractor = pke.unsupervised.MultipartiteRank()
- # load the content of the document and perform French stemming
- extractor.load_document(input=text,
- language='zh',
- normalization=None)
- # keyphrase candidate selection, here sequences of nouns and adjectives
- # defined by the Universal PoS tagset
- extractor.candidate_selection(
- pos={"NOUN", "PROPN"}, stoplist=customized_stopwords)
- # candidate weighting, here using a random walk algorithm
- extractor.candidate_weighting(alpha=1.1,
- threshold=0.65,
- method='average')
- # N-best selection, keyphrases contains the 10 highest scored candidates as
- # (keyphrase, score) tuples
- keyphrases = extractor.get_n_best(n=40) # topK
- return [j for sub in [k[0].split() for k in keyphrases] for j in sub]
- def half2full(s):
- n = []
- for c in list(s):
- num = ord(c)
- if num == 320:
- num = 0x3000
- elif 0x21 <= num <= 0x7E:
- num += 0xfee0
- num = chr(num)
- n.append(num)
- return ''.join(n)
- def find_tags(doc_list):
- pos_list = ['ns', 'n', 'nt', 'nz', 'x', 'ns', 'nrfg', 'an'] # , 'vn'
- punctuations = punctuation + \
- half2full(punctuation)+'、●{}「」[]【】()()<>《》〈〉『』〔〕'
- tag_list = []
- pbar = tqdm(range(len(doc_list)))
- pbar.set_description("[Extracting keywords...]")
- fail_count = 0
- for d in doc_list:
- pbar.update()
- d = d.translate(d.maketrans(
- punctuations, ' '*len(punctuations), ""))
- try:
- result = []
- result.extend(obj.extractKeywordFromString(d, num_kw=topK))
- result.extend(jieba.analyse.extract_tags(
- d, topK=topK, allowPOS=(pos_list)))
- result.extend(jieba.analyse.textrank(
- d, topK=topK, allowPOS=(pos_list)))
- result.extend(pke_MultipartiteRank(d))
- except:
- result = 'a'
- fail_count += 1
- print('-'*80)
- print('Keywords of this news are not available:\n',
- '...', d[50:150], '...\n', '-'*80)
- tags = list(filter(lambda x: len(x) > 1 and notNumStr(
- x) and x not in list(set(customized_stopwords)), result))
- tag_list.extend(tags)
- tag_list = list(set(tag_list))
- pbar.close()
- print('Num of keywords:', len(tag_list), end=', ')
- print('Fail:', fail_count, '\n')
- return tag_list
- # ┌─────────────────────────────┐
- # | |
- # | Loading stopwords ... |
- # | |
- # └─────────────────────────────┘
- print_block('Loading stopwords ...')
- paddle.enable_static()
- jieba.enable_paddle()
- jieba.set_dictionary('dict.txt.big')
- jieba.load_userdict('jieba_add_word.txt')
- jieba.load_userdict('jieba_add_word_kw_with_weighting.txt')
- check_pos_list = []
- converter = opencc.OpenCC('s2t.json')
- converter.convert('汉字') # 漢字
- cc_stopwords = converter.convert(
- open("cn_stopwords.txt", "r").read()).split('\n')
- if not args.topK:
- topK = 80
- else:
- topK = int(args.topK)
- obj = Rake()
- stop_path = "./stoplist/中文停用词表(1208个).txt"
- conj_path = "./stoplist/中文分隔词词库.txt"
- obj.initializeFromPath(stop_path, conj_path)
- # 创建一个停用词列表
- with open('customized_stopwords.pickle', 'rb') as handle:
- customized_stopwords = pickle.load(handle)
- customized_stopwords.extend(stopwords.words('english'))
- customized_stopwords.extend(cc_stopwords)
- # ┌───────────────────────────────┐
- # | |
- # | Documents embedding ... |
- # | |
- # └───────────────────────────────┘
- print_block('Documents embedding ...')
- device = "cuda" if torch.cuda.is_available() else "cpu"
- # documents embedding
- print('【Transformer】Documents embedding ... ', end='')
- model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
- embeddings = model.encode(data['news_content'].tolist())
- print('DONE! ', 'Embeddings.shape:', embeddings.shape)
- # ┌────────────────────────────────────────┐
- # | |
- # | Clustering ... (Topic detection) |
- # | |
- # └────────────────────────────────────────┘
- print_block('Clustering ... (Topic detection)')
- # HDBSCAN clustering
- print('【HDBSCAN】Clustering ...', end='')
- hclusterer = hdbscan.HDBSCAN(prediction_data=True).fit(
- embeddings) # embeddings_list
- print('DONE!\n', '-'*80)
- print('Number of clusterers:', len(Counter(hclusterer.labels_)))
- print('Noise ratio:', round(list(hclusterer.labels_).count(-1) / len(embeddings), 3)
- * 100, '% ', list(hclusterer.labels_).count(-1), '/', len(embeddings))
- print('Clusterers:')
- print(dict(Counter(hclusterer.labels_)))
- print('-'*80)
- # approximate predict cluster to find target domain
- predict_doc = open("predict_doc.txt", "r").read()
- print('Predict doc:')
- print('...', predict_doc[50:150], '...\n')
- test_labels, strengths = hdbscan.approximate_predict(
- hclusterer, model.encode([predict_doc]))
- target_domain_cluster = test_labels[0]
- print('Predict_doc (target domain) is predicted to be in cluster #',
- target_domain_cluster, end='')
- if target_domain_cluster == -1:
- temp = Counter(hclusterer.labels_)
- del temp[-1]
- target_domain_cluster = max(
- temp.items(), key=operator.itemgetter(1))[0]
- print(' (noise)\n -> Replace with the largest cluster #',
- target_domain_cluster)
- labels, values = zip(
- *sorted(Counter(hclusterer.labels_[hclusterer.labels_ != -1]).items()))
- # check news in the cluster
- cluster_num = target_domain_cluster
- fil = [l == cluster_num for l in hclusterer.labels_]
- doc_list = list(compress(data['news_content'].tolist(), fil))
- print('num of news in the cluster #', cluster_num, ':', len(doc_list))
- # ┌──────────────────────────────┐
- # | |
- # | Keyword extracting ... |
- # | |
- # └──────────────────────────────┘
- print_block(' Keyword extracting ...')
-
- # candidate target domain cluster
- if not args.target_domain_range:
- target_domain_range = 1
- else:
- target_domain_range = int(args.target_domain_range)
- cluster_list = list(range(max(labels[0], target_domain_cluster-target_domain_range), min(
- labels[-1], target_domain_cluster+target_domain_range)+1))
- # if your target_domain_cluster is 6 and target_domain_range is 1, cluster_list will be [5, 6, 7]
- # get all news in all candidate cluster
- doc_list = []
- for c in cluster_list:
- fil = [l == c for l in hclusterer.labels_]
- doc_list.extend(list(compress(data['news_content'].tolist(), fil)))
- print('Num of news in the cluster #', ' #'.join(
- [str(c) for c in cluster_list]), ':', len(doc_list),'\n')
- tag_list = []
- # find tags from renewhouse website
- print('Find tags from renewhouse website ...')
- with open('renewhouse_list.pickle', 'rb') as handle:
- renewhouse_list = pickle.load(handle)
- tag_list.extend(find_tags(renewhouse_list))
- # merge news in a candidate cluster to a doc and find tags
- print('Merge news in a candidate cluster to a doc and find tags ...')
- count = 0
- for c in cluster_list:
- count += 1
- print('Extracting keywords in the cluster #', str(
- c), '... (', count, '/', len(cluster_list), ')')
- fil = [l == c for l in hclusterer.labels_]
- doc_list_ = list(compress(data['news_content'].tolist(), fil))
- doc_list_ = sample(doc_list_, min(150, len(doc_list_)))
- tag_list.extend(find_tags(['\n'.join(doc_list_)]))
- # merge news in all candidate clusters to a doc and find tags
- print('Merge news in all candidate clusters to a doc and find tags ...')
- doc_list_ = sample(doc_list, min(150*len(cluster_list), len(doc_list)))
- tag_list.extend(find_tags(['\n'.join(doc_list_)]))
- tag_list = list(filter(lambda x: x not in list(
- set(customized_stopwords)), list(set(tag_list))))
- print('='*80, '\n')
- print('Num of keywords:', len(set(tag_list)))
- # save to csv
- pd.DataFrame(data={'id': list(range(1, int(len(set(tag_list)))+1)), 'kw': list(
- set(tag_list))}).to_csv('tag_list.csv', index=False, encoding='utf-8-sig')
- print('Save to "tag_list.csv"')
- print("\n--- %s seconds ---\n" % (time.time() - start_time))
|