|
@@ -1,5 +1,5 @@
|
|
|
# fastapi
|
|
|
-from fastapi import FastAPI, Request, Response, HTTPException
|
|
|
+from fastapi import FastAPI, Request, Response, HTTPException, status, Depends
|
|
|
from fastapi import templating
|
|
|
from fastapi.templating import Jinja2Templates
|
|
|
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
|
@@ -16,14 +16,17 @@ import os
|
|
|
|
|
|
# time
|
|
|
import datetime
|
|
|
+from datetime import timedelta
|
|
|
|
|
|
# db
|
|
|
import dataset
|
|
|
from passlib import context
|
|
|
+import models
|
|
|
|
|
|
# authorize
|
|
|
from passlib.context import CryptContext
|
|
|
-# from jose import JWTError, jwt
|
|
|
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
+from jose import JWTError, jwt
|
|
|
from fastapi_jwt_auth import AuthJWT
|
|
|
from fastapi_jwt_auth.exceptions import AuthJWTException
|
|
|
from fastapi.security import OAuth2AuthorizationCodeBearer, OAuth2PasswordRequestForm
|
|
@@ -31,8 +34,19 @@ from fastapi.security import OAuth2AuthorizationCodeBearer, OAuth2PasswordReques
|
|
|
|
|
|
# app
|
|
|
app = FastAPI()
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"],
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"],
|
|
|
+)
|
|
|
|
|
|
|
|
|
+SECRET_KEY = "df2f77bd544240801a048bd4293afd8eeb7fff3cb7050e42c791db4b83ebadcd"
|
|
|
+ALGORITHM = "HS256"
|
|
|
+ACCESS_TOKEN_EXPIRE_MINUTES = 3000
|
|
|
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
#
|
|
|
app.mount(path='/templates', app=StaticFiles(directory='templates'), name='templates')
|
|
@@ -59,13 +73,41 @@ async def index(request: Request):
|
|
|
async def login(request: Request):
|
|
|
return templates.TemplateResponse(name='login.html', context={'request': request})
|
|
|
|
|
|
+
|
|
|
+@app.post("/login")
|
|
|
+async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
|
|
|
+ db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
|
|
|
+ user = authenticate_user(form_data.username, form_data.password)
|
|
|
+ if not user:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
+ detail="Incorrect username or password",
|
|
|
+ headers={"WWW-Authenticate": "Bearer"},
|
|
|
+ )
|
|
|
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
+ access_token = create_access_token(
|
|
|
+ data={"sub": user.username}, expires_delta=access_token_expires
|
|
|
+ )
|
|
|
+ table = db['users']
|
|
|
+ user.token = access_token
|
|
|
+ table.update(dict(user), ['username'])
|
|
|
+ access_token = Authorize.create_access_token(subject=user.username)
|
|
|
+ refresh_token = Authorize.create_refresh_token(subject=user.username)
|
|
|
+ Authorize.set_access_cookies(access_token)
|
|
|
+ Authorize.set_refresh_cookies(refresh_token)
|
|
|
+ #return templates.TemplateResponse("index.html", {"request": request, "msg": 'Login'})
|
|
|
+ return {"access_token": access_token, "token_type": "bearer"}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
@app.get('/home', response_class=HTMLResponse)
|
|
|
async def login(request: Request):
|
|
|
return templates.TemplateResponse(name='home.html', context={'request': request})
|
|
|
|
|
|
@app.get('/monitor/tower', response_class=HTMLResponse)
|
|
|
async def login(request: Request):
|
|
|
- return templates.TemplateResponse(name='home.html', context={'request': request})
|
|
|
+ return templates.TemplateResponse(name='tower.html', context={'request': request})
|
|
|
|
|
|
@app.get('/optim', response_class=HTMLResponse)
|
|
|
async def login(request: Request):
|
|
@@ -85,4 +127,64 @@ async def login(request: Request):
|
|
|
|
|
|
@app.get('/set_up/system', response_class=HTMLResponse)
|
|
|
async def login(request: Request):
|
|
|
- return templates.TemplateResponse(name='system.html', context={'request': request})
|
|
|
+ return templates.TemplateResponse(name='system.html', context={'request': request})
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# Login funtion part
|
|
|
+def check_user_exists(username):
|
|
|
+ db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
|
|
|
+ if int(next(iter(db.query('SELECT COUNT(*) FROM aaron_testdb.user WHERE username = "'+username+'"')))['COUNT(*)']) > 0:
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ return False
|
|
|
+
|
|
|
+def get_user(username: str):
|
|
|
+ """ 取得使用者資訊(Model) """
|
|
|
+ db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
|
|
|
+ if not check_user_exists(username): # if user don't exist
|
|
|
+ return False
|
|
|
+ user_dict = next(
|
|
|
+ iter(db.query('SELECT * FROM aaron_testdb.user where username ="'+username+'"')))
|
|
|
+ user = models.User(**user_dict)
|
|
|
+ return user
|
|
|
+
|
|
|
+# def user_register(user):
|
|
|
+# db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
|
|
|
+# table = db['users']
|
|
|
+# user.password = get_password_hash(user.password)
|
|
|
+# table.insert(dict(user))
|
|
|
+
|
|
|
+def get_password_hash(password):
|
|
|
+ """ 加密密碼 """
|
|
|
+ return pwd_context.hash(password)
|
|
|
+
|
|
|
+def verify_password(plain_password, hashed_password):
|
|
|
+ """ 驗證密碼(hashed) """
|
|
|
+ return pwd_context.verify(plain_password, hashed_password)
|
|
|
+
|
|
|
+def authenticate_user(username: str, password: str):
|
|
|
+ """ 連線DB,讀取使用者是否存在。 """
|
|
|
+ db = dataset.connect('mysql://choozmo:pAssw0rd@db.ptt.cx:3306/aaron_testdb?charset=utf8mb4')
|
|
|
+ if not check_user_exists(username): # if user don't exist
|
|
|
+ return False
|
|
|
+ user_dict = next(iter(db.query('SELECT * FROM aaron_testdb.user where username ="'+username+'"')))
|
|
|
+ user = models.User(**user_dict)
|
|
|
+ if not verify_password(password, user.password):
|
|
|
+ return False
|
|
|
+ return user
|
|
|
+
|
|
|
+def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
|
+ """ 創建token,並設定過期時間。 """
|
|
|
+ to_encode = data.copy()
|
|
|
+ if expires_delta:
|
|
|
+ expire = datetime.utcnow() + expires_delta
|
|
|
+ else:
|
|
|
+ expire = datetime.utcnow() + timedelta(minutes=15)
|
|
|
+ to_encode.update({"exp": expire})
|
|
|
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
+ return encoded_jwt
|