|
@@ -1,16 +1,18 @@
|
|
from typing import Any, List, Optional
|
|
from typing import Any, List, Optional
|
|
import subprocess
|
|
import subprocess
|
|
from fastapi import UploadFile, File, Form
|
|
from fastapi import UploadFile, File, Form
|
|
-from fastapi.responses import FileResponse
|
|
|
|
|
|
+from fastapi.responses import FileResponse, JSONResponse
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi import WebSocket, BackgroundTasks
|
|
from fastapi import WebSocket, BackgroundTasks
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm import Session
|
|
import asyncio
|
|
import asyncio
|
|
|
|
+import time
|
|
import shutil
|
|
import shutil
|
|
import app.crud as crud
|
|
import app.crud as crud
|
|
import app.models as models
|
|
import app.models as models
|
|
import app.schemas as schemas
|
|
import app.schemas as schemas
|
|
from app.api import deps
|
|
from app.api import deps
|
|
|
|
+from datetime import datetime
|
|
|
|
|
|
from app.core.celery_app import celery_app
|
|
from app.core.celery_app import celery_app
|
|
from app.core.config import settings
|
|
from app.core.config import settings
|
|
@@ -27,7 +29,7 @@ router = APIRouter()
|
|
|
|
|
|
|
|
|
|
@router.post("/sr")
|
|
@router.post("/sr")
|
|
-async def supser_resolution(
|
|
|
|
|
|
+def supser_resolution(
|
|
*,
|
|
*,
|
|
db: Session = Depends(deps.get_db),
|
|
db: Session = Depends(deps.get_db),
|
|
current_user: models.User = Depends(deps.get_current_active_user),
|
|
current_user: models.User = Depends(deps.get_current_active_user),
|
|
@@ -37,6 +39,7 @@ async def supser_resolution(
|
|
"""
|
|
"""
|
|
Super Resolution.
|
|
Super Resolution.
|
|
"""
|
|
"""
|
|
|
|
+ print("api start: "+str(datetime.now()))
|
|
filenames = [random_name(20)+Path(file.filename).suffix for file in upload_files]
|
|
filenames = [random_name(20)+Path(file.filename).suffix for file in upload_files]
|
|
stemnames = [Path(filename).stem for filename in filenames]
|
|
stemnames = [Path(filename).stem for filename in filenames]
|
|
new_dir = random_name(10)
|
|
new_dir = random_name(10)
|
|
@@ -52,19 +55,42 @@ async def supser_resolution(
|
|
return {"error": str(e)}
|
|
return {"error": str(e)}
|
|
finally:
|
|
finally:
|
|
upload_files[i].file.close()
|
|
upload_files[i].file.close()
|
|
- source = [f"{str(BACKEND_ZIP_STORAGE/filename)}" for filename in filenames]
|
|
|
|
- r = subprocess.run(["sshpass", "-p", "choozmo9",
|
|
|
|
- "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
|
|
|
|
-
|
|
|
|
- background_tasks.add_task(wait_finish, new_dir_path, stemnames)
|
|
|
|
|
|
|
|
|
|
+ background_tasks.add_task(wait_finish, new_dir, filenames)
|
|
|
|
+
|
|
print(filenames)
|
|
print(filenames)
|
|
- return {"filenames": stemnames}
|
|
|
|
|
|
+ return JSONResponse({"filenames": stemnames}, background=background_tasks)
|
|
|
|
|
|
async def wait_finish(dirname, filenames):
|
|
async def wait_finish(dirname, filenames):
|
|
|
|
+ # print("start: "+str(datetime.now()))
|
|
|
|
+ new_dir_path = Path(BACKEND_ZIP_STORAGE).joinpath(dirname)
|
|
|
|
+ process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9",
|
|
|
|
+ "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}")
|
|
|
|
+ await process.wait()
|
|
|
|
+ # r = subprocess.run(["sshpass", "-p", "choozmo9",
|
|
|
|
+ # "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
|
|
|
|
+
|
|
|
|
+ restored_imgs = "restored_imgs"
|
|
|
|
+ res = celery_app.send_task("app.worker.super_resolution", args=[dirname])
|
|
|
|
+ # print(res.state)
|
|
|
|
+ while True:
|
|
|
|
+ await asyncio.sleep(0.5)
|
|
|
|
+ print(res.state)
|
|
|
|
+ if res.status == 'SUCCESS':
|
|
|
|
+ print("recieve finished")
|
|
|
|
+ break
|
|
for filename in filenames:
|
|
for filename in filenames:
|
|
- await asyncio.sleep(3)
|
|
|
|
- await publish(filename)
|
|
|
|
|
|
+ process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9",
|
|
|
|
+ "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/dirname/restored_imgs/filename)}", f"{str(BACKEND_ZIP_STORAGE)}")
|
|
|
|
+ await process.wait()
|
|
|
|
+ # re = subprocess.run(["sshpass", "-p", "choozmo9",
|
|
|
|
+ # "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/dirname/restored_imgs/filename)}", f"{str(BACKEND_ZIP_STORAGE)}"])
|
|
|
|
+
|
|
|
|
+ for sr_client in sr_clients.values():
|
|
|
|
+ await sr_client.send_text(f"{filename}")
|
|
|
|
+ # print("end: "+str(datetime.now()))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
|
|
@router.get("/sr")
|
|
@router.get("/sr")
|
|
def get_image(
|
|
def get_image(
|
|
@@ -77,10 +103,11 @@ def get_image(
|
|
"""
|
|
"""
|
|
Download image
|
|
Download image
|
|
"""
|
|
"""
|
|
- filename = Path(file_name)
|
|
|
|
- response_filename = filename.stem + "_hr.png"
|
|
|
|
- return FileResponse(path="test_medias/superman_resolution.png", media_type='image/png', filename=response_filename)
|
|
|
|
|
|
+
|
|
|
|
+ return_file_path = list(BACKEND_ZIP_STORAGE.glob(stored_file_name+"*"))[0]
|
|
|
|
+ return FileResponse(path=str(return_file_path), media_type='image/png', filename=file_name+return_file_path.suffix)
|
|
|
|
|
|
|
|
+ws_clients = {}
|
|
sr_clients = {}
|
|
sr_clients = {}
|
|
|
|
|
|
@router.websocket("/sr")
|
|
@router.websocket("/sr")
|
|
@@ -88,12 +115,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
await websocket.accept()
|
|
key = websocket.headers.get('sec-websocket-key')
|
|
key = websocket.headers.get('sec-websocket-key')
|
|
sr_clients[key] = websocket
|
|
sr_clients[key] = websocket
|
|
|
|
+ print(f"new comer: {key}")
|
|
try:
|
|
try:
|
|
while True:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
data = await websocket.receive_text()
|
|
- if not data.startswith("subscribe"):
|
|
|
|
|
|
+ print(f"{key}:{data}")
|
|
|
|
+ if data.startswith("unsubscribe"):
|
|
del sr_clients[key]
|
|
del sr_clients[key]
|
|
- return
|
|
|
|
|
|
+ print(f"beybey: {key}")
|
|
#for client in sr_clients.values():
|
|
#for client in sr_clients.values():
|
|
# await client.send_text(f"ID: {key} | Message: {data}")
|
|
# await client.send_text(f"ID: {key} | Message: {data}")
|
|
|
|
|
|
@@ -101,7 +130,3 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
#await websocket.close()
|
|
#await websocket.close()
|
|
# 接続が切れた場合、当該クライアントを削除する
|
|
# 接続が切れた場合、当該クライアントを削除する
|
|
del sr_clients[key]
|
|
del sr_clients[key]
|
|
-
|
|
|
|
-async def publish(data):
|
|
|
|
- for sr_client in sr_clients.values():
|
|
|
|
- await sr_client.send_text(f"{data}")
|
|
|