audio_processing.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from openai import OpenAI
  2. from api.openai_scripts_new.config import SYSTEM_PROMPT, OPENAI_API_KEY, SUPABASE_KEY, SUPABASE_URL
  3. from supabase import create_client, Client
  4. from api.openai_scripts_new.text_processing import fuzzy_correct_chinese
  5. from transformers import pipeline
  6. from langchain_core.prompts import ChatPromptTemplate
  7. from langchain_community.callbacks import get_openai_callback
  8. from langchain_core.output_parsers import StrOutputParser
  9. from langchain_openai import ChatOpenAI
  10. import torchaudio
  11. import torch
  12. client = OpenAI(api_key=OPENAI_API_KEY)
  13. supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
  14. pipe = pipeline(model="linshoufan/linshoufanfork-whisper-small-nan-tw-pinyin")
  15. def transcribe(audio_data):
  16. try:
  17. # table_name = "word_database"
  18. # response = supabase.table(table_name).select("term").execute()
  19. # custom_vocab = []
  20. # if response.data:
  21. # for item in response.data:
  22. # custom_vocab.append({item['term']})
  23. # else:
  24. # print(f"No data found or an error occurred: {response.error}")
  25. # print("Using default dictionary as Supabase data couldn't be fetched.")
  26. # 如果音頻是立體聲,轉換為單聲道
  27. text = pipe(audio_data)["text"]
  28. print(transcript)
  29. # 翻譯台羅拼音為繁體中文
  30. model_name = "gpt-4o"
  31. llm = ChatOpenAI(model_name=model_name, temperature=0.7, api_key=OPENAI_API_KEY, max_tokens=4096)
  32. with get_openai_callback() as cb:
  33. qa_system_prompt = f"""你是一個專門翻譯台羅拼音的助理,可以將台語音精準的轉換成在繁體中文中的意思.
  34. 你是一名資深的大語言模型領域的專家,精通模型架構原理和落地應用實踐,只需要翻譯成繁體中文即可."""
  35. qa_prompt = ChatPromptTemplate.from_messages(
  36. [
  37. ("system", qa_system_prompt),
  38. ("human", "{transcript}"),
  39. ]
  40. )
  41. rag_chain = (
  42. qa_prompt
  43. | llm
  44. | StrOutputParser()
  45. )
  46. # session_id = "abc123" # 這應該是從某個上下文獲取的動態值
  47. # chat_history = get_session_history(session_id)
  48. text = rag_chain.invoke(
  49. {"transcript": transcript}
  50. )
  51. # # 更新聊天歷史
  52. # chat_history.add_user_message(inp)
  53. # chat_history.add_ai_message(text)
  54. # # chat_history.add_message({'role':HumanMessage(content=input), 'message':AIMessage(content=text)})
  55. # save_session_history(session_id, chat_history)
  56. print(f"Total Tokens: {cb.total_tokens}")
  57. print(f"Prompt Tokens: {cb.prompt_tokens}")
  58. print(f"Completion Tokens: {cb.completion_tokens}")
  59. print(f"Total Cost (USD): ${cb.total_cost}")
  60. return text
  61. except Exception as e:
  62. print(f"轉錄時發生錯誤:{str(e)}")
  63. return None
  64. def post_process_transcript(transcript, temperature=0):
  65. corrected_transcript = fuzzy_correct_chinese(transcript)
  66. messages = [
  67. {"role": "system", "content": SYSTEM_PROMPT},
  68. {"role": "user", "content": f"請校對並修正以下轉錄文本,但不要改變其原意或回答問題:\n\n{corrected_transcript}"}
  69. ]
  70. response = client.chat.completions.create(
  71. model="gpt-4",
  72. temperature=temperature,
  73. messages=messages
  74. )
  75. return response.choices[0].message.content
  76. def process_audio(audio_data):
  77. raw_transcript = transcribe(audio_data)
  78. print(raw_transcript)
  79. if raw_transcript is None:
  80. return None, None
  81. corrected_transcript = post_process_transcript(raw_transcript)
  82. return raw_transcript, corrected_transcript