|
@@ -0,0 +1,267 @@
|
|
|
|
+from fastapi.openapi.docs import get_swagger_ui_html
|
|
|
|
+from typing import Optional
|
|
|
|
+
|
|
|
|
+from fastapi import Depends, FastAPI, HTTPException, status, Request, Form, Cookie, Response, Header
|
|
|
|
+from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
|
|
|
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
|
+from fastapi.templating import Jinja2Templates
|
|
|
|
+from fastapi.staticfiles import StaticFiles
|
|
|
|
+from passlib.context import CryptContext
|
|
|
|
+from jose import JWTError, jwt
|
|
|
|
+from fastapi_jwt_auth import AuthJWT
|
|
|
|
+from fastapi_jwt_auth.exceptions import AuthJWTException
|
|
|
|
+from datetime import datetime, timedelta
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
+
|
|
|
|
+import dataset
|
|
|
|
+import pymysql
|
|
|
|
+pymysql.install_as_MySQLdb()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# to get a string like this run:
|
|
|
|
+# openssl rand -hex 32
|
|
|
|
+SECRET_KEY = "df2f77bd544240801a048bd4293afd8eeb7fff3cb7050e42c791db4b83ebadcd"
|
|
|
|
+ALGORITHM = "HS256"
|
|
|
|
+ACCESS_TOKEN_EXPIRE_MINUTES = 300
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class Token(BaseModel):
|
|
|
|
+ access_token: str
|
|
|
|
+ token_type: str
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
|
|
+ to_encode = data.copy()
|
|
|
|
+ if expires_delta:
|
|
|
|
+ expire = datetime.utcnow() + expires_delta
|
|
|
|
+ else:
|
|
|
|
+ expire = datetime.utcnow() + timedelta(minutes=15)
|
|
|
|
+ to_encode.update({"exp": expire})
|
|
|
|
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
+ return encoded_jwt
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+app = FastAPI()
|
|
|
|
+# app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
|
|
|
|
+app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# connect DB
|
|
|
|
+db = dataset.connect(
|
|
|
|
+ 'mysql://choozmo:pAssw0rd@db.ptt.cx:3306/AI_anchor?charset=utf8mb4')
|
|
|
|
+
|
|
|
|
+# password hash function
|
|
|
|
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_password_hash(password):
|
|
|
|
+ return pwd_context.hash(password)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def verify_password(plain_password, hashed_password):
|
|
|
|
+ return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def authenticate_user(username: str, password: str):
|
|
|
|
+
|
|
|
|
+ if not check_user_exists(username): # if user don't exist
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ user_dict = next(iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"')))
|
|
|
|
+ user = User(**user_dict)
|
|
|
|
+
|
|
|
|
+ if not verify_password(password, user.password):
|
|
|
|
+ return False
|
|
|
|
+ return user
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# AuthJWT setting
|
|
|
|
+class Settings(BaseModel):
|
|
|
|
+ authjwt_secret_key: str = SECRET_KEY
|
|
|
|
+ # Configure application to store and get JWT from cookies
|
|
|
|
+ authjwt_token_location: set = {"cookies"}
|
|
|
|
+ # Only allow JWT cookies to be sent over https
|
|
|
|
+ authjwt_cookie_secure: bool = False
|
|
|
|
+ # Enable csrf double submit protection. default is True
|
|
|
|
+ authjwt_cookie_csrf_protect: bool = True
|
|
|
|
+ # Change to 'lax' in production to make your website more secure from CSRF Attacks, default is None
|
|
|
|
+ # authjwt_cookie_samesite: str = 'lax'
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@AuthJWT.load_config
|
|
|
|
+def get_config():
|
|
|
|
+ return Settings()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.exception_handler(AuthJWTException)
|
|
|
|
+def authjwt_exception_handler(request: Request, exc: AuthJWTException):
|
|
|
|
+ return JSONResponse(
|
|
|
|
+ status_code=exc.status_code,
|
|
|
|
+ content={"detail": exc.message}
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+class User(BaseModel):
|
|
|
|
+ username: str
|
|
|
|
+ email: str
|
|
|
|
+ password: str
|
|
|
|
+ token: Optional[str] = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class TokenData(BaseModel):
|
|
|
|
+ username: Optional[str] = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def user_register(user):
|
|
|
|
+ table = db['users']
|
|
|
|
+ user.password = get_password_hash(user.password)
|
|
|
|
+ table.insert(dict(user))
|
|
|
|
+
|
|
|
|
+templates = Jinja2Templates(directory="templates")
|
|
|
|
+
|
|
|
|
+# home page
|
|
|
|
+@app.get("/", response_class=HTMLResponse)
|
|
|
|
+async def get_home_page(request: Request, response: Response):
|
|
|
|
+ return templates.TemplateResponse("index.html", {"request": request, "response": response})
|
|
|
|
+
|
|
|
|
+# login & register page
|
|
|
|
+@app.get("/login", response_class=HTMLResponse)
|
|
|
|
+async def get_login_and_register_page(request: Request):
|
|
|
|
+
|
|
|
|
+ # ads_id: Optional[str] = Cookie(None)
|
|
|
|
+
|
|
|
|
+ return templates.TemplateResponse("login.html", {"request": request})
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_user(username: str):
|
|
|
|
+ if not check_user_exists(username): # if user don't exist
|
|
|
|
+ return False
|
|
|
|
+ user_dict = next(
|
|
|
|
+ iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"')))
|
|
|
|
+ user = User(**user_dict)
|
|
|
|
+ return user
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/login")
|
|
|
|
+async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
|
|
|
|
+
|
|
|
|
+ user = authenticate_user(form_data.username, form_data.password)
|
|
|
|
+ if not user:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
+ detail="Incorrect username or password",
|
|
|
|
+ headers={"WWW-Authenticate": "Bearer"},
|
|
|
|
+ )
|
|
|
|
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
+ access_token = create_access_token(
|
|
|
|
+ data={"sub": user.username}, expires_delta=access_token_expires
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ table = db['users']
|
|
|
|
+ user.token = access_token
|
|
|
|
+ table.update(dict(user), ['username'])
|
|
|
|
+
|
|
|
|
+ # Create the tokens and passing to set_access_cookies or set_refresh_cookies
|
|
|
|
+ access_token = Authorize.create_access_token(subject=user.username)
|
|
|
|
+ refresh_token = Authorize.create_refresh_token(subject=user.username)
|
|
|
|
+
|
|
|
|
+ # Set the JWT cookies in the response
|
|
|
|
+ Authorize.set_access_cookies(access_token)
|
|
|
|
+ Authorize.set_refresh_cookies(refresh_token)
|
|
|
|
+
|
|
|
|
+ # return templates.TemplateResponse("index.html", {"request": request, "msg": 'Login'})
|
|
|
|
+ return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/token")
|
|
|
|
+async def access_token(form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
|
|
|
|
+
|
|
|
|
+ user = authenticate_user(form_data.username, form_data.password)
|
|
|
|
+ if not user:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
+ detail="Incorrect username or password",
|
|
|
|
+ headers={"WWW-Authenticate": "Bearer"},
|
|
|
|
+ )
|
|
|
|
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
+ access_token = create_access_token(
|
|
|
|
+ data={"sub": user.username}, expires_delta=access_token_expires
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
|
|
+ credentials_exception = HTTPException(
|
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
+ detail="Could not validate credentials",
|
|
|
|
+ headers={"WWW-Authenticate": "Bearer"},
|
|
|
|
+ )
|
|
|
|
+ try:
|
|
|
|
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
+ username: str = payload.get("sub")
|
|
|
|
+ if username is None:
|
|
|
|
+ raise credentials_exception
|
|
|
|
+ token_data = TokenData(username=username)
|
|
|
|
+ except JWTError:
|
|
|
|
+ raise credentials_exception
|
|
|
|
+ user = get_user(username=token_data.username)
|
|
|
|
+ if user is None:
|
|
|
|
+ raise credentials_exception
|
|
|
|
+ return user
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def check_user_exists(username):
|
|
|
|
+ if int(next(iter(db.query('SELECT COUNT(*) FROM AI_anchor.users WHERE username = "'+username+'"')))['COUNT(*)']) > 0:
|
|
|
|
+ return True
|
|
|
|
+ else:
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/register")
|
|
|
|
+async def register(request: Request):
|
|
|
|
+ user = User(**await request.form())
|
|
|
|
+ user_register(user)
|
|
|
|
+ return templates.TemplateResponse("login.html", {'request': request,"success": True}, status_code=status.HTTP_302_FOUND)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get('/user_profile', response_class=HTMLResponse)
|
|
|
|
+def protected(request: Request, Authorize: AuthJWT = Depends()):
|
|
|
|
+ """
|
|
|
|
+ We do not need to make any changes to our protected endpoints. They
|
|
|
|
+ will all still function the exact same as they do when sending the
|
|
|
|
+ JWT in via a headers instead of a cookies
|
|
|
|
+ """
|
|
|
|
+ Authorize.jwt_required()
|
|
|
|
+
|
|
|
|
+ current_user = Authorize.get_jwt_subject()
|
|
|
|
+
|
|
|
|
+ return current_user
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get('/logout')
|
|
|
|
+def logout(request: Request, Authorize: AuthJWT = Depends()):
|
|
|
|
+ """
|
|
|
|
+ Because the JWT are stored in an httponly cookie now, we cannot
|
|
|
|
+ log the user out by simply deleting the cookies in the frontend.
|
|
|
|
+ We need the backend to send us a response to delete the cookies.
|
|
|
|
+ """
|
|
|
|
+ Authorize.jwt_required()
|
|
|
|
+
|
|
|
|
+ Authorize.unset_jwt_cookies()
|
|
|
|
+ return {"msg": "Successfully logout"}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=9996)
|