|
@@ -0,0 +1,356 @@
|
|
|
+# sse_app.py — 將 main_translate.py 與 SSE/前端整合為同一支 FastAPI 應用
|
|
|
+# 依賴:
|
|
|
+# pip install fastapi uvicorn[standard] transformers torch opencc-python-reimplemented sounddevice
|
|
|
+#
|
|
|
+# 啟動:
|
|
|
+# python sse_app.py --host 0.0.0.0 --port 7010
|
|
|
+
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import threading
|
|
|
+import queue
|
|
|
+import atexit
|
|
|
+import asyncio
|
|
|
+import signal
|
|
|
+from typing import Optional, AsyncIterator, Set
|
|
|
+from contextlib import asynccontextmanager
|
|
|
+
|
|
|
+import torch
|
|
|
+from whisper_live.client import TranscriptionClient
|
|
|
+from transformers import MarianMTModel, MarianTokenizer
|
|
|
+import opencc # 簡→繁(台灣用詞)
|
|
|
+
|
|
|
+from fastapi import FastAPI
|
|
|
+from fastapi.responses import HTMLResponse, StreamingResponse, PlainTextResponse
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+import uvicorn
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+# ===== 可調參數(你原本的) =====
|
|
|
+ASR_HOST = "192.168.192.83"
|
|
|
+ASR_PORT = 9090
|
|
|
+LANG = "en"
|
|
|
+WHISPER_MODEL = "small"
|
|
|
+USE_VAD = False
|
|
|
+PRINT_ASR_TO_STDERR = True # 英文印 stderr、中文印 stdout
|
|
|
+# ============================
|
|
|
+
|
|
|
+# ---------- 不緩衝 / 行緩衝輸出 ----------
|
|
|
+os.environ["PYTHONUNBUFFERED"] = "1"
|
|
|
+if hasattr(sys.stdout, "reconfigure"):
|
|
|
+ try:
|
|
|
+ sys.stdout.reconfigure(line_buffering=True)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+if hasattr(sys.stderr, "reconfigure"):
|
|
|
+ try:
|
|
|
+ sys.stderr.reconfigure(line_buffering=True)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+# ========== FastAPI 與 SSE 相關 ==========
|
|
|
+HOST_DEFAULT = "0.0.0.0"
|
|
|
+PORT_DEFAULT = 7010
|
|
|
+KEEPALIVE = 15.0
|
|
|
+QUEUE_SIZE = 500
|
|
|
+
|
|
|
+app = FastAPI()
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
|
|
|
+)
|
|
|
+
|
|
|
+# Asyncio event loop(在 lifespan 啟動時取得)
|
|
|
+_event_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
|
+
|
|
|
+# SSE 訂閱者(每個訂閱者一個 asyncio.Queue)
|
|
|
+_asr_subscribers: Set[asyncio.Queue[str]] = set()
|
|
|
+_trans_subscribers: Set[asyncio.Queue[str]] = set()
|
|
|
+
|
|
|
+def _threadsafe_broadcast(asr_or_trans: str, line: str):
|
|
|
+ """從任何執行緒安全地廣播到對應的 asyncio.Queue。"""
|
|
|
+ global _event_loop
|
|
|
+ subs = _asr_subscribers if asr_or_trans == "asr" else _trans_subscribers
|
|
|
+ if _event_loop is None:
|
|
|
+ return
|
|
|
+ for q in list(subs):
|
|
|
+ try:
|
|
|
+ _event_loop.call_soon_threadsafe(q.put_nowait, line)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+async def _event_stream(kind: str) -> AsyncIterator[bytes]:
|
|
|
+ """SSE generator:kind ∈ {"asr","trans"}。"""
|
|
|
+ q: asyncio.Queue[str] = asyncio.Queue(maxsize=QUEUE_SIZE)
|
|
|
+ subs = _asr_subscribers if kind == "asr" else _trans_subscribers
|
|
|
+ subs.add(q)
|
|
|
+ try:
|
|
|
+ # 初次 keepalive
|
|
|
+ yield b": connected\n\n"
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ item = await asyncio.wait_for(q.get(), timeout=KEEPALIVE)
|
|
|
+ data = item.replace("\r", "").replace("\n", "\\n")
|
|
|
+ yield f"data: {data}\n\n".encode("utf-8")
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ yield b": keepalive\n\n"
|
|
|
+ finally:
|
|
|
+ subs.discard(q)
|
|
|
+
|
|
|
+# ========== 你的翻譯執行緒(原封不動 + 廣播) ==========
|
|
|
+def log_asr(*args, **kwargs):
|
|
|
+ text = " ".join(str(a) for a in args)
|
|
|
+ # 終端 tee
|
|
|
+ print(*args, file=(sys.stderr if PRINT_ASR_TO_STDERR else sys.stdout), flush=True, **kwargs)
|
|
|
+ # SSE 廣播
|
|
|
+ _threadsafe_broadcast("asr", text)
|
|
|
+
|
|
|
+def log_trans(*args, **kwargs):
|
|
|
+ text = " ".join(str(a) for a in args)
|
|
|
+ # 終端 tee(中文用 stdout)
|
|
|
+ print(*args, file=sys.stdout, flush=True, **kwargs)
|
|
|
+ # SSE 廣播
|
|
|
+ _threadsafe_broadcast("trans", text)
|
|
|
+
|
|
|
+class TranslatorWorker(threading.Thread):
|
|
|
+ def __init__(self, model_name: str = "Helsinki-NLP/opus-mt-en-zh", max_len: int = 256):
|
|
|
+ super().__init__(daemon=True)
|
|
|
+ self.q: "queue.Queue[Optional[str]]" = queue.Queue(maxsize=500)
|
|
|
+ self.max_len = max_len
|
|
|
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ self.tok = MarianTokenizer.from_pretrained(model_name)
|
|
|
+ self.mt = MarianMTModel.from_pretrained(model_name)
|
|
|
+ self.mt.to(self.device)
|
|
|
+ self.mt.eval()
|
|
|
+ self.cc = opencc.OpenCC("s2twp")
|
|
|
+ self._stop_evt = threading.Event()
|
|
|
+
|
|
|
+ def translate(self, text: str) -> str:
|
|
|
+ if not text or not text.strip():
|
|
|
+ return ""
|
|
|
+ with torch.no_grad():
|
|
|
+ batch = self.tok([text.strip()], return_tensors="pt", padding=True,
|
|
|
+ truncation=True, max_length=self.max_len)
|
|
|
+ batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
+ gen = self.mt.generate(**batch, max_new_tokens=self.max_len)
|
|
|
+ zh_simplified = self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip()
|
|
|
+ return self.cc.convert(zh_simplified)
|
|
|
+
|
|
|
+ def run(self):
|
|
|
+ while not self._stop_evt.is_set():
|
|
|
+ item = self.q.get()
|
|
|
+ if item == "__STOP__":
|
|
|
+ break
|
|
|
+ try:
|
|
|
+ zh = self.translate(item)
|
|
|
+ if zh:
|
|
|
+ log_trans("TRANS_TW:", zh)
|
|
|
+ except Exception as e:
|
|
|
+ log_trans(f"[TranslatorWorker] 翻譯失敗: {e}")
|
|
|
+
|
|
|
+ def submit(self, text: str):
|
|
|
+ if not text or not text.strip():
|
|
|
+ return
|
|
|
+ try:
|
|
|
+ self.q.put_nowait(text)
|
|
|
+ except queue.Full:
|
|
|
+ try:
|
|
|
+ _ = self.q.get_nowait()
|
|
|
+ except queue.Empty:
|
|
|
+ pass
|
|
|
+ self.q.put_nowait(text)
|
|
|
+
|
|
|
+ def stop(self):
|
|
|
+ self._stop_evt.set()
|
|
|
+ try:
|
|
|
+ self.q.put_nowait("__STOP__")
|
|
|
+ except queue.Full:
|
|
|
+ try:
|
|
|
+ _ = self.q.get_nowait()
|
|
|
+ except queue.Empty:
|
|
|
+ pass
|
|
|
+ self.q.put_nowait("__STOP__")
|
|
|
+
|
|
|
+translator = TranslatorWorker(model_name="Helsinki-NLP/opus-mt-en-zh", max_len=256)
|
|
|
+
|
|
|
+# ---------- 英文輸出「去重」狀態 ----------
|
|
|
+_last_asr_printed = None # 僅用於避免連續印出完全相同的一段英文
|
|
|
+
|
|
|
+def on_asr(text: str, segments):
|
|
|
+ global _last_asr_printed
|
|
|
+ if not text:
|
|
|
+ return
|
|
|
+ normalized = text.strip()
|
|
|
+ if normalized != (_last_asr_printed or ""):
|
|
|
+ log_asr("ASR:", text)
|
|
|
+ _last_asr_printed = normalized
|
|
|
+ translator.submit(text)
|
|
|
+
|
|
|
+# ========== 啟動 / 停止 Whisper Live 的執行緒 ==========
|
|
|
+_tc: Optional[TranscriptionClient] = None
|
|
|
+_tc_thread: Optional[threading.Thread] = None
|
|
|
+
|
|
|
+def _tc_runner():
|
|
|
+ global _tc
|
|
|
+ try:
|
|
|
+ _tc()
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ pass
|
|
|
+ except Exception as e:
|
|
|
+ log_asr(f"[ASR Thread] error: {e}")
|
|
|
+
|
|
|
+def start_services():
|
|
|
+ """在應用啟動時啟動 Translator 與 TranscriptionClient 執行緒。"""
|
|
|
+ global _tc, _tc_thread
|
|
|
+ if not translator.is_alive():
|
|
|
+ translator.start()
|
|
|
+ _tc = TranscriptionClient(
|
|
|
+ host=ASR_HOST,
|
|
|
+ port=ASR_PORT,
|
|
|
+ lang=LANG,
|
|
|
+ translate=False,
|
|
|
+ model=WHISPER_MODEL,
|
|
|
+ use_vad=USE_VAD,
|
|
|
+ mute_audio_playback=True,
|
|
|
+ save_output_recording=True,
|
|
|
+ output_recording_filename="./output_recording.wav",
|
|
|
+ transcription_callback=on_asr,
|
|
|
+ )
|
|
|
+ _tc_thread = threading.Thread(target=_tc_runner, name="ASRThread", daemon=True)
|
|
|
+ _tc_thread.start()
|
|
|
+ log_asr("[bridge] services started.")
|
|
|
+
|
|
|
+def stop_services():
|
|
|
+ """在應用關閉時嘗試優雅停止。"""
|
|
|
+ try:
|
|
|
+ translator.stop()
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ # TranscriptionClient 沒有穩定的 stop API,就讓程序結束時一起回收
|
|
|
+ log_asr("[bridge] services stopping...")
|
|
|
+
|
|
|
+# FastAPI lifespan(避免 on_event deprecation)
|
|
|
+@asynccontextmanager
|
|
|
+async def lifespan(app_: FastAPI):
|
|
|
+ global _event_loop
|
|
|
+ _event_loop = asyncio.get_running_loop()
|
|
|
+ start_services()
|
|
|
+ try:
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ stop_services()
|
|
|
+
|
|
|
+app.router.lifespan_context = lifespan
|
|
|
+
|
|
|
+# ========== 路由 ==========
|
|
|
+@app.get("/", response_class=HTMLResponse)
|
|
|
+async def index():
|
|
|
+ html_path = Path(__file__).parent / "index.html"
|
|
|
+ return HTMLResponse(content=html_path.read_text(encoding="utf-8"))
|
|
|
+
|
|
|
+@app.get("/stream/asr")
|
|
|
+async def stream_asr():
|
|
|
+ return StreamingResponse(
|
|
|
+ _event_stream("asr"),
|
|
|
+ media_type="text/event-stream",
|
|
|
+ headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
|
+ )
|
|
|
+
|
|
|
+@app.get("/stream/trans")
|
|
|
+async def stream_trans():
|
|
|
+ return StreamingResponse(
|
|
|
+ _event_stream("trans"),
|
|
|
+ media_type="text/event-stream",
|
|
|
+ headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
|
+ )
|
|
|
+
|
|
|
+@app.get("/health", response_class=PlainTextResponse)
|
|
|
+async def health():
|
|
|
+ ok = translator.is_alive() and (_tc_thread is not None and _tc_thread.is_alive())
|
|
|
+ return "ok" if ok else "dead"
|
|
|
+
|
|
|
+# 內建頁面:左右分欄(ASR / TRANS_TW)
|
|
|
+CLIENT_HTML = """
|
|
|
+<!doctype html>
|
|
|
+<html lang="zh-Hant">
|
|
|
+<head>
|
|
|
+<meta charset="utf-8" />
|
|
|
+<meta name="viewport" content="width=device-width,initial-scale=1" />
|
|
|
+<title>ASR / 翻譯 分欄顯示</title>
|
|
|
+<style>
|
|
|
+:root { --bg:#0b1020; --panel:#121833; --fg:#e7ebff; --muted:#9aa4c7; --blue:#7aa2ff; --green:#93e1a4; }
|
|
|
+* { box-sizing: border-box; }
|
|
|
+body { margin:0; background:var(--bg); color:var(--fg); font:14px/1.5 ui-monospace, SFMono-Regular, Menlo, Consolas, "Noto Sans TC", monospace; height:100vh; display:flex; flex-direction:column; }
|
|
|
+header { background:#101739; padding:10px 14px; border-bottom:1px solid #1f2650; display:flex; justify-content:space-between; align-items:center; }
|
|
|
+h1 { margin:0; font-size:16px; }
|
|
|
+.wrap { flex:1; display:grid; grid-template-columns:1fr 1fr; gap:8px; padding:8px; min-height:0; }
|
|
|
+.col { display:flex; flex-direction:column; border:1px solid #1f2650; border-radius:10px; background:var(--panel); min-height:0; }
|
|
|
+.col h2 { margin:0; padding:8px 10px; font-size:13px; border-bottom:1px solid #1f2650; }
|
|
|
+.stream { flex:1; overflow:auto; padding:10px; white-space:pre-wrap; word-break:break-word; }
|
|
|
+.asr .line::before { content:"ASR "; color:var(--blue); font-weight:700; margin-right:6px; }
|
|
|
+.trans .line::before { content:"TRANS_TW "; color:var(--green); font-weight:700; margin-right:6px; }
|
|
|
+.controls { display:flex; gap:8px; }
|
|
|
+button { background:var(--panel); color:var(--fg); border:1px solid #2a3570; border-radius:999px; padding:6px 12px; cursor:pointer; }
|
|
|
+button:hover { border-color:var(--blue); }
|
|
|
+</style>
|
|
|
+</head>
|
|
|
+<body>
|
|
|
+ <header>
|
|
|
+ <h1>ASR / 翻譯 分欄顯示</h1>
|
|
|
+ <div class="controls">
|
|
|
+ <button id="clear">清除</button>
|
|
|
+ <button id="freeze">暫停自動捲動</button>
|
|
|
+ </div>
|
|
|
+ </header>
|
|
|
+ <div class="wrap">
|
|
|
+ <section class="col asr">
|
|
|
+ <h2>英文 ASR</h2>
|
|
|
+ <div id="asr" class="stream"></div>
|
|
|
+ </section>
|
|
|
+ <section class="col trans">
|
|
|
+ <h2>中文 翻譯</h2>
|
|
|
+ <div id="trans" class="stream"></div>
|
|
|
+ </section>
|
|
|
+ </div>
|
|
|
+<script>
|
|
|
+let autoScroll = true;
|
|
|
+const asrEl = document.getElementById("asr");
|
|
|
+const transEl = document.getElementById("trans");
|
|
|
+document.getElementById("clear").onclick = () => { asrEl.textContent=""; transEl.textContent=""; };
|
|
|
+document.getElementById("freeze").onclick = (e) => { autoScroll = !autoScroll; e.target.textContent = autoScroll ? "暫停自動捲動" : "恢復自動捲動"; };
|
|
|
+
|
|
|
+function connect(url, container){
|
|
|
+ const es = new EventSource(url);
|
|
|
+ es.onmessage = (e) => {
|
|
|
+ const text = (e.data || "").replaceAll("\\\\n","\\n");
|
|
|
+ const div = document.createElement("div");
|
|
|
+ div.className = "line";
|
|
|
+ div.textContent = text;
|
|
|
+ container.appendChild(div);
|
|
|
+ if (autoScroll) container.scrollTop = container.scrollHeight;
|
|
|
+ };
|
|
|
+ es.onerror = () => {
|
|
|
+ const div = document.createElement("div");
|
|
|
+ div.className="line";
|
|
|
+ div.style.color="#ff9b9b";
|
|
|
+ div.textContent="[連線中斷,稍後自動重試]";
|
|
|
+ container.appendChild(div);
|
|
|
+ };
|
|
|
+}
|
|
|
+connect("/stream/asr", asrEl);
|
|
|
+connect("/stream/trans", transEl);
|
|
|
+</script>
|
|
|
+</body>
|
|
|
+</html>
|
|
|
+"""
|
|
|
+
|
|
|
+def main():
|
|
|
+ import argparse
|
|
|
+ p = argparse.ArgumentParser(description="ASR + Translator integrated with SSE/Front-end")
|
|
|
+ p.add_argument("--host", default=HOST_DEFAULT)
|
|
|
+ p.add_argument("--port", type=int, default=PORT_DEFAULT)
|
|
|
+ args = p.parse_args()
|
|
|
+ uvicorn.run("sse_app:app", host=args.host, port=args.port, reload=False)
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|