123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- # fastapi
- from fastapi import FastAPI, Request, Response, HTTPException, status, Depends
- from fastapi import templating
- from fastapi.templating import Jinja2Templates
- from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
- from fastapi.middleware.cors import CORSMiddleware
- # static file
- from fastapi.staticfiles import StaticFiles
- # fastapi view function parameters
- from typing import List, Optional
- # path
- import os
- # time
- # import datetime
- from datetime import timedelta, datetime
- # db
- import dataset
- from passlib import context
- import models
- # authorize
- from passlib.context import CryptContext
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- from jose import JWTError, jwt
- from fastapi_jwt_auth import AuthJWT
- from fastapi_jwt_auth.exceptions import AuthJWTException
- from fastapi.security import OAuth2AuthorizationCodeBearer, OAuth2PasswordRequestForm
- import pymysql
- pymysql.install_as_MySQLdb()
- # app
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- SECRET_KEY = "df2f77bd544240801a048bd4293afd8eeb7fff3cb7050e42c791db4b83ebadcd"
- ALGORITHM = "HS256"
- ACCESS_TOKEN_EXPIRE_MINUTES = 3000
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- #
- app.mount(path='/templates', app=StaticFiles(directory='templates'), name='templates')
- app.mount(path='/static', app=StaticFiles(directory='static'), name='static ')
- #
- templates = Jinja2Templates(directory='templates')
- @AuthJWT.load_config
- def get_config():
- return models.Settings()
- # view
- @app.get('/', response_class=HTMLResponse)
- async def index(request: Request):
- print(request)
- return templates.TemplateResponse(name='index.html', context={'request': request})
-
- @app.get('/login', response_class=HTMLResponse)
- async def login(request: Request):
- return templates.TemplateResponse(name='login.html', context={'request': request})
- @app.post("/login")
- async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
- db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
- user = authenticate_user(form_data.username, form_data.password)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Bearer"},
- )
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- access_token = create_access_token(
- data={"sub": user.username}, expires_delta=access_token_expires
- )
- table = db['users']
- user.token = access_token
- print(user)
- table.update(dict(user), ['username'])
- access_token = Authorize.create_access_token(subject=user.username)
- refresh_token = Authorize.create_refresh_token(subject=user.username)
- Authorize.set_access_cookies(access_token)
- Authorize.set_refresh_cookies(refresh_token)
- # return templates.TemplateResponse("home.html", {"request": request, "msg": 'Login'})
- return {"access_token": access_token, "token_type": "bearer"} # 回傳token給前端
-
- @app.get('/register', response_class=HTMLResponse)
- async def login(request: Request):
- return templates.TemplateResponse(name='register.html', context={'request': request})
- @app.post('/register')
- async def register(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
- user = models.User(**await request.form())
- print(form_data.username, form_data.password, user)
-
- # 密碼加密
- user.password = get_password_hash(user.password)
-
- # 存入DB
- db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
- user_table = db['users']
- user_table.insert(dict(user))
-
- # 跳轉頁面至登入
- return templates.TemplateResponse(name='login.html', context={'request': request})
- @app.get('/home', response_class=HTMLResponse)
- async def home(request: Request):
- return templates.TemplateResponse(name='home.html', context={'request': request})
- @app.get('/tower', response_class=HTMLResponse)
- async def tower(request: Request):
- return templates.TemplateResponse(name='tower.html', context={'request': request})
- @app.get('/optim', response_class=HTMLResponse)
- async def optim(request: Request):
- return templates.TemplateResponse(name='optim.html', context={'request': request})
- @app.get('/vibration', response_class=HTMLResponse)
- async def vibration(request: Request):
- return templates.TemplateResponse(name='vibration.html', context={'request': request})
- @app.get('/history', response_class=HTMLResponse)
- async def history(request: Request):
- return templates.TemplateResponse(name='history.html', context={'request': request})
- @app.get('/device', response_class=HTMLResponse)
- async def device(request: Request):
- return templates.TemplateResponse(name='device.html', context={'request': request})
- @app.get('/system', response_class=HTMLResponse)
- async def system(request: Request):
- return templates.TemplateResponse(name='system.html', context={'request': request})
- # Login funtion part
- def check_user_exists(username):
- db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
- if int(next(iter(db.query('SELECT COUNT(*) FROM aaron_testdb.users WHERE userName = "'+username+'"')))['COUNT(*)']) > 0:
- return True
- else:
- return False
- def get_user(username: str):
- """ 取得使用者資訊(Model) """
- db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
- if not check_user_exists(username): # if user don't exist
- return False
- user_dict = next(
- iter(db.query('SELECT * FROM aaron_testdb.users where userName ="'+username+'"')))
- user = models.User(**user_dict)
- return user
-
- def user_register(user):
- db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
- table = db['users']
- user.password = get_password_hash(user.password)
- table.insert(dict(user))
- def get_password_hash(password):
- """ 加密密碼 """
- return pwd_context.hash(password)
- def verify_password(plain_password, hashed_password):
- """ 驗證密碼(hashed) """
- return pwd_context.verify(plain_password, hashed_password)
- def authenticate_user(username: str, password: str):
- """ 連線DB,讀取使用者是否存在。 """
- db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
- if not check_user_exists(username): # if user don't exist
- return False
- user_dict = next(iter(db.query('SELECT * FROM aaron_testdb.users where userName ="'+username+'"')))
- user = models.User(**user_dict)
- if not verify_password(password, user.password):
- return False
- return user
- def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
- """ 創建token,並設定過期時間。 """
- to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
- return encoded_jwt
|