123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- from fastapi.openapi.docs import get_swagger_ui_html
- from typing import Optional
- from fastapi import Depends, FastAPI, HTTPException, status, Request, Form, Cookie, Response, Header
- from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
- from fastapi.templating import Jinja2Templates
- from fastapi.staticfiles import StaticFiles
- from passlib.context import CryptContext
- from jose import JWTError, jwt
- from fastapi_jwt_auth import AuthJWT
- from fastapi_jwt_auth.exceptions import AuthJWTException
- from datetime import datetime, timedelta
- from pydantic import BaseModel
- import dataset
- import pymysql
- pymysql.install_as_MySQLdb()
- # to get a string like this run:
- # openssl rand -hex 32
- SECRET_KEY = "df2f77bd544240801a048bd4293afd8eeb7fff3cb7050e42c791db4b83ebadcd"
- ALGORITHM = "HS256"
- ACCESS_TOKEN_EXPIRE_MINUTES = 300
- class Token(BaseModel):
- access_token: str
- token_type: str
- def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
- to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
- return encoded_jwt
- app = FastAPI()
- # app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
- app.mount("/static", StaticFiles(directory="static"), name="static")
- # connect DB
- db = dataset.connect(
- 'mysql://choozmo:pAssw0rd@db.ptt.cx:3306/AI_anchor?charset=utf8mb4')
- # password hash function
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- def get_password_hash(password):
- return pwd_context.hash(password)
- def verify_password(plain_password, hashed_password):
- return pwd_context.verify(plain_password, hashed_password)
- def authenticate_user(username: str, password: str):
- if not check_user_exists(username): # if user don't exist
- return False
- user_dict = next(iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"')))
- user = User(**user_dict)
-
- if not verify_password(password, user.password):
- return False
- return user
- # AuthJWT setting
- class Settings(BaseModel):
- authjwt_secret_key: str = SECRET_KEY
- # Configure application to store and get JWT from cookies
- authjwt_token_location: set = {"cookies"}
- # Only allow JWT cookies to be sent over https
- authjwt_cookie_secure: bool = False
- # Enable csrf double submit protection. default is True
- authjwt_cookie_csrf_protect: bool = True
- # Change to 'lax' in production to make your website more secure from CSRF Attacks, default is None
- # authjwt_cookie_samesite: str = 'lax'
- @AuthJWT.load_config
- def get_config():
- return Settings()
- @app.exception_handler(AuthJWTException)
- def authjwt_exception_handler(request: Request, exc: AuthJWTException):
- return JSONResponse(
- status_code=exc.status_code,
- content={"detail": exc.message}
- )
- class User(BaseModel):
- username: str
- email: str
- password: str
- token: Optional[str] = None
- class TokenData(BaseModel):
- username: Optional[str] = None
- def user_register(user):
- table = db['users']
- user.password = get_password_hash(user.password)
- table.insert(dict(user))
- templates = Jinja2Templates(directory="templates")
- # home page
- @app.get("/", response_class=HTMLResponse)
- async def get_home_page(request: Request, response: Response):
- return templates.TemplateResponse("index.html", {"request": request, "response": response})
- # login & register page
- @app.get("/login", response_class=HTMLResponse)
- async def get_login_and_register_page(request: Request):
-
- # ads_id: Optional[str] = Cookie(None)
- return templates.TemplateResponse("login.html", {"request": request})
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- def get_user(username: str):
- if not check_user_exists(username): # if user don't exist
- return False
- user_dict = next(
- iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"')))
- user = User(**user_dict)
- return user
- @app.post("/login")
- async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
- user = authenticate_user(form_data.username, form_data.password)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Bearer"},
- )
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- access_token = create_access_token(
- data={"sub": user.username}, expires_delta=access_token_expires
- )
- table = db['users']
- user.token = access_token
- table.update(dict(user), ['username'])
- # Create the tokens and passing to set_access_cookies or set_refresh_cookies
- access_token = Authorize.create_access_token(subject=user.username)
- refresh_token = Authorize.create_refresh_token(subject=user.username)
- # Set the JWT cookies in the response
- Authorize.set_access_cookies(access_token)
- Authorize.set_refresh_cookies(refresh_token)
- # return templates.TemplateResponse("index.html", {"request": request, "msg": 'Login'})
- return {"access_token": access_token, "token_type": "bearer"}
- @app.post("/token")
- async def access_token(form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
- user = authenticate_user(form_data.username, form_data.password)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Bearer"},
- )
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- access_token = create_access_token(
- data={"sub": user.username}, expires_delta=access_token_expires
- )
- return {"access_token": access_token, "token_type": "bearer"}
- async def get_current_user(token: str = Depends(oauth2_scheme)):
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- try:
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
- username: str = payload.get("sub")
- if username is None:
- raise credentials_exception
- token_data = TokenData(username=username)
- except JWTError:
- raise credentials_exception
- user = get_user(username=token_data.username)
- if user is None:
- raise credentials_exception
- return user
- def check_user_exists(username):
- if int(next(iter(db.query('SELECT COUNT(*) FROM AI_anchor.users WHERE username = "'+username+'"')))['COUNT(*)']) > 0:
- return True
- else:
- return False
- @app.post("/register")
- async def register(request: Request):
- user = User(**await request.form())
- user_register(user)
- return templates.TemplateResponse("login.html", {'request': request,"success": True}, status_code=status.HTTP_302_FOUND)
- @app.get('/user_profile', response_class=HTMLResponse)
- def protected(request: Request, Authorize: AuthJWT = Depends()):
- """
- We do not need to make any changes to our protected endpoints. They
- will all still function the exact same as they do when sending the
- JWT in via a headers instead of a cookies
- """
- Authorize.jwt_required()
- current_user = Authorize.get_jwt_subject()
- return current_user
- @app.get('/logout')
- def logout(request: Request, Authorize: AuthJWT = Depends()):
- """
- Because the JWT are stored in an httponly cookie now, we cannot
- log the user out by simply deleting the cookies in the frontend.
- We need the backend to send us a response to delete the cookies.
- """
- Authorize.jwt_required()
- Authorize.unset_jwt_cookies()
- return {"msg": "Successfully logout"}
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=9996)
|