__init__-0.4.0.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. """
  2. Defines autosub's main functionality.
  3. """
  4. #!/usr/bin/env python
  5. from __future__ import absolute_import, print_function, unicode_literals
  6. import argparse
  7. import audioop
  8. import json
  9. import math
  10. import multiprocessing
  11. import os
  12. import subprocess
  13. import sys
  14. import tempfile
  15. import wave
  16. import requests
  17. from googleapiclient.discovery import build
  18. from progressbar import ProgressBar, Percentage, Bar, ETA
  19. from autosub.constants import (
  20. LANGUAGE_CODES, GOOGLE_SPEECH_API_KEY, GOOGLE_SPEECH_API_URL,
  21. )
  22. from autosub.formatters import FORMATTERS
  23. DEFAULT_SUBTITLE_FORMAT = 'srt'
  24. DEFAULT_CONCURRENCY = 10
  25. DEFAULT_SRC_LANGUAGE = 'en'
  26. DEFAULT_DST_LANGUAGE = 'en'
  27. def percentile(arr, percent):
  28. """
  29. Calculate the given percentile of arr.
  30. """
  31. arr = sorted(arr)
  32. index = (len(arr) - 1) * percent
  33. floor = math.floor(index)
  34. ceil = math.ceil(index)
  35. if floor == ceil:
  36. return arr[int(index)]
  37. low_value = arr[int(floor)] * (ceil - index)
  38. high_value = arr[int(ceil)] * (index - floor)
  39. return low_value + high_value
  40. class FLACConverter(object): # pylint: disable=too-few-public-methods
  41. """
  42. Class for converting a region of an input audio or video file into a FLAC audio file
  43. """
  44. def __init__(self, source_path, include_before=0.25, include_after=0.25):
  45. self.source_path = source_path
  46. self.include_before = include_before
  47. self.include_after = include_after
  48. def __call__(self, region):
  49. try:
  50. start, end = region
  51. start = max(0, start - self.include_before)
  52. end += self.include_after
  53. temp = tempfile.NamedTemporaryFile(suffix='.flac')
  54. command = ["ffmpeg", "-ss", str(start), "-t", str(end - start),
  55. "-y", "-i", self.source_path,
  56. "-loglevel", "error", temp.name]
  57. use_shell = True if os.name == "nt" else False
  58. subprocess.check_output(command, stdin=open(os.devnull), shell=use_shell)
  59. return temp.read()
  60. except KeyboardInterrupt:
  61. return None
  62. class SpeechRecognizer(object): # pylint: disable=too-few-public-methods
  63. """
  64. Class for performing speech-to-text for an input FLAC file.
  65. """
  66. def __init__(self, language="en", rate=44100, retries=3, api_key=GOOGLE_SPEECH_API_KEY):
  67. self.language = language
  68. self.rate = rate
  69. self.api_key = api_key
  70. self.retries = retries
  71. def __call__(self, data):
  72. try:
  73. for _ in range(self.retries):
  74. url = GOOGLE_SPEECH_API_URL.format(lang=self.language, key=self.api_key)
  75. headers = {"Content-Type": "audio/x-flac; rate=%d" % self.rate}
  76. try:
  77. resp = requests.post(url, data=data, headers=headers)
  78. except requests.exceptions.ConnectionError:
  79. continue
  80. for line in resp.content.decode('utf-8').split("\n"):
  81. try:
  82. line = json.loads(line)
  83. line = line['result'][0]['alternative'][0]['transcript']
  84. return line[:1].upper() + line[1:]
  85. except IndexError:
  86. # no result
  87. continue
  88. except KeyboardInterrupt:
  89. return None
  90. class Translator(object): # pylint: disable=too-few-public-methods
  91. """
  92. Class for translating a sentence from a one language to another.
  93. """
  94. def __init__(self, language, api_key, src, dst):
  95. self.language = language
  96. self.api_key = api_key
  97. self.service = build('translate', 'v2',
  98. developerKey=self.api_key)
  99. self.src = src
  100. self.dst = dst
  101. def __call__(self, sentence):
  102. try:
  103. if not sentence:
  104. return None
  105. result = self.service.translations().list( # pylint: disable=no-member
  106. source=self.src,
  107. target=self.dst,
  108. q=[sentence]
  109. ).execute()
  110. if 'translations' in result and result['translations'] and \
  111. 'translatedText' in result['translations'][0]:
  112. return result['translations'][0]['translatedText']
  113. return None
  114. except KeyboardInterrupt:
  115. return None
  116. def which(program):
  117. """
  118. Return the path for a given executable.
  119. """
  120. def is_exe(file_path):
  121. """
  122. Checks whether a file is executable.
  123. """
  124. return os.path.isfile(file_path) and os.access(file_path, os.X_OK)
  125. fpath, _ = os.path.split(program)
  126. if fpath:
  127. if is_exe(program):
  128. return program
  129. else:
  130. for path in os.environ["PATH"].split(os.pathsep):
  131. path = path.strip('"')
  132. exe_file = os.path.join(path, program)
  133. if is_exe(exe_file):
  134. return exe_file
  135. return None
  136. def extract_audio(filename, channels=1, rate=16000):
  137. """
  138. Extract audio from an input file to a temporary WAV file.
  139. """
  140. temp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
  141. if not os.path.isfile(filename):
  142. print("The given file does not exist: {}".format(filename))
  143. raise Exception("Invalid filepath: {}".format(filename))
  144. if not which("ffmpeg"):
  145. print("ffmpeg: Executable not found on machine.")
  146. raise Exception("Dependency not found: ffmpeg")
  147. command = ["ffmpeg", "-y", "-i", filename,
  148. "-ac", str(channels), "-ar", str(rate),
  149. "-loglevel", "error", temp.name]
  150. use_shell = True if os.name == "nt" else False
  151. subprocess.check_output(command, stdin=open(os.devnull), shell=use_shell)
  152. return temp.name, rate
  153. def find_speech_regions(filename, frame_width=4096, min_region_size=0.5, max_region_size=6): # pylint: disable=too-many-locals
  154. """
  155. Perform voice activity detection on a given audio file.
  156. """
  157. reader = wave.open(filename)
  158. sample_width = reader.getsampwidth()
  159. rate = reader.getframerate()
  160. n_channels = reader.getnchannels()
  161. chunk_duration = float(frame_width) / rate
  162. n_chunks = int(math.ceil(reader.getnframes()*1.0 / frame_width))
  163. energies = []
  164. for _ in range(n_chunks):
  165. chunk = reader.readframes(frame_width)
  166. energies.append(audioop.rms(chunk, sample_width * n_channels))
  167. threshold = percentile(energies, 0.2)
  168. elapsed_time = 0
  169. regions = []
  170. region_start = None
  171. for energy in energies:
  172. is_silence = energy <= threshold
  173. max_exceeded = region_start and elapsed_time - region_start >= max_region_size
  174. if (max_exceeded or is_silence) and region_start:
  175. if elapsed_time - region_start >= min_region_size:
  176. regions.append((region_start, elapsed_time))
  177. region_start = None
  178. elif (not region_start) and (not is_silence):
  179. region_start = elapsed_time
  180. elapsed_time += chunk_duration
  181. return regions
  182. def generate_subtitles( # pylint: disable=too-many-locals,too-many-arguments
  183. source_path,
  184. output=None,
  185. concurrency=DEFAULT_CONCURRENCY,
  186. src_language=DEFAULT_SRC_LANGUAGE,
  187. dst_language=DEFAULT_DST_LANGUAGE,
  188. subtitle_file_format=DEFAULT_SUBTITLE_FORMAT,
  189. api_key=None,
  190. ):
  191. """
  192. Given an input audio/video file, generate subtitles in the specified language and format.
  193. """
  194. audio_filename, audio_rate = extract_audio(source_path)
  195. regions = find_speech_regions(audio_filename)
  196. pool = multiprocessing.Pool(concurrency)
  197. converter = FLACConverter(source_path=audio_filename)
  198. recognizer = SpeechRecognizer(language=src_language, rate=audio_rate,
  199. api_key=GOOGLE_SPEECH_API_KEY)
  200. transcripts = []
  201. if regions:
  202. try:
  203. widgets = ["Converting speech regions to FLAC files: ", Percentage(), ' ', Bar(), ' ',
  204. ETA()]
  205. pbar = ProgressBar(widgets=widgets, maxval=len(regions)).start()
  206. extracted_regions = []
  207. for i, extracted_region in enumerate(pool.imap(converter, regions)):
  208. extracted_regions.append(extracted_region)
  209. pbar.update(i)
  210. pbar.finish()
  211. widgets = ["Performing speech recognition: ", Percentage(), ' ', Bar(), ' ', ETA()]
  212. pbar = ProgressBar(widgets=widgets, maxval=len(regions)).start()
  213. for i, transcript in enumerate(pool.imap(recognizer, extracted_regions)):
  214. transcripts.append(transcript)
  215. pbar.update(i)
  216. pbar.finish()
  217. if src_language.split("-")[0] != dst_language.split("-")[0]:
  218. if api_key:
  219. google_translate_api_key = api_key
  220. translator = Translator(dst_language, google_translate_api_key,
  221. dst=dst_language,
  222. src=src_language)
  223. prompt = "Translating from {0} to {1}: ".format(src_language, dst_language)
  224. widgets = [prompt, Percentage(), ' ', Bar(), ' ', ETA()]
  225. pbar = ProgressBar(widgets=widgets, maxval=len(regions)).start()
  226. translated_transcripts = []
  227. for i, transcript in enumerate(pool.imap(translator, transcripts)):
  228. translated_transcripts.append(transcript)
  229. pbar.update(i)
  230. pbar.finish()
  231. transcripts = translated_transcripts
  232. else:
  233. print(
  234. "Error: Subtitle translation requires specified Google Translate API key. "
  235. "See --help for further information."
  236. )
  237. return 1
  238. except KeyboardInterrupt:
  239. pbar.finish()
  240. pool.terminate()
  241. pool.join()
  242. print("Cancelling transcription")
  243. raise
  244. timed_subtitles = [(r, t) for r, t in zip(regions, transcripts) if t]
  245. formatter = FORMATTERS.get(subtitle_file_format)
  246. formatted_subtitles = formatter(timed_subtitles)
  247. dest = output
  248. if not dest:
  249. base = os.path.splitext(source_path)[0]
  250. dest = "{base}.{format}".format(base=base, format=subtitle_file_format)
  251. with open(dest, 'wb') as output_file:
  252. output_file.write(formatted_subtitles.encode("utf-8"))
  253. os.remove(audio_filename)
  254. return dest
  255. def validate(args):
  256. """
  257. Check that the CLI arguments passed to autosub are valid.
  258. """
  259. if args.format not in FORMATTERS:
  260. print(
  261. "Subtitle format not supported. "
  262. "Run with --list-formats to see all supported formats."
  263. )
  264. return False
  265. if args.src_language not in LANGUAGE_CODES.keys():
  266. print(
  267. "Source language not supported. "
  268. "Run with --list-languages to see all supported languages."
  269. )
  270. return False
  271. if args.dst_language not in LANGUAGE_CODES.keys():
  272. print(
  273. "Destination language not supported. "
  274. "Run with --list-languages to see all supported languages."
  275. )
  276. return False
  277. if not args.source_path:
  278. print("Error: You need to specify a source path.")
  279. return False
  280. return True
  281. def main():
  282. """
  283. Run autosub as a command-line program.
  284. """
  285. parser = argparse.ArgumentParser()
  286. parser.add_argument('source_path', help="Path to the video or audio file to subtitle",
  287. nargs='?')
  288. parser.add_argument('-C', '--concurrency', help="Number of concurrent API requests to make",
  289. type=int, default=DEFAULT_CONCURRENCY)
  290. parser.add_argument('-o', '--output',
  291. help="Output path for subtitles (by default, subtitles are saved in \
  292. the same directory and name as the source path)")
  293. parser.add_argument('-F', '--format', help="Destination subtitle format",
  294. default=DEFAULT_SUBTITLE_FORMAT)
  295. parser.add_argument('-S', '--src-language', help="Language spoken in source file",
  296. default=DEFAULT_SRC_LANGUAGE)
  297. parser.add_argument('-D', '--dst-language', help="Desired language for the subtitles",
  298. default=DEFAULT_DST_LANGUAGE)
  299. parser.add_argument('-K', '--api-key',
  300. help="The Google Translate API key to be used. \
  301. (Required for subtitle translation)")
  302. parser.add_argument('--list-formats', help="List all available subtitle formats",
  303. action='store_true')
  304. parser.add_argument('--list-languages', help="List all available source/destination languages",
  305. action='store_true')
  306. args = parser.parse_args()
  307. if args.list_formats:
  308. print("List of formats:")
  309. for subtitle_format in FORMATTERS:
  310. print("{format}".format(format=subtitle_format))
  311. return 0
  312. if args.list_languages:
  313. print("List of all languages:")
  314. for code, language in sorted(LANGUAGE_CODES.items()):
  315. print("{code}\t{language}".format(code=code, language=language))
  316. return 0
  317. if not validate(args):
  318. return 1
  319. try:
  320. subtitle_file_path = generate_subtitles(
  321. source_path=args.source_path,
  322. concurrency=args.concurrency,
  323. src_language=args.src_language,
  324. dst_language=args.dst_language,
  325. api_key=args.api_key,
  326. subtitle_file_format=args.format,
  327. output=args.output,
  328. )
  329. print("Subtitles file created at {}".format(subtitle_file_path))
  330. except KeyboardInterrupt:
  331. return 1
  332. return 0
  333. if __name__ == '__main__':
  334. sys.exit(main())