tomoya il y a 2 ans
Parent
commit
1b731a22ac

+ 44 - 19
backend/app/app/api/api_v1/endpoints/images.py

@@ -1,16 +1,18 @@
 from typing import Any, List, Optional
 import subprocess
 from fastapi import UploadFile, File, Form
-from fastapi.responses import FileResponse
+from fastapi.responses import FileResponse, JSONResponse
 from fastapi import APIRouter, Depends, HTTPException
 from fastapi import WebSocket, BackgroundTasks
 from sqlalchemy.orm import Session
 import asyncio
+import time
 import shutil
 import app.crud as crud
 import app.models as models
 import app.schemas as schemas 
 from app.api import deps
+from datetime import datetime
 
 from app.core.celery_app import celery_app
 from app.core.config import settings
@@ -27,7 +29,7 @@ router = APIRouter()
 
 
 @router.post("/sr")
-async def supser_resolution(
+def supser_resolution(
     *,
     db: Session = Depends(deps.get_db),
     current_user: models.User = Depends(deps.get_current_active_user),
@@ -37,6 +39,7 @@ async def supser_resolution(
     """
     Super Resolution.
     """
+    print("api start: "+str(datetime.now()))
     filenames = [random_name(20)+Path(file.filename).suffix for file in upload_files]
     stemnames = [Path(filename).stem for filename in filenames]
     new_dir = random_name(10)
@@ -52,19 +55,42 @@ async def supser_resolution(
           return {"error": str(e)}
       finally:
           upload_files[i].file.close()
-    source = [f"{str(BACKEND_ZIP_STORAGE/filename)}" for filename in filenames]
-    r = subprocess.run(["sshpass", "-p", "choozmo9", 
-                    "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
-    
-    background_tasks.add_task(wait_finish, new_dir_path, stemnames)
     
+    background_tasks.add_task(wait_finish, new_dir, filenames)
+
     print(filenames)
-    return {"filenames": stemnames}
+    return JSONResponse({"filenames": stemnames}, background=background_tasks)
 
 async def wait_finish(dirname, filenames):
+    # print("start: "+str(datetime.now()))
+    new_dir_path = Path(BACKEND_ZIP_STORAGE).joinpath(dirname)
+    process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9", 
+                    "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}")
+    await process.wait()
+    # r = subprocess.run(["sshpass", "-p", "choozmo9", 
+    #                 "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
+    
+    restored_imgs = "restored_imgs" 
+    res = celery_app.send_task("app.worker.super_resolution", args=[dirname])
+    # print(res.state)
+    while True:
+        await asyncio.sleep(0.5)
+        print(res.state)
+        if res.status == 'SUCCESS':
+            print("recieve finished")
+            break
     for filename in filenames:
-      await asyncio.sleep(3)
-      await publish(filename)
+        process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9", 
+                      "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/dirname/restored_imgs/filename)}", f"{str(BACKEND_ZIP_STORAGE)}")
+        await process.wait()
+        # re = subprocess.run(["sshpass", "-p", "choozmo9", 
+        #             "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/dirname/restored_imgs/filename)}", f"{str(BACKEND_ZIP_STORAGE)}"])
+        
+        for sr_client in sr_clients.values():
+            await sr_client.send_text(f"{filename}")
+    # print("end: "+str(datetime.now()))
+
+    
 
 @router.get("/sr")
 def get_image(
@@ -77,10 +103,11 @@ def get_image(
     """
     Download image
     """
-    filename = Path(file_name)
-    response_filename = filename.stem + "_hr.png"
-    return FileResponse(path="test_medias/superman_resolution.png", media_type='image/png', filename=response_filename)
+    
+    return_file_path = list(BACKEND_ZIP_STORAGE.glob(stored_file_name+"*"))[0]
+    return FileResponse(path=str(return_file_path), media_type='image/png', filename=file_name+return_file_path.suffix)
 
+ws_clients = {}
 sr_clients = {}
 
 @router.websocket("/sr")
@@ -88,12 +115,14 @@ async def websocket_endpoint(websocket: WebSocket):
     await websocket.accept()
     key = websocket.headers.get('sec-websocket-key')
     sr_clients[key] = websocket
+    print(f"new comer: {key}")
     try:
         while True:
             data = await websocket.receive_text()
-            if not data.startswith("subscribe"):
+            print(f"{key}:{data}")
+            if data.startswith("unsubscribe"):
               del sr_clients[key]
-              return 
+              print(f"beybey: {key}")
               #for client in sr_clients.values():
               #      await client.send_text(f"ID: {key} | Message: {data}")
 
@@ -101,7 +130,3 @@ async def websocket_endpoint(websocket: WebSocket):
         #await websocket.close()
         # 接続が切れた場合、当該クライアントを削除する
         del sr_clients[key]
-
-async def publish(data):
-    for sr_client in sr_clients.values():
-        await sr_client.send_text(f"{data}")

+ 6 - 3
backend/app/app/api/api_v1/endpoints/videos.py

@@ -56,9 +56,7 @@ def upload_plot(
     print(title)
     print(upload_file.filename)
     file_name = crud.video.generate_file_name(db=db, n=20)
-    video_create = schemas.VideoCreate(title=title, progress_state="waiting", stored_file_name=file_name)
-    video = crud.video.create_with_owner(db=db, obj_in=video_create, owner_id=current_user.id)
-
+    
     try:
         with open(str(Path(BACKEND_ZIP_STORAGE).joinpath(video.stored_file_name+".zip")), 'wb') as f:
             while contents := upload_file.file.read(1024 * 1024):
@@ -68,6 +66,11 @@ def upload_plot(
         return {"error": str(e)}
     finally:
         upload_file.file.close()
+
+    # check valid file
+    video_create = schemas.VideoCreate(title=title, progress_state="waiting", stored_file_name=file_name)
+    video = crud.video.create_with_owner(db=db, obj_in=video_create, owner_id=current_user.id)
+
     zip_filename = video.stored_file_name+".zip"
     print(str(BACKEND_ZIP_STORAGE/zip_filename))
     r = subprocess.run(["sshpass", "-p", "choozmo9", 

+ 1 - 0
backend/app/app/celeryconf.py

@@ -0,0 +1 @@
+task_track_started=True

+ 2 - 2
backend/app/app/core/celery_app.py

@@ -1,7 +1,7 @@
 from celery import Celery
 
-celery_app = Celery("worker", broker="redis://172.104.93.163:16379/0")
+celery_app = Celery("worker", broker="redis://172.104.93.163:16379/0", backend="redis://172.104.93.163:16379/0")
 
 
 
-celery_app.conf.task_routes = {"app.worker.make_video": "main-queue"}
+celery_app.conf.task_routes = {"app.worker.make_video": "main-queue", "app.worker.super_resolution": "main-queue"}

+ 34 - 22
backend/app/app/worker.py

@@ -11,6 +11,7 @@ import dataset
 from app.db.session import SessionLocal
 from app.models import video
 from app import crud
+import gc
 download_to_local_url = urljoin(settings.SERVER_HOST, settings.API_V1_STR+"/videos/worker")
 upload_to_server_url = urljoin(settings.SERVER_HOST, settings.API_V1_STR+"/videos/worker")
 
@@ -23,8 +24,8 @@ STORAGE_IP = '192.168.192.252'#os.getenv('STORAGE_IP')
 if not STORAGE_IP:
     raise Exception
 
-@celery_app.task(acks_late=True)
-def make_video(video_id, filename, user_id) -> str:
+@celery_app.task(acks_late=True, bind=True, track_started=True)
+def make_video(self, video_id, filename, user_id) -> str:
     #video_id, zip_filename, user_id = args
     # download 
     '''
@@ -45,29 +46,40 @@ def make_video(video_id, filename, user_id) -> str:
     db.commit()
     # make video
     try:
-      make_video_from_zip(working_dir=CELERY_ZIP_STORAGE,style=Path("app/style/choozmo"),  inputfile=zip_filename,opening=False, ending=False)
+      content_time = make_video_from_zip(working_dir=CELERY_ZIP_STORAGE,style=Path("app/style/choozmo"),
+                                         inputfile=zip_filename,
+                                         opening=False, 
+                                         ending=False,
+                                         watermark_path='medias/logo_watermark.jpg')
     except Exception as e:
-      print(f"error:{e}")
-      db.execute(f"UPDATE video SET progress_state='failed' where id={video_id}")
-      db.commit()
+        print(f"error:{e}")
+        db.execute(f"UPDATE video SET progress_state='failed' where id={video_id}")
+        db.commit()
     else:
-      # 
-      video_filename = filename + ".mp4"
-      r = subprocess.run(["sshpass", "-p", "choozmo9", 
-                          "scp", "-o", "StrictHostKeyChecking=no", f"{str(CELERY_ZIP_STORAGE/'output.mp4')}", f"root@{STORAGE_IP}:{'/var/www/videos/'+video_filename}"])
-      print(f"return to local storage: {r.returncode}")
-      print(f"video_id: {video_id}, file name: {filename}")
+        # 
+        video_filename = filename + ".mp4"
+        r = subprocess.run(["sshpass", "-p", "choozmo9", 
+                            "scp", "-o", "StrictHostKeyChecking=no", f"{str(CELERY_ZIP_STORAGE/'output.mp4')}", f"root@{STORAGE_IP}:{'/var/www/videos/'+video_filename}"])
+        print(f"return to local storage: {r.returncode}")
+        print(f"video_id: {video_id}, file name: {filename}")
 
-      db.execute(f"UPDATE video SET progress_state='completed' where id={video_id}")
-      db.commit()
+        db.execute(f"UPDATE video SET progress_state='completed' where id={video_id}")
+        db.commit()
 
-      
-      return "complete"
+        gc_ret = gc.collect()
+        print(f"gc_collected: {gc_ret}")
+        
+        return str(int(content_time))
 
-@celery_app.task(acks_late=True)
-def super_resolution(filenames):
-   source = [f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE/filename)}" for filename in filenames]
-   r = subprocess.run(["sshpass", "-p", "choozmo9", 
-                        "scp", "-o", "StrictHostKeyChecking=no", f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE/zip_filename)}", f"{str(CELERY_ZIP_STORAGE/zip_filename)}"])
-   
+@celery_app.task(acks_late=True, bind=True, track_started=True)
+def super_resolution(self, dirname:str):
+    print(dirname)
+    re = subprocess.run(["sshpass", "-p", "choozmo9", 
+                        "scp", "-o", "StrictHostKeyChecking=no", "-r", f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE/dirname)}", f"{str(CELERY_ZIP_STORAGE)}"])
+    result_dir = dirname+"_result"
+    re = subprocess.run(["python", "/root/github/GFPGAN/inference_gfpgan.py", "-i", f"{str(CELERY_ZIP_STORAGE/dirname)}", "-o", f"{str(CELERY_ZIP_STORAGE/dirname)}", "-v", "1.4"])
 
+    re = subprocess.run(["sshpass", "-p", "choozmo9", 
+                        "scp", "-o", "StrictHostKeyChecking=no", "-r", f"{str(CELERY_ZIP_STORAGE/dirname)}", f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE)}"])
+    
+    return "complete"