main.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from fastapi.openapi.docs import get_swagger_ui_html
  2. from typing import Optional
  3. from fastapi import Depends, FastAPI, HTTPException, status, Request, Form, Cookie, Response, Header
  4. from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
  5. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  6. from fastapi.templating import Jinja2Templates
  7. from fastapi.staticfiles import StaticFiles
  8. from passlib.context import CryptContext
  9. from jose import JWTError, jwt
  10. from fastapi_jwt_auth import AuthJWT
  11. from fastapi_jwt_auth.exceptions import AuthJWTException
  12. from datetime import datetime, timedelta
  13. from pydantic import BaseModel
  14. import dataset
  15. import pymysql
  16. pymysql.install_as_MySQLdb()
  17. # to get a string like this run:
  18. # openssl rand -hex 32
  19. SECRET_KEY = "df2f77bd544240801a048bd4293afd8eeb7fff3cb7050e42c791db4b83ebadcd"
  20. ALGORITHM = "HS256"
  21. ACCESS_TOKEN_EXPIRE_MINUTES = 300
  22. class Token(BaseModel):
  23. access_token: str
  24. token_type: str
  25. def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
  26. to_encode = data.copy()
  27. if expires_delta:
  28. expire = datetime.utcnow() + expires_delta
  29. else:
  30. expire = datetime.utcnow() + timedelta(minutes=15)
  31. to_encode.update({"exp": expire})
  32. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  33. return encoded_jwt
  34. app = FastAPI()
  35. # app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
  36. app.mount("/static", StaticFiles(directory="static"), name="static")
  37. # connect DB
  38. db = dataset.connect(
  39. 'mysql://choozmo:pAssw0rd@db.ptt.cx:3306/AI_anchor?charset=utf8mb4')
  40. # password hash function
  41. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  42. def get_password_hash(password):
  43. return pwd_context.hash(password)
  44. def verify_password(plain_password, hashed_password):
  45. return pwd_context.verify(plain_password, hashed_password)
  46. def authenticate_user(username: str, password: str):
  47. if not check_user_exists(username): # if user don't exist
  48. return False
  49. user_dict = next(iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"')))
  50. user = User(**user_dict)
  51. if not verify_password(password, user.password):
  52. return False
  53. return user
  54. # AuthJWT setting
  55. class Settings(BaseModel):
  56. authjwt_secret_key: str = SECRET_KEY
  57. # Configure application to store and get JWT from cookies
  58. authjwt_token_location: set = {"cookies"}
  59. # Only allow JWT cookies to be sent over https
  60. authjwt_cookie_secure: bool = False
  61. # Enable csrf double submit protection. default is True
  62. authjwt_cookie_csrf_protect: bool = True
  63. # Change to 'lax' in production to make your website more secure from CSRF Attacks, default is None
  64. # authjwt_cookie_samesite: str = 'lax'
  65. @AuthJWT.load_config
  66. def get_config():
  67. return Settings()
  68. @app.exception_handler(AuthJWTException)
  69. def authjwt_exception_handler(request: Request, exc: AuthJWTException):
  70. return JSONResponse(
  71. status_code=exc.status_code,
  72. content={"detail": exc.message}
  73. )
  74. class User(BaseModel):
  75. username: str
  76. email: str
  77. password: str
  78. token: Optional[str] = None
  79. class TokenData(BaseModel):
  80. username: Optional[str] = None
  81. def user_register(user):
  82. table = db['users']
  83. user.password = get_password_hash(user.password)
  84. table.insert(dict(user))
  85. templates = Jinja2Templates(directory="templates")
  86. # home page
  87. @app.get("/", response_class=HTMLResponse)
  88. async def get_home_page(request: Request, response: Response):
  89. return templates.TemplateResponse("index.html", {"request": request, "response": response})
  90. # login & register page
  91. @app.get("/login", response_class=HTMLResponse)
  92. async def get_login_and_register_page(request: Request):
  93. # ads_id: Optional[str] = Cookie(None)
  94. return templates.TemplateResponse("login.html", {"request": request})
  95. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  96. def get_user(username: str):
  97. if not check_user_exists(username): # if user don't exist
  98. return False
  99. user_dict = next(
  100. iter(db.query('SELECT * FROM AI_anchor.users where username ="'+username+'"')))
  101. user = User(**user_dict)
  102. return user
  103. @app.post("/login")
  104. async def login_for_access_token(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
  105. user = authenticate_user(form_data.username, form_data.password)
  106. if not user:
  107. raise HTTPException(
  108. status_code=status.HTTP_401_UNAUTHORIZED,
  109. detail="Incorrect username or password",
  110. headers={"WWW-Authenticate": "Bearer"},
  111. )
  112. access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  113. access_token = create_access_token(
  114. data={"sub": user.username}, expires_delta=access_token_expires
  115. )
  116. table = db['users']
  117. user.token = access_token
  118. table.update(dict(user), ['username'])
  119. # Create the tokens and passing to set_access_cookies or set_refresh_cookies
  120. access_token = Authorize.create_access_token(subject=user.username)
  121. refresh_token = Authorize.create_refresh_token(subject=user.username)
  122. # Set the JWT cookies in the response
  123. Authorize.set_access_cookies(access_token)
  124. Authorize.set_refresh_cookies(refresh_token)
  125. # return templates.TemplateResponse("index.html", {"request": request, "msg": 'Login'})
  126. return {"access_token": access_token, "token_type": "bearer"}
  127. @app.post("/token")
  128. async def access_token(form_data: OAuth2PasswordRequestForm = Depends(), Authorize: AuthJWT = Depends()):
  129. user = authenticate_user(form_data.username, form_data.password)
  130. if not user:
  131. raise HTTPException(
  132. status_code=status.HTTP_401_UNAUTHORIZED,
  133. detail="Incorrect username or password",
  134. headers={"WWW-Authenticate": "Bearer"},
  135. )
  136. access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  137. access_token = create_access_token(
  138. data={"sub": user.username}, expires_delta=access_token_expires
  139. )
  140. return {"access_token": access_token, "token_type": "bearer"}
  141. async def get_current_user(token: str = Depends(oauth2_scheme)):
  142. credentials_exception = HTTPException(
  143. status_code=status.HTTP_401_UNAUTHORIZED,
  144. detail="Could not validate credentials",
  145. headers={"WWW-Authenticate": "Bearer"},
  146. )
  147. try:
  148. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  149. username: str = payload.get("sub")
  150. if username is None:
  151. raise credentials_exception
  152. token_data = TokenData(username=username)
  153. except JWTError:
  154. raise credentials_exception
  155. user = get_user(username=token_data.username)
  156. if user is None:
  157. raise credentials_exception
  158. return user
  159. def check_user_exists(username):
  160. if int(next(iter(db.query('SELECT COUNT(*) FROM AI_anchor.users WHERE username = "'+username+'"')))['COUNT(*)']) > 0:
  161. return True
  162. else:
  163. return False
  164. @app.post("/register")
  165. async def register(request: Request):
  166. user = User(**await request.form())
  167. user_register(user)
  168. return templates.TemplateResponse("login.html", {'request': request,"success": True}, status_code=status.HTTP_302_FOUND)
  169. @app.get('/user_profile', response_class=HTMLResponse)
  170. def protected(request: Request, Authorize: AuthJWT = Depends()):
  171. """
  172. We do not need to make any changes to our protected endpoints. They
  173. will all still function the exact same as they do when sending the
  174. JWT in via a headers instead of a cookies
  175. """
  176. Authorize.jwt_required()
  177. current_user = Authorize.get_jwt_subject()
  178. return current_user
  179. @app.get('/logout')
  180. def logout(request: Request, Authorize: AuthJWT = Depends()):
  181. """
  182. Because the JWT are stored in an httponly cookie now, we cannot
  183. log the user out by simply deleting the cookies in the frontend.
  184. We need the backend to send us a response to delete the cookies.
  185. """
  186. Authorize.jwt_required()
  187. Authorize.unset_jwt_cookies()
  188. return {"msg": "Successfully logout"}
  189. if __name__ == "__main__":
  190. uvicorn.run(app, host="0.0.0.0", port=9996)