# sse_app.py — 將 main_translate.py 與 SSE/前端整合為同一支 FastAPI 應用 # 依賴: # pip install fastapi uvicorn[standard] transformers torch opencc-python-reimplemented sounddevice # # 啟動: # python sse_app.py --host 0.0.0.0 --port 7010 import os import sys import threading import queue import atexit import asyncio import signal from typing import Optional, AsyncIterator, Set from contextlib import asynccontextmanager import torch from whisper_live.client import TranscriptionClient from transformers import MarianMTModel, MarianTokenizer import opencc # 簡→繁(台灣用詞) from fastapi import FastAPI from fastapi.responses import HTMLResponse, StreamingResponse, PlainTextResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn from pathlib import Path # ===== 可調參數(你原本的) ===== ASR_HOST = "192.168.192.83" ASR_PORT = 9090 LANG = "en" WHISPER_MODEL = "small" USE_VAD = False PRINT_ASR_TO_STDERR = True # 英文印 stderr、中文印 stdout # ============================ # ---------- 不緩衝 / 行緩衝輸出 ---------- os.environ["PYTHONUNBUFFERED"] = "1" if hasattr(sys.stdout, "reconfigure"): try: sys.stdout.reconfigure(line_buffering=True) except Exception: pass if hasattr(sys.stderr, "reconfigure"): try: sys.stderr.reconfigure(line_buffering=True) except Exception: pass # ========== FastAPI 與 SSE 相關 ========== HOST_DEFAULT = "0.0.0.0" PORT_DEFAULT = 7010 KEEPALIVE = 15.0 QUEUE_SIZE = 500 app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) # Asyncio event loop(在 lifespan 啟動時取得) _event_loop: Optional[asyncio.AbstractEventLoop] = None # SSE 訂閱者(每個訂閱者一個 asyncio.Queue) _asr_subscribers: Set[asyncio.Queue[str]] = set() _trans_subscribers: Set[asyncio.Queue[str]] = set() def _threadsafe_broadcast(asr_or_trans: str, line: str): """從任何執行緒安全地廣播到對應的 asyncio.Queue。""" global _event_loop subs = _asr_subscribers if asr_or_trans == "asr" else _trans_subscribers if _event_loop is None: return for q in list(subs): try: _event_loop.call_soon_threadsafe(q.put_nowait, line) except Exception: pass async def _event_stream(kind: str) -> AsyncIterator[bytes]: """SSE generator:kind ∈ {"asr","trans"}。""" q: asyncio.Queue[str] = asyncio.Queue(maxsize=QUEUE_SIZE) subs = _asr_subscribers if kind == "asr" else _trans_subscribers subs.add(q) try: # 初次 keepalive yield b": connected\n\n" while True: try: item = await asyncio.wait_for(q.get(), timeout=KEEPALIVE) data = item.replace("\r", "").replace("\n", "\\n") yield f"data: {data}\n\n".encode("utf-8") except asyncio.TimeoutError: yield b": keepalive\n\n" finally: subs.discard(q) # ========== 你的翻譯執行緒(原封不動 + 廣播) ========== def log_asr(*args, **kwargs): text = " ".join(str(a) for a in args) # 終端 tee print(*args, file=(sys.stderr if PRINT_ASR_TO_STDERR else sys.stdout), flush=True, **kwargs) # SSE 廣播 _threadsafe_broadcast("asr", text) def log_trans(*args, **kwargs): text = " ".join(str(a) for a in args) # 終端 tee(中文用 stdout) print(*args, file=sys.stdout, flush=True, **kwargs) # SSE 廣播 _threadsafe_broadcast("trans", text) class TranslatorWorker(threading.Thread): def __init__(self, model_name: str = "Helsinki-NLP/opus-mt-en-zh", max_len: int = 256): super().__init__(daemon=True) self.q: "queue.Queue[Optional[str]]" = queue.Queue(maxsize=500) self.max_len = max_len self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tok = MarianTokenizer.from_pretrained(model_name) self.mt = MarianMTModel.from_pretrained(model_name) self.mt.to(self.device) self.mt.eval() self.cc = opencc.OpenCC("s2twp") self._stop_evt = threading.Event() def translate(self, text: str) -> str: if not text or not text.strip(): return "" with torch.no_grad(): batch = self.tok([text.strip()], return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) batch = {k: v.to(self.device) for k, v in batch.items()} gen = self.mt.generate(**batch, max_new_tokens=self.max_len) zh_simplified = self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip() return self.cc.convert(zh_simplified) def run(self): while not self._stop_evt.is_set(): item = self.q.get() if item == "__STOP__": break try: zh = self.translate(item) if zh: log_trans("TRANS_TW:", zh) except Exception as e: log_trans(f"[TranslatorWorker] 翻譯失敗: {e}") def submit(self, text: str): if not text or not text.strip(): return try: self.q.put_nowait(text) except queue.Full: try: _ = self.q.get_nowait() except queue.Empty: pass self.q.put_nowait(text) def stop(self): self._stop_evt.set() try: self.q.put_nowait("__STOP__") except queue.Full: try: _ = self.q.get_nowait() except queue.Empty: pass self.q.put_nowait("__STOP__") translator = TranslatorWorker(model_name="Helsinki-NLP/opus-mt-en-zh", max_len=256) # ---------- 英文輸出「去重」狀態 ---------- _last_asr_printed = None # 僅用於避免連續印出完全相同的一段英文 def on_asr(text: str, segments): global _last_asr_printed if not text: return normalized = text.strip() if normalized != (_last_asr_printed or ""): log_asr("ASR:", text) _last_asr_printed = normalized translator.submit(text) # ========== 啟動 / 停止 Whisper Live 的執行緒 ========== _tc: Optional[TranscriptionClient] = None _tc_thread: Optional[threading.Thread] = None def _tc_runner(): global _tc try: _tc() except KeyboardInterrupt: pass except Exception as e: log_asr(f"[ASR Thread] error: {e}") def start_services(): """在應用啟動時啟動 Translator 與 TranscriptionClient 執行緒。""" global _tc, _tc_thread if not translator.is_alive(): translator.start() _tc = TranscriptionClient( host=ASR_HOST, port=ASR_PORT, lang=LANG, translate=False, model=WHISPER_MODEL, use_vad=USE_VAD, mute_audio_playback=True, save_output_recording=True, output_recording_filename="./output_recording.wav", transcription_callback=on_asr, ) _tc_thread = threading.Thread(target=_tc_runner, name="ASRThread", daemon=True) _tc_thread.start() log_asr("[bridge] services started.") def stop_services(): """在應用關閉時嘗試優雅停止。""" try: translator.stop() except Exception: pass # TranscriptionClient 沒有穩定的 stop API,就讓程序結束時一起回收 log_asr("[bridge] services stopping...") # FastAPI lifespan(避免 on_event deprecation) @asynccontextmanager async def lifespan(app_: FastAPI): global _event_loop _event_loop = asyncio.get_running_loop() start_services() try: yield finally: stop_services() app.router.lifespan_context = lifespan # ========== 路由 ========== @app.get("/", response_class=HTMLResponse) async def index(): html_path = Path(__file__).parent / "index.html" return HTMLResponse(content=html_path.read_text(encoding="utf-8")) @app.get("/stream/asr") async def stream_asr(): return StreamingResponse( _event_stream("asr"), media_type="text/event-stream", headers={"Content-Type": "text/event-stream; charset=utf-8"}, ) @app.get("/stream/trans") async def stream_trans(): return StreamingResponse( _event_stream("trans"), media_type="text/event-stream", headers={"Content-Type": "text/event-stream; charset=utf-8"}, ) @app.get("/health", response_class=PlainTextResponse) async def health(): ok = translator.is_alive() and (_tc_thread is not None and _tc_thread.is_alive()) return "ok" if ok else "dead" # 內建頁面:左右分欄(ASR / TRANS_TW) CLIENT_HTML = """