from fastapi.openapi.docs import get_swagger_ui_html from typing import Optional from fastapi import Depends, FastAPI, HTTPException, status, Request, Form, Cookie, Response, Header from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from passlib.context import CryptContext from jose import JWTError, jwt from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException from datetime import datetime, timedelta from pydantic import BaseModel import dataset import pymysql pymysql.install_as_MySQLdb() # to get a string like this run: # openssl rand -hex 32 SECRET_KEY = "df2f77bd544240801a048bd4293afd8eeb7fff3cb7050e42c791db4b83ebadcd" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 300 class Token(BaseModel): access_token: str token_type: str def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt app = FastAPI() # app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) app.mount("/static", StaticFiles(directory="static"), name="static") # connect DB db = dataset.connect( 'mysql://choozmo:pAssw0rd@db.ptt.cx:3306/AI_anchor?charset=utf8mb4') # password hash function pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def get_password_hash(password): return pwd_context.hash(password) def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) def authenticate_user(username: str, password: str): if not check_user_exists(username): # if user don't exist return False user_dict = next(iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"'))) user = User(**user_dict) if not verify_password(password, user.password): return False return user # AuthJWT setting class Settings(BaseModel): authjwt_secret_key: str = SECRET_KEY # Configure application to store and get JWT from cookies authjwt_token_location: set = {"cookies"} # Only allow JWT cookies to be sent over https authjwt_cookie_secure: bool = False # Enable csrf double submit protection. default is True authjwt_cookie_csrf_protect: bool = True # Change to 'lax' in production to make your website more secure from CSRF Attacks, default is None # authjwt_cookie_samesite: str = 'lax' @AuthJWT.load_config def get_config(): return Settings() @app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException): return JSONResponse( status_code=exc.status_code, content={"detail": exc.message} ) class User(BaseModel): username: str email: str password: str token: Optional[str] = None class TokenData(BaseModel): username: Optional[str] = None def user_register(user): table = db['users'] user.password = get_password_hash(user.password) table.insert(dict(user)) templates = Jinja2Templates(directory="templates") # home page @app.get("/", response_class=HTMLResponse) async def get_home_page(request: Request, response: Response): return templates.TemplateResponse("index.html", {"request": request, "response": response}) # login & register page @app.get("/login", response_class=HTMLResponse) async def get_login_and_register_page(request: Request): # ads_id: Optional[str] = Cookie(None) return templates.TemplateResponse("login.html", {"request": request}) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def get_user(username: str): if not check_user_exists(username): # if user don't exist return False user_dict = next( iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"'))) user = User(**user_dict) return user @app.post("/login") async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()): user = authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) table = db['users'] user.token = access_token table.update(dict(user), ['username']) # Create the tokens and passing to set_access_cookies or set_refresh_cookies access_token = Authorize.create_access_token(subject=user.username) refresh_token = Authorize.create_refresh_token(subject=user.username) # Set the JWT cookies in the response Authorize.set_access_cookies(access_token) Authorize.set_refresh_cookies(refresh_token) # return templates.TemplateResponse("index.html", {"request": request, "msg": 'Login'}) return {"access_token": access_token, "token_type": "bearer"} @app.post("/token") async def access_token(form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()): user = authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} async def get_current_user(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception token_data = TokenData(username=username) except JWTError: raise credentials_exception user = get_user(username=token_data.username) if user is None: raise credentials_exception return user def check_user_exists(username): if int(next(iter(db.query('SELECT COUNT(*) FROM AI_anchor.users WHERE username = "'+username+'"')))['COUNT(*)']) > 0: return True else: return False @app.post("/register") async def register(request: Request): user = User(**await request.form()) user_register(user) return templates.TemplateResponse("login.html", {'request': request,"success": True}, status_code=status.HTTP_302_FOUND) @app.get('/user_profile', response_class=HTMLResponse) def protected(request: Request, Authorize: AuthJWT = Depends()): """ We do not need to make any changes to our protected endpoints. They will all still function the exact same as they do when sending the JWT in via a headers instead of a cookies """ Authorize.jwt_required() current_user = Authorize.get_jwt_subject() return current_user @app.get('/logout') def logout(request: Request, Authorize: AuthJWT = Depends()): """ Because the JWT are stored in an httponly cookie now, we cannot log the user out by simply deleting the cookies in the frontend. We need the backend to send us a response to delete the cookies. """ Authorize.jwt_required() Authorize.unset_jwt_cookies() return {"msg": "Successfully logout"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=9996)