models.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from langchain_core.language_models import BaseChatModel
  2. from langchain_core.callbacks import CallbackManagerForLLMRun
  3. from langchain_core.outputs import ChatResult, ChatGeneration
  4. from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
  5. from pydantic import Field
  6. import subprocess
  7. from typing import Any, List, Optional
  8. from tenacity import retry, stop_after_attempt, wait_random_exponential
  9. from config import system_prompt
  10. class OllamaChatModel(BaseChatModel):
  11. model_name: str = Field(default="taide-local-3")
  12. @retry(stop=stop_after_attempt(3), wait=wait_random_exponential(min=1, max=60))
  13. def _generate(
  14. self,
  15. messages: List[BaseMessage],
  16. stop: Optional[List[str]] = None,
  17. run_manager: Optional[CallbackManagerForLLMRun] = None,
  18. **kwargs: Any,
  19. ) -> ChatResult:
  20. formatted_messages = []
  21. for msg in messages:
  22. if isinstance(msg, HumanMessage):
  23. formatted_messages.append({"role": "user", "content": msg.content})
  24. elif isinstance(msg, AIMessage):
  25. formatted_messages.append({"role": "assistant", "content": msg.content})
  26. elif isinstance(msg, SystemMessage):
  27. formatted_messages.append({"role": "system", "content": msg.content})
  28. prompt = f"<s>[INST] {system_prompt}\n\n"
  29. for msg in formatted_messages:
  30. if msg['role'] == 'user':
  31. prompt += f"{msg['content']} [/INST]"
  32. elif msg['role'] == "assistant":
  33. prompt += f"{msg['content']} </s><s>[INST]"
  34. command = ["ollama", "run", self.model_name, prompt]
  35. try:
  36. result = subprocess.run(command, capture_output=True, text=True, timeout=60) # 添加60秒超時
  37. except subprocess.TimeoutExpired:
  38. raise Exception("Ollama command timed out")
  39. if result.returncode != 0:
  40. raise Exception(f"Ollama command failed: {result.stderr}")
  41. content = result.stdout.strip()
  42. message = AIMessage(content=content)
  43. generation = ChatGeneration(message=message)
  44. return ChatResult(generations=[generation])
  45. @property
  46. def _llm_type(self) -> str:
  47. return "ollama-chat-model"