| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 | from langchain_core.language_models import BaseChatModelfrom langchain_core.callbacks import CallbackManagerForLLMRunfrom langchain_core.outputs import ChatResult, ChatGenerationfrom langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessagefrom pydantic import Fieldimport subprocessfrom typing import Any, List, Optionalfrom tenacity import retry, stop_after_attempt, wait_random_exponentialfrom config import system_promptclass OllamaChatModel(BaseChatModel):    model_name: str = Field(default="taide-local-3")    @retry(stop=stop_after_attempt(3), wait=wait_random_exponential(min=1, max=60))    def _generate(            self,            messages: List[BaseMessage],            stop: Optional[List[str]] = None,            run_manager: Optional[CallbackManagerForLLMRun] = None,            **kwargs: Any,    ) -> ChatResult:        formatted_messages = []        for msg in messages:            if isinstance(msg, HumanMessage):                formatted_messages.append({"role": "user", "content": msg.content})            elif isinstance(msg, AIMessage):                formatted_messages.append({"role": "assistant", "content": msg.content})            elif isinstance(msg, SystemMessage):                 formatted_messages.append({"role": "system", "content": msg.content})        prompt = f"<s>[INST] {system_prompt}\n\n"        for msg in formatted_messages:            if msg['role'] == 'user':                prompt += f"{msg['content']} [/INST]"            elif msg['role'] == "assistant":                prompt += f"{msg['content']} </s><s>[INST]"        command = ["ollama", "run", self.model_name, prompt]        try:            result = subprocess.run(command, capture_output=True, text=True, timeout=60)  # 添加60秒超時        except subprocess.TimeoutExpired:            raise Exception("Ollama command timed out")        if result.returncode != 0:            raise Exception(f"Ollama command failed: {result.stderr}")                content = result.stdout.strip()        message = AIMessage(content=content)        generation = ChatGeneration(message=message)        return ChatResult(generations=[generation])        @property    def _llm_type(self) -> str:        return "ollama-chat-model"
 |