|
@@ -0,0 +1,74 @@
|
|
|
+import subprocess
|
|
|
+import json
|
|
|
+from typing import Any, List, Optional, Dict
|
|
|
+from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
+from langchain_core.language_models import BaseChatModel
|
|
|
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
|
|
+from langchain_core.outputs import ChatResult, ChatGeneration
|
|
|
+from pydantic import Field
|
|
|
+
|
|
|
+class OllamaChatModel(BaseChatModel):
|
|
|
+ model_name: str = Field(default="taide-local")
|
|
|
+
|
|
|
+ def _generate(
|
|
|
+ self,
|
|
|
+ messages: List[BaseMessage],
|
|
|
+ stop: Optional[List[str]] = None,
|
|
|
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
+ **kwargs: Any,
|
|
|
+ ) -> ChatResult:
|
|
|
+ prompt = "\n".join([f"{msg.__class__.__name__}: {msg.content}" for msg in messages])
|
|
|
+
|
|
|
+ command = ["ollama", "run", self.model_name, prompt]
|
|
|
+ result = subprocess.run(command, capture_output=True, text=True)
|
|
|
+
|
|
|
+ if result.returncode != 0:
|
|
|
+ raise Exception(f"Ollama command failed: {result.stderr}")
|
|
|
+
|
|
|
+ content = result.stdout.strip()
|
|
|
+
|
|
|
+ message = AIMessage(content=content)
|
|
|
+ generation = ChatGeneration(message=message)
|
|
|
+ return ChatResult(generations=[generation])
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _llm_type(self) -> str:
|
|
|
+ return "ollama-chat-model"
|
|
|
+
|
|
|
+def check_model_availability(model_name: str):
|
|
|
+ result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
|
|
|
+ if result.returncode != 0:
|
|
|
+ print(f"Error checking model availability: {result.stderr}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ models = result.stdout.splitlines()
|
|
|
+ return any(model_name in model for model in models)
|
|
|
+
|
|
|
+# Usage example
|
|
|
+if __name__ == "__main__":
|
|
|
+ model_name = "taide-local"
|
|
|
+
|
|
|
+ print(f"Checking availability of model {model_name}...")
|
|
|
+ if not check_model_availability(model_name):
|
|
|
+ print(f"Model {model_name} is not available. Please check if it's correctly installed in Ollama.")
|
|
|
+ exit(1)
|
|
|
+
|
|
|
+ model = OllamaChatModel(model_name=model_name)
|
|
|
+
|
|
|
+ print(f"Starting chat with {model_name} model. Type 'exit' to quit.")
|
|
|
+
|
|
|
+ messages = []
|
|
|
+ while True:
|
|
|
+ user_input = input("You: ")
|
|
|
+ if user_input.lower() == 'exit':
|
|
|
+ break
|
|
|
+
|
|
|
+ messages.append(HumanMessage(content=user_input))
|
|
|
+ try:
|
|
|
+ response = model.invoke(messages)
|
|
|
+ print("AI:", response.content)
|
|
|
+ messages.append(AIMessage(content=response.content))
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error communicating with Ollama: {e}")
|
|
|
+
|
|
|
+print("Chat session ended. Goodbye!")
|