from langchain_core.language_models import BaseChatModel from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.outputs import ChatResult, ChatGeneration from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage from pydantic import Field import subprocess from typing import Any, List, Optional from tenacity import retry, stop_after_attempt, wait_random_exponential from config import system_prompt class OllamaChatModel(BaseChatModel): model_name: str = Field(default="taide-local-3") @retry(stop=stop_after_attempt(3), wait=wait_random_exponential(min=1, max=60)) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: formatted_messages = [] for msg in messages: if isinstance(msg, HumanMessage): formatted_messages.append({"role": "user", "content": msg.content}) elif isinstance(msg, AIMessage): formatted_messages.append({"role": "assistant", "content": msg.content}) elif isinstance(msg, SystemMessage): formatted_messages.append({"role": "system", "content": msg.content}) prompt = f"[INST] {system_prompt}\n\n" for msg in formatted_messages: if msg['role'] == 'user': prompt += f"{msg['content']} [/INST]" elif msg['role'] == "assistant": prompt += f"{msg['content']} [INST]" command = ["ollama", "run", self.model_name, prompt] try: result = subprocess.run(command, capture_output=True, text=True, timeout=60) # 添加60秒超時 except subprocess.TimeoutExpired: raise Exception("Ollama command timed out") if result.returncode != 0: raise Exception(f"Ollama command failed: {result.stderr}") content = result.stdout.strip() message = AIMessage(content=content) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @property def _llm_type(self) -> str: return "ollama-chat-model"