12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- from langchain_core.language_models import BaseChatModel
- from langchain_core.callbacks import CallbackManagerForLLMRun
- from langchain_core.outputs import ChatResult, ChatGeneration
- from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
- from pydantic import Field
- import subprocess
- from typing import Any, List, Optional
- from tenacity import retry, stop_after_attempt, wait_random_exponential
- from config import system_prompt
- class OllamaChatModel(BaseChatModel):
- model_name: str = Field(default="taide-local-3")
- @retry(stop=stop_after_attempt(3), wait=wait_random_exponential(min=1, max=60))
- def _generate(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> ChatResult:
- formatted_messages = []
- for msg in messages:
- if isinstance(msg, HumanMessage):
- formatted_messages.append({"role": "user", "content": msg.content})
- elif isinstance(msg, AIMessage):
- formatted_messages.append({"role": "assistant", "content": msg.content})
- elif isinstance(msg, SystemMessage):
- formatted_messages.append({"role": "system", "content": msg.content})
- prompt = f"<s>[INST] {system_prompt}\n\n"
- for msg in formatted_messages:
- if msg['role'] == 'user':
- prompt += f"{msg['content']} [/INST]"
- elif msg['role'] == "assistant":
- prompt += f"{msg['content']} </s><s>[INST]"
- command = ["ollama", "run", self.model_name, prompt]
- try:
- result = subprocess.run(command, capture_output=True, text=True, timeout=60) # 添加60秒超時
- except subprocess.TimeoutExpired:
- raise Exception("Ollama command timed out")
- if result.returncode != 0:
- raise Exception(f"Ollama command failed: {result.stderr}")
-
- content = result.stdout.strip()
- message = AIMessage(content=content)
- generation = ChatGeneration(message=message)
- return ChatResult(generations=[generation])
-
- @property
- def _llm_type(self) -> str:
- return "ollama-chat-model"
|