local_llm.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from langchain_community.chat_models import ChatOllama
  2. from langchain_openai import ChatOpenAI
  3. from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
  4. import torch
  5. from langchain_huggingface import HuggingFacePipeline
  6. from typing import Any, List, Optional, Dict
  7. from langchain_core.callbacks import CallbackManagerForLLMRun
  8. from langchain_core.language_models import BaseChatModel
  9. from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
  10. from langchain_core.outputs import ChatResult, ChatGeneration
  11. from pydantic import Field
  12. import subprocess
  13. import time
  14. from dotenv import load_dotenv
  15. load_dotenv()
  16. system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  17. def hf():
  18. model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
  19. tokenizer = AutoTokenizer.from_pretrained(model_id)
  20. llm = HuggingFacePipeline.from_model_id(
  21. model_id=model_id,
  22. task="text-generation",
  23. model_kwargs={"torch_dtype": torch.bfloat16},
  24. pipeline_kwargs={"return_full_text": False,
  25. "max_new_tokens": 512,
  26. "repetition_penalty":1.03},
  27. device=0, device_map='cuda')
  28. # print(llm.pipeline)
  29. llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
  30. return llm
  31. def ollama_():
  32. # model = "cwchang/llama3-taide-lx-8b-chat-alpha1"
  33. model = "llama3.1:latest"
  34. # model = "llama3.1:70b"
  35. # model = "893379029/piccolo-large-zh-v2"
  36. sys = "你是一個來自台灣的 AI 助理,,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。請用 5 句話以內回答問題。"
  37. # llm = ChatOllama(model=model, num_gpu=2, num_thread=32, temperature=0, system=sys, keep_alive="10m", verbose=True)
  38. llm = ChatOllama(model=model, num_gpu=2, temperature=0, system=sys, keep_alive="10m")
  39. return llm
  40. def openai_(): # not lacal
  41. llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
  42. return llm
  43. class OllamaChatModel(BaseChatModel):
  44. model_name: str = Field(default="taide-local-llama3")
  45. def _generate(
  46. self,
  47. messages: List[BaseMessage],
  48. stop: Optional[List[str]] = None,
  49. run_manager: Optional[CallbackManagerForLLMRun] = None,
  50. **kwargs: Any,
  51. ) -> ChatResult:
  52. formatted_messages = []
  53. for msg in messages:
  54. if isinstance(msg, HumanMessage):
  55. formatted_messages.append({"role": "user", "content": msg.content})
  56. elif isinstance(msg, AIMessage):
  57. formatted_messages.append({"role": "assistant", "content": msg.content})
  58. elif isinstance(msg, SystemMessage):
  59. formatted_messages.append({"role": "system", "content": msg.content})
  60. # prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" # TAIDE llama2
  61. prompt = f"<|begin_of_text|><|start_header_id|>{system_prompt}<|end_header_id|>" # TAIDE llama3
  62. for msg in formatted_messages:
  63. if msg['role'] == 'user':
  64. # prompt += f"{msg['content']} [/INST]" # TAIDE llama2
  65. prompt += f"<|eot_id|><|start_header_id|>{msg['content']}<|end_header_id|>" # TAIDE llama3
  66. elif msg['role'] == "assistant":
  67. # prompt += f"{msg['content']} </s><s>[INST]" # TAIDE llama2
  68. prompt += f"<|eot_id|><|start_header_id|>{msg['content']}<|end_header_id|>" # TAIDE llama3
  69. command = ["docker", "exec", "-it", "ollama", "ollama", "run", self.model_name, prompt]
  70. result = subprocess.run(command, capture_output=True, text=True)
  71. if result.returncode != 0:
  72. raise Exception(f"Ollama command failed: {result.stderr}")
  73. content = result.stdout.strip()
  74. message = AIMessage(content=content)
  75. generation = ChatGeneration(message=message)
  76. return ChatResult(generations=[generation])
  77. @property
  78. def _llm_type(self) -> str:
  79. return "ollama-chat-model"
  80. # taide_llm = OllamaChatModel(model_name="taide-local-llama2")
  81. if __name__ == "__main__":
  82. question = ""
  83. while question.lower() != "exit":
  84. question = input("Question: ")
  85. # 溫室氣體是什麼?
  86. for function in [ollama_, huggingface_, huggingface2_, openai_]:
  87. start = time.time()
  88. llm = function()
  89. answer = llm.invoke(question)
  90. print(answer)
  91. processing_time = time.time() - start
  92. print(processing_time)