main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # fastapi
  2. from fastapi import FastAPI, Request, Response, HTTPException, status, Depends
  3. from fastapi import templating
  4. from fastapi.templating import Jinja2Templates
  5. from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
  6. from fastapi.middleware.cors import CORSMiddleware
  7. # static file
  8. from fastapi.staticfiles import StaticFiles
  9. # fastapi view function parameters
  10. from typing import List, Optional
  11. # path
  12. import os
  13. # time
  14. # import datetime
  15. from datetime import timedelta, datetime
  16. # db
  17. import dataset
  18. from passlib import context
  19. import models
  20. # authorize
  21. from passlib.context import CryptContext
  22. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  23. from jose import JWTError, jwt
  24. from fastapi_jwt_auth import AuthJWT
  25. from fastapi_jwt_auth.exceptions import AuthJWTException
  26. from fastapi.security import OAuth2AuthorizationCodeBearer, OAuth2PasswordRequestForm
  27. import numpy as np
  28. import pymysql
  29. pymysql.install_as_MySQLdb()
  30. # app
  31. app = FastAPI()
  32. app.add_middleware(
  33. CORSMiddleware,
  34. allow_origins=["*"],
  35. allow_credentials=True,
  36. allow_methods=["*"],
  37. allow_headers=["*"],
  38. )
  39. SECRET_KEY = "df2f77bd544240801a048bd4293afd8eeb7fff3cb7050e42c791db4b83ebadcd"
  40. ALGORITHM = "HS256"
  41. ACCESS_TOKEN_EXPIRE_MINUTES = 3000
  42. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  43. #
  44. app.mount(path='/templates', app=StaticFiles(directory='templates'), name='templates')
  45. app.mount(path='/static', app=StaticFiles(directory='static'), name='static ')
  46. #
  47. templates = Jinja2Templates(directory='templates')
  48. @AuthJWT.load_config
  49. def get_config():
  50. return models.Settings()
  51. # view
  52. @app.get('/', response_class=HTMLResponse)
  53. async def index(request: Request):
  54. print(request)
  55. return templates.TemplateResponse(name='index.html', context={'request': request})
  56. @app.get('/login', response_class=HTMLResponse)
  57. async def login(request: Request):
  58. return templates.TemplateResponse(name='login_test.html', context={'request': request})
  59. @app.post("/login")
  60. async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
  61. db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
  62. user = authenticate_user(form_data.username, form_data.password)
  63. if not user:
  64. raise HTTPException(
  65. status_code=status.HTTP_401_UNAUTHORIZED,
  66. detail="Incorrect username or password",
  67. headers={"WWW-Authenticate": "Bearer"},
  68. )
  69. access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  70. access_token = create_access_token(
  71. data={"sub": user.username}, expires_delta=access_token_expires
  72. )
  73. table = db['users']
  74. user.token = access_token
  75. print(user)
  76. table.update(dict(user), ['username'])
  77. access_token = Authorize.create_access_token(subject=user.username)
  78. refresh_token = Authorize.create_refresh_token(subject=user.username)
  79. Authorize.set_access_cookies(access_token)
  80. Authorize.set_refresh_cookies(refresh_token)
  81. # return templates.TemplateResponse("home.html", {"request": request, "msg": 'Login'})
  82. return {"access_token": access_token, "token_type": "bearer"} # 回傳token給前端
  83. @app.get('/register', response_class=HTMLResponse)
  84. async def login(request: Request):
  85. return templates.TemplateResponse(name='rigister_test.html', context={'request': request})
  86. @app.post('/register')
  87. async def register(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
  88. user = models.User(**await request.form())
  89. print(form_data.username, form_data.password, user)
  90. # 密碼加密
  91. user.password = get_password_hash(user.password)
  92. # 存入DB
  93. db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
  94. user_table = db['users']
  95. user_table.insert(dict(user))
  96. # 跳轉頁面至登入
  97. return templates.TemplateResponse(name='login.html', context={'request': request})
  98. @app.get('/home', response_class=HTMLResponse)
  99. async def home(request: Request):
  100. return templates.TemplateResponse(name='home.html', context={'request': request})
  101. @app.get('/tower', response_class=HTMLResponse)
  102. async def tower(request: Request, Authorize: AuthJWT = Depends()):
  103. try:
  104. Authorize.jwt_required()
  105. except Exception as e:
  106. print(e)
  107. return RedirectResponse('/login')
  108. # current_user = Authorize.get_jwt_subject()
  109. return templates.TemplateResponse(name='tower.html', context={'request': request})
  110. @app.get('/optim', response_class=HTMLResponse)
  111. async def optim(request: Request, Authorize: AuthJWT = Depends()):
  112. try:
  113. Authorize.jwt_required()
  114. except Exception as e:
  115. print(e)
  116. return RedirectResponse('/login')
  117. # current_user = Authorize.get_jwt_subject()
  118. return templates.TemplateResponse(name='optim.html', context={'request': request})
  119. @app.get('/vibration', response_class=HTMLResponse)
  120. async def vibration(request: Request, Authorize: AuthJWT = Depends()):
  121. try:
  122. Authorize.jwt_required()
  123. except Exception as e:
  124. print(e)
  125. return RedirectResponse('/login')
  126. # current_user = Authorize.get_jwt_subject()
  127. return templates.TemplateResponse(name='vibration.html', context={'request': request})
  128. @app.get('/history', response_class=HTMLResponse)
  129. async def history(request: Request, Authorize: AuthJWT = Depends()):
  130. try:
  131. Authorize.jwt_required()
  132. except Exception as e:
  133. print(e)
  134. return RedirectResponse('/login')
  135. # current_user = Authorize.get_jwt_subject()
  136. return templates.TemplateResponse(name='history.html', context={'request': request})
  137. @app.get('/device', response_class=HTMLResponse)
  138. async def device(request: Request, Authorize: AuthJWT = Depends()):
  139. try:
  140. Authorize.jwt_required()
  141. except Exception as e:
  142. print(e)
  143. return RedirectResponse('/login')
  144. # current_user = Authorize.get_jwt_subject()
  145. return templates.TemplateResponse(name='device.html', context={'request': request})
  146. @app.get('/system', response_class=HTMLResponse)
  147. async def system(request: Request, Authorize: AuthJWT = Depends()):
  148. try:
  149. Authorize.jwt_required()
  150. except Exception as e:
  151. print(e)
  152. return RedirectResponse('/login')
  153. # current_user = Authorize.get_jwt_subject()
  154. return templates.TemplateResponse(name='system.html', context={'request': request})
  155. # 溫度API
  156. @app.get('/temperature')
  157. async def get_temperatures():
  158. """ 撈DB溫度 """
  159. return {'hot_water': 30.48, 'cold_water': 28.10, 'wet_ball': 25.14}
  160. @app.get('/health')
  161. async def get_health(date: str):
  162. """ 撈健康指標、預設健康指標 """
  163. date = str(datetime.strptime(date, "%Y-%m-%d"))[:10]
  164. print(date)
  165. print(str(datetime.today()))
  166. print(str(datetime.today()-timedelta(days=1)))
  167. fake_data = {
  168. str(datetime.today())[:10]: {'curr_health': 0.7, 'pred_health': 0.8},
  169. str(datetime.today()-timedelta(days=1))[:10]: {'curr_health': 0.6, 'pred_health': 0.7},
  170. }
  171. return fake_data[date]
  172. @app.get('/history_data')
  173. async def get_history(time_end: str):
  174. """ 透過終點時間,抓取歷史資料。 """
  175. date = str(datetime.strptime(time_end, "%Y-%m-%d"))[:10]
  176. print(date)
  177. print(str(datetime.today()))
  178. print(str(datetime.today()-timedelta(days=1)))
  179. fake_data = {
  180. str(datetime.today())[:10]: {
  181. 'curr_history': {
  182. 'RPM_1X': list(np.random.rand(13)),
  183. 'RPM_2X': list(np.random.rand(13)),
  184. 'RPM_3X': list(np.random.rand(13)),
  185. 'RPM_4X': list(np.random.rand(13)),
  186. 'RPM_5X': list(np.random.rand(13)),
  187. 'RPM_6X': list(np.random.rand(13)),
  188. 'RPM_7X': list(np.random.rand(13)),
  189. 'RPM_8X': list(np.random.rand(13)),
  190. 'Gear_1X': list(np.random.rand(13)),
  191. 'Gear_2X': list(np.random.rand(13)),
  192. 'Gear_3X': list(np.random.rand(13)),
  193. 'Gear_4X': list(np.random.rand(13)),
  194. },
  195. 'past_history': {
  196. 'RPM_1X': list(np.random.rand(13)),
  197. 'RPM_2X': list(np.random.rand(13)),
  198. 'RPM_3X': list(np.random.rand(13)),
  199. 'RPM_4X': list(np.random.rand(13)),
  200. 'RPM_5X': list(np.random.rand(13)),
  201. 'RPM_6X': list(np.random.rand(13)),
  202. 'RPM_7X': list(np.random.rand(13)),
  203. 'RPM_8X': list(np.random.rand(13)),
  204. 'Gear_1X': list(np.random.rand(13)),
  205. 'Gear_2X': list(np.random.rand(13)),
  206. 'Gear_3X': list(np.random.rand(13)),
  207. 'Gear_4X': list(np.random.rand(13)),
  208. }
  209. },
  210. str(datetime.today()-timedelta(days=1))[:10]: {
  211. 'curr_history': {
  212. 'RPM_1X': list(np.random.rand(13)),
  213. 'RPM_2X': list(np.random.rand(13)),
  214. 'RPM_3X': list(np.random.rand(13)),
  215. 'RPM_4X': list(np.random.rand(13)),
  216. 'RPM_5X': list(np.random.rand(13)),
  217. 'RPM_6X': list(np.random.rand(13)),
  218. 'RPM_7X': list(np.random.rand(13)),
  219. 'RPM_8X': list(np.random.rand(13)),
  220. 'Gear_1X': list(np.random.rand(13)),
  221. 'Gear_2X': list(np.random.rand(13)),
  222. 'Gear_3X': list(np.random.rand(13)),
  223. 'Gear_4X': list(np.random.rand(13)),
  224. },
  225. 'past_history': {
  226. 'RPM_1X': list(np.random.rand(13)),
  227. 'RPM_2X': list(np.random.rand(13)),
  228. 'RPM_3X': list(np.random.rand(13)),
  229. 'RPM_4X': list(np.random.rand(13)),
  230. 'RPM_5X': list(np.random.rand(13)),
  231. 'RPM_6X': list(np.random.rand(13)),
  232. 'RPM_7X': list(np.random.rand(13)),
  233. 'RPM_8X': list(np.random.rand(13)),
  234. 'Gear_1X': list(np.random.rand(13)),
  235. 'Gear_2X': list(np.random.rand(13)),
  236. 'Gear_3X': list(np.random.rand(13)),
  237. 'Gear_4X': list(np.random.rand(13)),
  238. }
  239. },
  240. }
  241. return fake_data[date]
  242. # Get data from db
  243. def get_data_from_db(query):
  244. db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_test_db?charset=utf8mb4')
  245. data = db.query(query=query)
  246. return data
  247. # Login funtion part
  248. def check_user_exists(username):
  249. db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
  250. if int(next(iter(db.query('SELECT COUNT(*) FROM aaron_testdb.users WHERE userName = "'+username+'"')))['COUNT(*)']) > 0:
  251. return True
  252. else:
  253. return False
  254. def get_user(username: str):
  255. """ 取得使用者資訊(Model) """
  256. db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
  257. if not check_user_exists(username): # if user don't exist
  258. return False
  259. user_dict = next(
  260. iter(db.query('SELECT * FROM aaron_testdb.users where userName ="'+username+'"')))
  261. user = models.User(**user_dict)
  262. return user
  263. def user_register(user):
  264. db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
  265. table = db['users']
  266. user.password = get_password_hash(user.password)
  267. table.insert(dict(user))
  268. def get_password_hash(password):
  269. """ 加密密碼 """
  270. return pwd_context.hash(password)
  271. def verify_password(plain_password, hashed_password):
  272. """ 驗證密碼(hashed) """
  273. return pwd_context.verify(plain_password, hashed_password)
  274. def authenticate_user(username: str, password: str):
  275. """ 連線DB,讀取使用者是否存在。 """
  276. db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
  277. if not check_user_exists(username): # if user don't exist
  278. return False
  279. user_dict = next(iter(db.query('SELECT * FROM aaron_testdb.users where userName ="'+username+'"')))
  280. user = models.User(**user_dict)
  281. if not verify_password(password, user.password):
  282. return False
  283. return user
  284. def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
  285. """ 創建token,並設定過期時間。 """
  286. to_encode = data.copy()
  287. if expires_delta:
  288. expire = datetime.utcnow() + expires_delta
  289. else:
  290. expire = datetime.utcnow() + timedelta(minutes=15)
  291. to_encode.update({"exp": expire})
  292. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  293. return encoded_jwt