from langchain_core.language_models import BaseChatModel
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from pydantic import Field
import subprocess
from typing import Any, List, Optional
from tenacity import retry, stop_after_attempt, wait_random_exponential
from config import system_prompt
class OllamaChatModel(BaseChatModel):
model_name: str = Field(default="taide-local-3")
@retry(stop=stop_after_attempt(3), wait=wait_random_exponential(min=1, max=60))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
formatted_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
formatted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
formatted_messages.append({"role": "assistant", "content": msg.content})
elif isinstance(msg, SystemMessage):
formatted_messages.append({"role": "system", "content": msg.content})
prompt = f"[INST] {system_prompt}\n\n"
for msg in formatted_messages:
if msg['role'] == 'user':
prompt += f"{msg['content']} [/INST]"
elif msg['role'] == "assistant":
prompt += f"{msg['content']} [INST]"
command = ["ollama", "run", self.model_name, prompt]
try:
result = subprocess.run(command, capture_output=True, text=True, timeout=60) # 添加60秒超時
except subprocess.TimeoutExpired:
raise Exception("Ollama command timed out")
if result.returncode != 0:
raise Exception(f"Ollama command failed: {result.stderr}")
content = result.stdout.strip()
message = AIMessage(content=content)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "ollama-chat-model"