123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- # sse_app.py — 將 main_translate.py 與 SSE/前端整合為同一支 FastAPI 應用
- # 依賴:
- # pip install fastapi uvicorn[standard] transformers torch opencc-python-reimplemented sounddevice
- #
- # 啟動:
- # python sse_app.py --host 0.0.0.0 --port 7010
- import os
- import sys
- import threading
- import queue
- import atexit
- import asyncio
- import signal
- from typing import Optional, AsyncIterator, Set
- from contextlib import asynccontextmanager
- import torch
- from whisper_live.client import TranscriptionClient
- from transformers import MarianMTModel, MarianTokenizer
- import opencc # 簡→繁(台灣用詞)
- from fastapi import FastAPI
- from fastapi.responses import HTMLResponse, StreamingResponse, PlainTextResponse
- from fastapi.middleware.cors import CORSMiddleware
- import uvicorn
- from pathlib import Path
- # ===== 可調參數(你原本的) =====
- ASR_HOST = "192.168.192.83"
- ASR_PORT = 9090
- LANG = "en"
- WHISPER_MODEL = "small"
- USE_VAD = False
- PRINT_ASR_TO_STDERR = True # 英文印 stderr、中文印 stdout
- # ============================
- # ---------- 不緩衝 / 行緩衝輸出 ----------
- os.environ["PYTHONUNBUFFERED"] = "1"
- if hasattr(sys.stdout, "reconfigure"):
- try:
- sys.stdout.reconfigure(line_buffering=True)
- except Exception:
- pass
- if hasattr(sys.stderr, "reconfigure"):
- try:
- sys.stderr.reconfigure(line_buffering=True)
- except Exception:
- pass
- # ========== FastAPI 與 SSE 相關 ==========
- HOST_DEFAULT = "0.0.0.0"
- PORT_DEFAULT = 7010
- KEEPALIVE = 15.0
- QUEUE_SIZE = 500
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
- )
- # Asyncio event loop(在 lifespan 啟動時取得)
- _event_loop: Optional[asyncio.AbstractEventLoop] = None
- # SSE 訂閱者(每個訂閱者一個 asyncio.Queue)
- _asr_subscribers: Set[asyncio.Queue[str]] = set()
- _trans_subscribers: Set[asyncio.Queue[str]] = set()
- def _threadsafe_broadcast(asr_or_trans: str, line: str):
- """從任何執行緒安全地廣播到對應的 asyncio.Queue。"""
- global _event_loop
- subs = _asr_subscribers if asr_or_trans == "asr" else _trans_subscribers
- if _event_loop is None:
- return
- for q in list(subs):
- try:
- _event_loop.call_soon_threadsafe(q.put_nowait, line)
- except Exception:
- pass
- async def _event_stream(kind: str) -> AsyncIterator[bytes]:
- """SSE generator:kind ∈ {"asr","trans"}。"""
- q: asyncio.Queue[str] = asyncio.Queue(maxsize=QUEUE_SIZE)
- subs = _asr_subscribers if kind == "asr" else _trans_subscribers
- subs.add(q)
- try:
- # 初次 keepalive
- yield b": connected\n\n"
- while True:
- try:
- item = await asyncio.wait_for(q.get(), timeout=KEEPALIVE)
- data = item.replace("\r", "").replace("\n", "\\n")
- yield f"data: {data}\n\n".encode("utf-8")
- except asyncio.TimeoutError:
- yield b": keepalive\n\n"
- finally:
- subs.discard(q)
- # ========== 你的翻譯執行緒(原封不動 + 廣播) ==========
- def log_asr(*args, **kwargs):
- text = " ".join(str(a) for a in args)
- # 終端 tee
- print(*args, file=(sys.stderr if PRINT_ASR_TO_STDERR else sys.stdout), flush=True, **kwargs)
- # SSE 廣播
- _threadsafe_broadcast("asr", text)
- def log_trans(*args, **kwargs):
- text = " ".join(str(a) for a in args)
- # 終端 tee(中文用 stdout)
- print(*args, file=sys.stdout, flush=True, **kwargs)
- # SSE 廣播
- _threadsafe_broadcast("trans", text)
- class TranslatorWorker(threading.Thread):
- def __init__(self, model_name: str = "Helsinki-NLP/opus-mt-en-zh", max_len: int = 256):
- super().__init__(daemon=True)
- self.q: "queue.Queue[Optional[str]]" = queue.Queue(maxsize=500)
- self.max_len = max_len
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
- self.tok = MarianTokenizer.from_pretrained(model_name)
- self.mt = MarianMTModel.from_pretrained(model_name)
- self.mt.to(self.device)
- self.mt.eval()
- self.cc = opencc.OpenCC("s2twp")
- self._stop_evt = threading.Event()
- def translate(self, text: str) -> str:
- if not text or not text.strip():
- return ""
- with torch.no_grad():
- batch = self.tok([text.strip()], return_tensors="pt", padding=True,
- truncation=True, max_length=self.max_len)
- batch = {k: v.to(self.device) for k, v in batch.items()}
- gen = self.mt.generate(**batch, max_new_tokens=self.max_len)
- zh_simplified = self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip()
- return self.cc.convert(zh_simplified)
- def run(self):
- while not self._stop_evt.is_set():
- item = self.q.get()
- if item == "__STOP__":
- break
- try:
- zh = self.translate(item)
- if zh:
- log_trans("TRANS_TW:", zh)
- except Exception as e:
- log_trans(f"[TranslatorWorker] 翻譯失敗: {e}")
- def submit(self, text: str):
- if not text or not text.strip():
- return
- try:
- self.q.put_nowait(text)
- except queue.Full:
- try:
- _ = self.q.get_nowait()
- except queue.Empty:
- pass
- self.q.put_nowait(text)
- def stop(self):
- self._stop_evt.set()
- try:
- self.q.put_nowait("__STOP__")
- except queue.Full:
- try:
- _ = self.q.get_nowait()
- except queue.Empty:
- pass
- self.q.put_nowait("__STOP__")
- translator = TranslatorWorker(model_name="Helsinki-NLP/opus-mt-en-zh", max_len=256)
- # ---------- 英文輸出「去重」狀態 ----------
- _last_asr_printed = None # 僅用於避免連續印出完全相同的一段英文
- def on_asr(text: str, segments):
- global _last_asr_printed
- if not text:
- return
- normalized = text.strip()
- if normalized != (_last_asr_printed or ""):
- log_asr("ASR:", text)
- _last_asr_printed = normalized
- translator.submit(text)
- # ========== 啟動 / 停止 Whisper Live 的執行緒 ==========
- _tc: Optional[TranscriptionClient] = None
- _tc_thread: Optional[threading.Thread] = None
- def _tc_runner():
- global _tc
- try:
- _tc()
- except KeyboardInterrupt:
- pass
- except Exception as e:
- log_asr(f"[ASR Thread] error: {e}")
- def start_services():
- """在應用啟動時啟動 Translator 與 TranscriptionClient 執行緒。"""
- global _tc, _tc_thread
- if not translator.is_alive():
- translator.start()
- _tc = TranscriptionClient(
- host=ASR_HOST,
- port=ASR_PORT,
- lang=LANG,
- translate=False,
- model=WHISPER_MODEL,
- use_vad=USE_VAD,
- mute_audio_playback=True,
- save_output_recording=True,
- output_recording_filename="./output_recording.wav",
- transcription_callback=on_asr,
- )
- _tc_thread = threading.Thread(target=_tc_runner, name="ASRThread", daemon=True)
- _tc_thread.start()
- log_asr("[bridge] services started.")
- def stop_services():
- """在應用關閉時嘗試優雅停止。"""
- try:
- translator.stop()
- except Exception:
- pass
- # TranscriptionClient 沒有穩定的 stop API,就讓程序結束時一起回收
- log_asr("[bridge] services stopping...")
- # FastAPI lifespan(避免 on_event deprecation)
- @asynccontextmanager
- async def lifespan(app_: FastAPI):
- global _event_loop
- _event_loop = asyncio.get_running_loop()
- start_services()
- try:
- yield
- finally:
- stop_services()
- app.router.lifespan_context = lifespan
- # ========== 路由 ==========
- @app.get("/", response_class=HTMLResponse)
- async def index():
- html_path = Path(__file__).parent / "index.html"
- return HTMLResponse(content=html_path.read_text(encoding="utf-8"))
- @app.get("/stream/asr")
- async def stream_asr():
- return StreamingResponse(
- _event_stream("asr"),
- media_type="text/event-stream",
- headers={"Content-Type": "text/event-stream; charset=utf-8"},
- )
- @app.get("/stream/trans")
- async def stream_trans():
- return StreamingResponse(
- _event_stream("trans"),
- media_type="text/event-stream",
- headers={"Content-Type": "text/event-stream; charset=utf-8"},
- )
- @app.get("/health", response_class=PlainTextResponse)
- async def health():
- ok = translator.is_alive() and (_tc_thread is not None and _tc_thread.is_alive())
- return "ok" if ok else "dead"
- # 內建頁面:左右分欄(ASR / TRANS_TW)
- CLIENT_HTML = """
- <!doctype html>
- <html lang="zh-Hant">
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width,initial-scale=1" />
- <title>ASR / 翻譯 分欄顯示</title>
- <style>
- :root { --bg:#0b1020; --panel:#121833; --fg:#e7ebff; --muted:#9aa4c7; --blue:#7aa2ff; --green:#93e1a4; }
- * { box-sizing: border-box; }
- body { margin:0; background:var(--bg); color:var(--fg); font:14px/1.5 ui-monospace, SFMono-Regular, Menlo, Consolas, "Noto Sans TC", monospace; height:100vh; display:flex; flex-direction:column; }
- header { background:#101739; padding:10px 14px; border-bottom:1px solid #1f2650; display:flex; justify-content:space-between; align-items:center; }
- h1 { margin:0; font-size:16px; }
- .wrap { flex:1; display:grid; grid-template-columns:1fr 1fr; gap:8px; padding:8px; min-height:0; }
- .col { display:flex; flex-direction:column; border:1px solid #1f2650; border-radius:10px; background:var(--panel); min-height:0; }
- .col h2 { margin:0; padding:8px 10px; font-size:13px; border-bottom:1px solid #1f2650; }
- .stream { flex:1; overflow:auto; padding:10px; white-space:pre-wrap; word-break:break-word; }
- .asr .line::before { content:"ASR "; color:var(--blue); font-weight:700; margin-right:6px; }
- .trans .line::before { content:"TRANS_TW "; color:var(--green); font-weight:700; margin-right:6px; }
- .controls { display:flex; gap:8px; }
- button { background:var(--panel); color:var(--fg); border:1px solid #2a3570; border-radius:999px; padding:6px 12px; cursor:pointer; }
- button:hover { border-color:var(--blue); }
- </style>
- </head>
- <body>
- <header>
- <h1>ASR / 翻譯 分欄顯示</h1>
- <div class="controls">
- <button id="clear">清除</button>
- <button id="freeze">暫停自動捲動</button>
- </div>
- </header>
- <div class="wrap">
- <section class="col asr">
- <h2>英文 ASR</h2>
- <div id="asr" class="stream"></div>
- </section>
- <section class="col trans">
- <h2>中文 翻譯</h2>
- <div id="trans" class="stream"></div>
- </section>
- </div>
- <script>
- let autoScroll = true;
- const asrEl = document.getElementById("asr");
- const transEl = document.getElementById("trans");
- document.getElementById("clear").onclick = () => { asrEl.textContent=""; transEl.textContent=""; };
- document.getElementById("freeze").onclick = (e) => { autoScroll = !autoScroll; e.target.textContent = autoScroll ? "暫停自動捲動" : "恢復自動捲動"; };
- function connect(url, container){
- const es = new EventSource(url);
- es.onmessage = (e) => {
- const text = (e.data || "").replaceAll("\\\\n","\\n");
- const div = document.createElement("div");
- div.className = "line";
- div.textContent = text;
- container.appendChild(div);
- if (autoScroll) container.scrollTop = container.scrollHeight;
- };
- es.onerror = () => {
- const div = document.createElement("div");
- div.className="line";
- div.style.color="#ff9b9b";
- div.textContent="[連線中斷,稍後自動重試]";
- container.appendChild(div);
- };
- }
- connect("/stream/asr", asrEl);
- connect("/stream/trans", transEl);
- </script>
- </body>
- </html>
- """
- def main():
- import argparse
- p = argparse.ArgumentParser(description="ASR + Translator integrated with SSE/Front-end")
- p.add_argument("--host", default=HOST_DEFAULT)
- p.add_argument("--port", type=int, default=PORT_DEFAULT)
- args = p.parse_args()
- uvicorn.run("sse_app:app", host=args.host, port=args.port, reload=False)
- if __name__ == "__main__":
- main()
|