__init__.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. """
  2. Defines autosub's main functionality.
  3. """
  4. #!/usr/bin/env python
  5. from __future__ import absolute_import, print_function, unicode_literals
  6. import argparse
  7. import audioop
  8. import math
  9. import multiprocessing
  10. import os
  11. from json import JSONDecodeError
  12. import subprocess
  13. import sys
  14. import tempfile
  15. import wave
  16. import json
  17. import requests
  18. try:
  19. from json.decoder import JSONDecodeError
  20. except ImportError:
  21. JSONDecodeError = ValueError
  22. from googleapiclient.discovery import build
  23. from progressbar import ProgressBar, Percentage, Bar, ETA
  24. from autosub.constants import (
  25. LANGUAGE_CODES, GOOGLE_SPEECH_API_KEY, GOOGLE_SPEECH_API_URL,
  26. )
  27. from autosub.formatters import FORMATTERS
  28. DEFAULT_SUBTITLE_FORMAT = 'srt'
  29. DEFAULT_CONCURRENCY = 10
  30. DEFAULT_SRC_LANGUAGE = 'en'
  31. DEFAULT_DST_LANGUAGE = 'en'
  32. def percentile(arr, percent):
  33. """
  34. Calculate the given percentile of arr.
  35. """
  36. arr = sorted(arr)
  37. index = (len(arr) - 1) * percent
  38. floor = math.floor(index)
  39. ceil = math.ceil(index)
  40. if floor == ceil:
  41. return arr[int(index)]
  42. low_value = arr[int(floor)] * (ceil - index)
  43. high_value = arr[int(ceil)] * (index - floor)
  44. return low_value + high_value
  45. class FLACConverter(object): # pylint: disable=too-few-public-methods
  46. """
  47. Class for converting a region of an input audio or video file into a FLAC audio file
  48. """
  49. def __init__(self, source_path, include_before=0.25, include_after=0.25):
  50. self.source_path = source_path
  51. self.include_before = include_before
  52. self.include_after = include_after
  53. def __call__(self, region):
  54. try:
  55. start, end = region
  56. start = max(0, start - self.include_before)
  57. end += self.include_after
  58. #delete=False necessary for running on Windows
  59. temp = tempfile.NamedTemporaryFile(suffix='.flac', delete=False)
  60. program_ffmpeg = which("ffmpeg")
  61. command = [str(program_ffmpeg), "-ss", str(start), "-t", str(end - start),
  62. "-y", "-i", self.source_path,
  63. "-loglevel", "error", temp.name]
  64. use_shell = True if os.name == "nt" else False
  65. subprocess.check_output(command, stdin=open(os.devnull), shell=use_shell)
  66. read_data = temp.read()
  67. temp.close()
  68. os.unlink(temp.name)
  69. return read_data
  70. except KeyboardInterrupt:
  71. return None
  72. class SpeechRecognizer(object): # pylint: disable=too-few-public-methods
  73. """
  74. Class for performing speech-to-text for an input FLAC file.
  75. """
  76. def __init__(self, language="en", rate=44100, retries=3, api_key=GOOGLE_SPEECH_API_KEY):
  77. self.language = language
  78. self.rate = rate
  79. self.api_key = api_key
  80. self.retries = retries
  81. def __call__(self, data):
  82. try:
  83. for _ in range(self.retries):
  84. url = GOOGLE_SPEECH_API_URL.format(lang=self.language, key=self.api_key)
  85. headers = {"Content-Type": "audio/x-flac; rate=%d" % self.rate}
  86. try:
  87. resp = requests.post(url, data=data, headers=headers)
  88. except requests.exceptions.ConnectionError:
  89. continue
  90. for line in resp.content.decode('utf-8').split("\n"):
  91. try:
  92. line = json.loads(line)
  93. line = line['result'][0]['alternative'][0]['transcript']
  94. return line[:1].upper() + line[1:]
  95. except IndexError:
  96. # no result
  97. continue
  98. except JSONDecodeError:
  99. continue
  100. except KeyboardInterrupt:
  101. return None
  102. class Translator(object): # pylint: disable=too-few-public-methods
  103. """
  104. Class for translating a sentence from a one language to another.
  105. """
  106. def __init__(self, language, api_key, src, dst):
  107. self.language = language
  108. self.api_key = api_key
  109. self.service = build('translate', 'v2',
  110. developerKey=self.api_key)
  111. self.src = src
  112. self.dst = dst
  113. def __call__(self, sentence):
  114. try:
  115. if not sentence:
  116. return None
  117. result = self.service.translations().list( # pylint: disable=no-member
  118. source=self.src,
  119. target=self.dst,
  120. q=[sentence]
  121. ).execute()
  122. if 'translations' in result and result['translations'] and \
  123. 'translatedText' in result['translations'][0]:
  124. return result['translations'][0]['translatedText']
  125. return None
  126. except KeyboardInterrupt:
  127. return None
  128. def which(program):
  129. """
  130. Return the path for a given executable.
  131. """
  132. def is_exe(file_path):
  133. """
  134. Checks whether a file is executable.
  135. """
  136. return os.path.isfile(file_path) and os.access(file_path, os.X_OK)
  137. #necessary to run on Windows
  138. if os.name == "nt":
  139. program += ".exe"
  140. fpath, _ = os.path.split(program)
  141. if fpath:
  142. if is_exe(program):
  143. return program
  144. else:
  145. #looks for file in the script execution folder before checking on system path
  146. current_dir = os.getcwd()
  147. local_program = os.path.join(current_dir, program)
  148. if is_exe(local_program):
  149. return local_program
  150. else:
  151. for path in os.environ["PATH"].split(os.pathsep):
  152. path = path.strip('"')
  153. exe_file = os.path.join(path, program)
  154. if is_exe(exe_file):
  155. return exe_file
  156. return None
  157. def extract_audio(filename, channels=1, rate=16000):
  158. """
  159. Extract audio from an input file to a temporary WAV file.
  160. """
  161. temp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
  162. if not os.path.isfile(filename):
  163. print("The given file does not exist: {}".format(filename))
  164. raise Exception("Invalid filepath: {}".format(filename))
  165. program_ffmpeg = which("ffmpeg")
  166. if not program_ffmpeg:
  167. print("ffmpeg: Executable not found on machine.")
  168. raise Exception("Dependency not found: ffmpeg")
  169. command = [str(program_ffmpeg), "-y", "-i", filename,
  170. "-ac", str(channels), "-ar", str(rate),
  171. "-loglevel", "error", temp.name]
  172. use_shell = True if os.name == "nt" else False
  173. subprocess.check_output(command, stdin=open(os.devnull), shell=use_shell)
  174. return temp.name, rate
  175. def find_speech_regions(filename, frame_width=4096, min_region_size=0.5, max_region_size=6): # pylint: disable=too-many-locals
  176. """
  177. Perform voice activity detection on a given audio file.
  178. """
  179. reader = wave.open(filename)
  180. sample_width = reader.getsampwidth()
  181. rate = reader.getframerate()
  182. n_channels = reader.getnchannels()
  183. chunk_duration = float(frame_width) / rate
  184. n_chunks = int(math.ceil(reader.getnframes()*1.0 / frame_width))
  185. energies = []
  186. for _ in range(n_chunks):
  187. chunk = reader.readframes(frame_width)
  188. energies.append(audioop.rms(chunk, sample_width * n_channels))
  189. threshold = percentile(energies, 0.2)
  190. elapsed_time = 0
  191. regions = []
  192. region_start = None
  193. for energy in energies:
  194. is_silence = energy <= threshold
  195. max_exceeded = region_start and elapsed_time - region_start >= max_region_size
  196. if (max_exceeded or is_silence) and region_start:
  197. if elapsed_time - region_start >= min_region_size:
  198. regions.append((region_start, elapsed_time))
  199. region_start = None
  200. elif (not region_start) and (not is_silence):
  201. region_start = elapsed_time
  202. elapsed_time += chunk_duration
  203. return regions
  204. def generate_subtitles( # pylint: disable=too-many-locals,too-many-arguments
  205. source_path,
  206. output=None,
  207. concurrency=DEFAULT_CONCURRENCY,
  208. src_language=DEFAULT_SRC_LANGUAGE,
  209. dst_language=DEFAULT_DST_LANGUAGE,
  210. subtitle_file_format=DEFAULT_SUBTITLE_FORMAT,
  211. api_key=None,
  212. ):
  213. """
  214. Given an input audio/video file, generate subtitles in the specified language and format.
  215. """
  216. if os.name != "nt" and "Darwin" in os.uname():
  217. #the default unix fork method does not work on Mac OS
  218. #need to use forkserver
  219. if 'forkserver' != multiprocessing.get_start_method(allow_none=True):
  220. multiprocessing.set_start_method('forkserver')
  221. audio_filename, audio_rate = extract_audio(source_path)
  222. regions = find_speech_regions(audio_filename)
  223. pool = multiprocessing.Pool(concurrency)
  224. converter = FLACConverter(source_path=audio_filename)
  225. recognizer = SpeechRecognizer(language=src_language, rate=audio_rate,
  226. api_key=GOOGLE_SPEECH_API_KEY)
  227. transcripts = []
  228. if regions:
  229. try:
  230. widgets = ["Converting speech regions to FLAC files: ", Percentage(), ' ', Bar(), ' ',
  231. ETA()]
  232. pbar = ProgressBar(widgets=widgets, maxval=len(regions)).start()
  233. extracted_regions = []
  234. for i, extracted_region in enumerate(pool.imap(converter, regions)):
  235. extracted_regions.append(extracted_region)
  236. pbar.update(i)
  237. pbar.finish()
  238. widgets = ["Performing speech recognition: ", Percentage(), ' ', Bar(), ' ', ETA()]
  239. pbar = ProgressBar(widgets=widgets, maxval=len(regions)).start()
  240. for i, transcript in enumerate(pool.imap(recognizer, extracted_regions)):
  241. transcripts.append(transcript)
  242. pbar.update(i)
  243. pbar.finish()
  244. if src_language.split("-")[0] != dst_language.split("-")[0]:
  245. if api_key:
  246. google_translate_api_key = api_key
  247. translator = Translator(dst_language, google_translate_api_key,
  248. dst=dst_language,
  249. src=src_language)
  250. prompt = "Translating from {0} to {1}: ".format(src_language, dst_language)
  251. widgets = [prompt, Percentage(), ' ', Bar(), ' ', ETA()]
  252. pbar = ProgressBar(widgets=widgets, maxval=len(regions)).start()
  253. translated_transcripts = []
  254. for i, transcript in enumerate(pool.imap(translator, transcripts)):
  255. translated_transcripts.append(transcript)
  256. pbar.update(i)
  257. pbar.finish()
  258. transcripts = translated_transcripts
  259. else:
  260. print(
  261. "Error: Subtitle translation requires specified Google Translate API key. "
  262. "See --help for further information."
  263. )
  264. return 1
  265. except KeyboardInterrupt:
  266. pbar.finish()
  267. pool.terminate()
  268. pool.join()
  269. print("Cancelling transcription")
  270. raise
  271. timed_subtitles = [(r, t) for r, t in zip(regions, transcripts) if t]
  272. formatter = FORMATTERS.get(subtitle_file_format)
  273. formatted_subtitles = formatter(timed_subtitles)
  274. dest = output
  275. if not dest:
  276. base = os.path.splitext(source_path)[0]
  277. dest = "{base}.{format}".format(base=base, format=subtitle_file_format)
  278. with open(dest, 'wb') as output_file:
  279. output_file.write(formatted_subtitles.encode("utf-8"))
  280. os.remove(audio_filename)
  281. return dest
  282. def validate(args):
  283. """
  284. Check that the CLI arguments passed to autosub are valid.
  285. """
  286. if args.format not in FORMATTERS:
  287. print(
  288. "Subtitle format not supported. "
  289. "Run with --list-formats to see all supported formats."
  290. )
  291. return False
  292. if args.src_language not in LANGUAGE_CODES.keys():
  293. print(
  294. "Source language not supported. "
  295. "Run with --list-languages to see all supported languages."
  296. )
  297. return False
  298. if args.dst_language not in LANGUAGE_CODES.keys():
  299. print(
  300. "Destination language not supported. "
  301. "Run with --list-languages to see all supported languages."
  302. )
  303. return False
  304. if not args.source_path:
  305. print("Error: You need to specify a source path.")
  306. return False
  307. return True
  308. def main():
  309. """
  310. Run autosub as a command-line program.
  311. """
  312. parser = argparse.ArgumentParser()
  313. parser.add_argument('source_path', help="Path to the video or audio file to subtitle",
  314. nargs='?')
  315. parser.add_argument('-C', '--concurrency', help="Number of concurrent API requests to make",
  316. type=int, default=DEFAULT_CONCURRENCY)
  317. parser.add_argument('-o', '--output',
  318. help="Output path for subtitles (by default, subtitles are saved in \
  319. the same directory and name as the source path)")
  320. parser.add_argument('-F', '--format', help="Destination subtitle format",
  321. default=DEFAULT_SUBTITLE_FORMAT)
  322. parser.add_argument('-S', '--src-language', help="Language spoken in source file",
  323. default=DEFAULT_SRC_LANGUAGE)
  324. parser.add_argument('-D', '--dst-language', help="Desired language for the subtitles",
  325. default=DEFAULT_DST_LANGUAGE)
  326. parser.add_argument('-K', '--api-key',
  327. help="The Google Translate API key to be used. \
  328. (Required for subtitle translation)")
  329. parser.add_argument('--list-formats', help="List all available subtitle formats",
  330. action='store_true')
  331. parser.add_argument('--list-languages', help="List all available source/destination languages",
  332. action='store_true')
  333. args = parser.parse_args()
  334. if args.list_formats:
  335. print("List of formats:")
  336. for subtitle_format in FORMATTERS:
  337. print("{format}".format(format=subtitle_format))
  338. return 0
  339. if args.list_languages:
  340. print("List of all languages:")
  341. for code, language in sorted(LANGUAGE_CODES.items()):
  342. print("{code}\t{language}".format(code=code, language=language))
  343. return 0
  344. if not validate(args):
  345. return 1
  346. try:
  347. subtitle_file_path = generate_subtitles(
  348. source_path=args.source_path,
  349. concurrency=args.concurrency,
  350. src_language=args.src_language,
  351. dst_language=args.dst_language,
  352. api_key=args.api_key,
  353. subtitle_file_format=args.format,
  354. output=args.output,
  355. )
  356. print("Subtitles file created at {}".format(subtitle_file_path))
  357. except KeyboardInterrupt:
  358. return 1
  359. return 0
  360. if __name__ == '__main__':
  361. sys.exit(main())