sse_app.py 12 KB


  1. # sse_app.py — 將 main_translate.py 與 SSE/前端整合為同一支 FastAPI 應用
  2. # 依賴:
  3. # pip install fastapi uvicorn[standard] transformers torch opencc-python-reimplemented sounddevice
  4. #
  5. # 啟動:
  6. # python sse_app.py --host 0.0.0.0 --port 7010
  7. import os
  8. import sys
  9. import threading
  10. import queue
  11. import atexit
  12. import asyncio
  13. import signal
  14. from typing import Optional, AsyncIterator, Set
  15. from contextlib import asynccontextmanager
  16. import torch
  17. from whisper_live.client import TranscriptionClient
  18. from transformers import MarianMTModel, MarianTokenizer
  19. import opencc # 簡→繁(台灣用詞)
  20. from fastapi import FastAPI
  21. from fastapi.responses import HTMLResponse, StreamingResponse, PlainTextResponse
  22. from fastapi.middleware.cors import CORSMiddleware
  23. import uvicorn
  24. from pathlib import Path
  25. # ===== 可調參數(你原本的) =====
  26. ASR_HOST = "192.168.192.83"
  27. ASR_PORT = 9090
  28. LANG = "en"
  29. WHISPER_MODEL = "small"
  30. USE_VAD = False
  31. PRINT_ASR_TO_STDERR = True # 英文印 stderr、中文印 stdout
  32. # ============================
  33. # ---------- 不緩衝 / 行緩衝輸出 ----------
  34. os.environ["PYTHONUNBUFFERED"] = "1"
  35. if hasattr(sys.stdout, "reconfigure"):
  36. try:
  37. sys.stdout.reconfigure(line_buffering=True)
  38. except Exception:
  39. pass
  40. if hasattr(sys.stderr, "reconfigure"):
  41. try:
  42. sys.stderr.reconfigure(line_buffering=True)
  43. except Exception:
  44. pass
  45. # ========== FastAPI 與 SSE 相關 ==========
  46. HOST_DEFAULT = "0.0.0.0"
  47. PORT_DEFAULT = 7010
  48. KEEPALIVE = 15.0
  49. QUEUE_SIZE = 500
  50. app = FastAPI()
  51. app.add_middleware(
  52. CORSMiddleware,
  53. allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
  54. )
  55. # Asyncio event loop(在 lifespan 啟動時取得)
  56. _event_loop: Optional[asyncio.AbstractEventLoop] = None
  57. # SSE 訂閱者(每個訂閱者一個 asyncio.Queue)
  58. _asr_subscribers: Set[asyncio.Queue[str]] = set()
  59. _trans_subscribers: Set[asyncio.Queue[str]] = set()
  60. def _threadsafe_broadcast(asr_or_trans: str, line: str):
  61. """從任何執行緒安全地廣播到對應的 asyncio.Queue。"""
  62. global _event_loop
  63. subs = _asr_subscribers if asr_or_trans == "asr" else _trans_subscribers
  64. if _event_loop is None:
  65. return
  66. for q in list(subs):
  67. try:
  68. _event_loop.call_soon_threadsafe(q.put_nowait, line)
  69. except Exception:
  70. pass
  71. async def _event_stream(kind: str) -> AsyncIterator[bytes]:
  72. """SSE generator:kind ∈ {"asr","trans"}。"""
  73. q: asyncio.Queue[str] = asyncio.Queue(maxsize=QUEUE_SIZE)
  74. subs = _asr_subscribers if kind == "asr" else _trans_subscribers
  75. subs.add(q)
  76. try:
  77. # 初次 keepalive
  78. yield b": connected\n\n"
  79. while True:
  80. try:
  81. item = await asyncio.wait_for(q.get(), timeout=KEEPALIVE)
  82. data = item.replace("\r", "").replace("\n", "\\n")
  83. yield f"data: {data}\n\n".encode("utf-8")
  84. except asyncio.TimeoutError:
  85. yield b": keepalive\n\n"
  86. finally:
  87. subs.discard(q)
  88. # ========== 你的翻譯執行緒(原封不動 + 廣播) ==========
  89. def log_asr(*args, **kwargs):
  90. text = " ".join(str(a) for a in args)
  91. # 終端 tee
  92. print(*args, file=(sys.stderr if PRINT_ASR_TO_STDERR else sys.stdout), flush=True, **kwargs)
  93. # SSE 廣播
  94. _threadsafe_broadcast("asr", text)
  95. def log_trans(*args, **kwargs):
  96. text = " ".join(str(a) for a in args)
  97. # 終端 tee(中文用 stdout)
  98. print(*args, file=sys.stdout, flush=True, **kwargs)
  99. # SSE 廣播
  100. _threadsafe_broadcast("trans", text)
  101. class TranslatorWorker(threading.Thread):
  102. def __init__(self, model_name: str = "Helsinki-NLP/opus-mt-en-zh", max_len: int = 256):
  103. super().__init__(daemon=True)
  104. self.q: "queue.Queue[Optional[str]]" = queue.Queue(maxsize=500)
  105. self.max_len = max_len
  106. self.device = "cuda" if torch.cuda.is_available() else "cpu"
  107. self.tok = MarianTokenizer.from_pretrained(model_name)
  108. self.mt = MarianMTModel.from_pretrained(model_name)
  109. self.mt.to(self.device)
  110. self.mt.eval()
  111. self.cc = opencc.OpenCC("s2twp")
  112. self._stop_evt = threading.Event()
  113. def translate(self, text: str) -> str:
  114. if not text or not text.strip():
  115. return ""
  116. with torch.no_grad():
  117. batch = self.tok([text.strip()], return_tensors="pt", padding=True,
  118. truncation=True, max_length=self.max_len)
  119. batch = {k: v.to(self.device) for k, v in batch.items()}
  120. gen = self.mt.generate(**batch, max_new_tokens=self.max_len)
  121. zh_simplified = self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip()
  122. return self.cc.convert(zh_simplified)
  123. def run(self):
  124. while not self._stop_evt.is_set():
  125. item = self.q.get()
  126. if item == "__STOP__":
  127. break
  128. try:
  129. zh = self.translate(item)
  130. if zh:
  131. log_trans("TRANS_TW:", zh)
  132. except Exception as e:
  133. log_trans(f"[TranslatorWorker] 翻譯失敗: {e}")
  134. def submit(self, text: str):
  135. if not text or not text.strip():
  136. return
  137. try:
  138. self.q.put_nowait(text)
  139. except queue.Full:
  140. try:
  141. _ = self.q.get_nowait()
  142. except queue.Empty:
  143. pass
  144. self.q.put_nowait(text)
  145. def stop(self):
  146. self._stop_evt.set()
  147. try:
  148. self.q.put_nowait("__STOP__")
  149. except queue.Full:
  150. try:
  151. _ = self.q.get_nowait()
  152. except queue.Empty:
  153. pass
  154. self.q.put_nowait("__STOP__")
  155. translator = TranslatorWorker(model_name="Helsinki-NLP/opus-mt-en-zh", max_len=256)
  156. # ---------- 英文輸出「去重」狀態 ----------
  157. _last_asr_printed = None # 僅用於避免連續印出完全相同的一段英文
  158. def on_asr(text: str, segments):
  159. global _last_asr_printed
  160. if not text:
  161. return
  162. normalized = text.strip()
  163. if normalized != (_last_asr_printed or ""):
  164. log_asr("ASR:", text)
  165. _last_asr_printed = normalized
  166. translator.submit(text)
  167. # ========== 啟動 / 停止 Whisper Live 的執行緒 ==========
  168. _tc: Optional[TranscriptionClient] = None
  169. _tc_thread: Optional[threading.Thread] = None
  170. def _tc_runner():
  171. global _tc
  172. try:
  173. _tc()
  174. except KeyboardInterrupt:
  175. pass
  176. except Exception as e:
  177. log_asr(f"[ASR Thread] error: {e}")
  178. def start_services():
  179. """在應用啟動時啟動 Translator 與 TranscriptionClient 執行緒。"""
  180. global _tc, _tc_thread
  181. if not translator.is_alive():
  182. translator.start()
  183. _tc = TranscriptionClient(
  184. host=ASR_HOST,
  185. port=ASR_PORT,
  186. lang=LANG,
  187. translate=False,
  188. model=WHISPER_MODEL,
  189. use_vad=USE_VAD,
  190. mute_audio_playback=True,
  191. save_output_recording=True,
  192. output_recording_filename="./output_recording.wav",
  193. transcription_callback=on_asr,
  194. )
  195. _tc_thread = threading.Thread(target=_tc_runner, name="ASRThread", daemon=True)
  196. _tc_thread.start()
  197. log_asr("[bridge] services started.")
  198. def stop_services():
  199. """在應用關閉時嘗試優雅停止。"""
  200. try:
  201. translator.stop()
  202. except Exception:
  203. pass
  204. # TranscriptionClient 沒有穩定的 stop API,就讓程序結束時一起回收
  205. log_asr("[bridge] services stopping...")
  206. # FastAPI lifespan(避免 on_event deprecation)
  207. @asynccontextmanager
  208. async def lifespan(app_: FastAPI):
  209. global _event_loop
  210. _event_loop = asyncio.get_running_loop()
  211. start_services()
  212. try:
  213. yield
  214. finally:
  215. stop_services()
  216. app.router.lifespan_context = lifespan
  217. # ========== 路由 ==========
  218. @app.get("/", response_class=HTMLResponse)
  219. async def index():
  220. html_path = Path(__file__).parent / "index.html"
  221. return HTMLResponse(content=html_path.read_text(encoding="utf-8"))
  222. @app.get("/stream/asr")
  223. async def stream_asr():
  224. return StreamingResponse(
  225. _event_stream("asr"),
  226. media_type="text/event-stream",
  227. headers={"Content-Type": "text/event-stream; charset=utf-8"},
  228. )
  229. @app.get("/stream/trans")
  230. async def stream_trans():
  231. return StreamingResponse(
  232. _event_stream("trans"),
  233. media_type="text/event-stream",
  234. headers={"Content-Type": "text/event-stream; charset=utf-8"},
  235. )
  236. @app.get("/health", response_class=PlainTextResponse)
  237. async def health():
  238. ok = translator.is_alive() and (_tc_thread is not None and _tc_thread.is_alive())
  239. return "ok" if ok else "dead"
  240. # 內建頁面:左右分欄(ASR / TRANS_TW)
  241. CLIENT_HTML = """
  242. <!doctype html>
  243. <html lang="zh-Hant">
  244. <head>
  245. <meta charset="utf-8" />
  246. <meta name="viewport" content="width=device-width,initial-scale=1" />
  247. <title>ASR / 翻譯 分欄顯示</title>
  248. <style>
  249. :root { --bg:#0b1020; --panel:#121833; --fg:#e7ebff; --muted:#9aa4c7; --blue:#7aa2ff; --green:#93e1a4; }
  250. * { box-sizing: border-box; }
  251. body { margin:0; background:var(--bg); color:var(--fg); font:14px/1.5 ui-monospace, SFMono-Regular, Menlo, Consolas, "Noto Sans TC", monospace; height:100vh; display:flex; flex-direction:column; }
  252. header { background:#101739; padding:10px 14px; border-bottom:1px solid #1f2650; display:flex; justify-content:space-between; align-items:center; }
  253. h1 { margin:0; font-size:16px; }
  254. .wrap { flex:1; display:grid; grid-template-columns:1fr 1fr; gap:8px; padding:8px; min-height:0; }
  255. .col { display:flex; flex-direction:column; border:1px solid #1f2650; border-radius:10px; background:var(--panel); min-height:0; }
  256. .col h2 { margin:0; padding:8px 10px; font-size:13px; border-bottom:1px solid #1f2650; }
  257. .stream { flex:1; overflow:auto; padding:10px; white-space:pre-wrap; word-break:break-word; }
  258. .asr .line::before { content:"ASR "; color:var(--blue); font-weight:700; margin-right:6px; }
  259. .trans .line::before { content:"TRANS_TW "; color:var(--green); font-weight:700; margin-right:6px; }
  260. .controls { display:flex; gap:8px; }
  261. button { background:var(--panel); color:var(--fg); border:1px solid #2a3570; border-radius:999px; padding:6px 12px; cursor:pointer; }
  262. button:hover { border-color:var(--blue); }
  263. </style>
  264. </head>
  265. <body>
  266. <header>
  267. <h1>ASR / 翻譯 分欄顯示</h1>
  268. <div class="controls">
  269. <button id="clear">清除</button>
  270. <button id="freeze">暫停自動捲動</button>
  271. </div>
  272. </header>
  273. <div class="wrap">
  274. <section class="col asr">
  275. <h2>英文 ASR</h2>
  276. <div id="asr" class="stream"></div>
  277. </section>
  278. <section class="col trans">
  279. <h2>中文 翻譯</h2>
  280. <div id="trans" class="stream"></div>
  281. </section>
  282. </div>
  283. <script>
  284. let autoScroll = true;
  285. const asrEl = document.getElementById("asr");
  286. const transEl = document.getElementById("trans");
  287. document.getElementById("clear").onclick = () => { asrEl.textContent=""; transEl.textContent=""; };
  288. document.getElementById("freeze").onclick = (e) => { autoScroll = !autoScroll; e.target.textContent = autoScroll ? "暫停自動捲動" : "恢復自動捲動"; };
  289. function connect(url, container){
  290. const es = new EventSource(url);
  291. es.onmessage = (e) => {
  292. const text = (e.data || "").replaceAll("\\\\n","\\n");
  293. const div = document.createElement("div");
  294. div.className = "line";
  295. div.textContent = text;
  296. container.appendChild(div);
  297. if (autoScroll) container.scrollTop = container.scrollHeight;
  298. };
  299. es.onerror = () => {
  300. const div = document.createElement("div");
  301. div.className="line";
  302. div.style.color="#ff9b9b";
  303. div.textContent="[連線中斷,稍後自動重試]";
  304. container.appendChild(div);
  305. };
  306. }
  307. connect("/stream/asr", asrEl);
  308. connect("/stream/trans", transEl);
  309. </script>
  310. </body>
  311. </html>
  312. """
  313. def main():
  314. import argparse
  315. p = argparse.ArgumentParser(description="ASR + Translator integrated with SSE/Front-end")
  316. p.add_argument("--host", default=HOST_DEFAULT)
  317. p.add_argument("--port", type=int, default=PORT_DEFAULT)
  318. args = p.parse_args()
  319. uvicorn.run("sse_app:app", host=args.host, port=args.port, reload=False)
  320. if __name__ == "__main__":
  321. main()