浏览代码

add image sr

tomoya 2 年之前
父节点
当前提交
1b731a22ac

+ 44 - 19
backend/app/app/api/api_v1/endpoints/images.py

@@ -1,16 +1,18 @@
 from typing import Any, List, Optional
 from typing import Any, List, Optional
 import subprocess
 import subprocess
 from fastapi import UploadFile, File, Form
 from fastapi import UploadFile, File, Form
-from fastapi.responses import FileResponse
+from fastapi.responses import FileResponse, JSONResponse
 from fastapi import APIRouter, Depends, HTTPException
 from fastapi import APIRouter, Depends, HTTPException
 from fastapi import WebSocket, BackgroundTasks
 from fastapi import WebSocket, BackgroundTasks
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 import asyncio
 import asyncio
+import time
 import shutil
 import shutil
 import app.crud as crud
 import app.crud as crud
 import app.models as models
 import app.models as models
 import app.schemas as schemas 
 import app.schemas as schemas 
 from app.api import deps
 from app.api import deps
+from datetime import datetime
 
 
 from app.core.celery_app import celery_app
 from app.core.celery_app import celery_app
 from app.core.config import settings
 from app.core.config import settings
@@ -27,7 +29,7 @@ router = APIRouter()
 
 
 
 
 @router.post("/sr")
 @router.post("/sr")
-async def supser_resolution(
+def supser_resolution(
     *,
     *,
     db: Session = Depends(deps.get_db),
     db: Session = Depends(deps.get_db),
     current_user: models.User = Depends(deps.get_current_active_user),
     current_user: models.User = Depends(deps.get_current_active_user),
@@ -37,6 +39,7 @@ async def supser_resolution(
     """
     """
     Super Resolution.
     Super Resolution.
     """
     """
+    print("api start: "+str(datetime.now()))
     filenames = [random_name(20)+Path(file.filename).suffix for file in upload_files]
     filenames = [random_name(20)+Path(file.filename).suffix for file in upload_files]
     stemnames = [Path(filename).stem for filename in filenames]
     stemnames = [Path(filename).stem for filename in filenames]
     new_dir = random_name(10)
     new_dir = random_name(10)
@@ -52,19 +55,42 @@ async def supser_resolution(
           return {"error": str(e)}
           return {"error": str(e)}
       finally:
       finally:
           upload_files[i].file.close()
           upload_files[i].file.close()
-    source = [f"{str(BACKEND_ZIP_STORAGE/filename)}" for filename in filenames]
-    r = subprocess.run(["sshpass", "-p", "choozmo9", 
-                    "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
-    
-    background_tasks.add_task(wait_finish, new_dir_path, stemnames)
     
     
+    background_tasks.add_task(wait_finish, new_dir, filenames)
+
     print(filenames)
     print(filenames)
-    return {"filenames": stemnames}
+    return JSONResponse({"filenames": stemnames}, background=background_tasks)
 
 
 async def wait_finish(dirname, filenames):
 async def wait_finish(dirname, filenames):
+    # print("start: "+str(datetime.now()))
+    new_dir_path = Path(BACKEND_ZIP_STORAGE).joinpath(dirname)
+    process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9", 
+                    "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}")
+    await process.wait()
+    # r = subprocess.run(["sshpass", "-p", "choozmo9", 
+    #                 "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
+    
+    restored_imgs = "restored_imgs" 
+    res = celery_app.send_task("app.worker.super_resolution", args=[dirname])
+    # print(res.state)
+    while True:
+        await asyncio.sleep(0.5)
+        print(res.state)
+        if res.status == 'SUCCESS':
+            print("recieve finished")
+            break
     for filename in filenames:
     for filename in filenames:
-      await asyncio.sleep(3)
-      await publish(filename)
+        process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9", 
+                      "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/dirname/restored_imgs/filename)}", f"{str(BACKEND_ZIP_STORAGE)}")
+        await process.wait()
+        # re = subprocess.run(["sshpass", "-p", "choozmo9", 
+        #             "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/dirname/restored_imgs/filename)}", f"{str(BACKEND_ZIP_STORAGE)}"])
+        
+        for sr_client in sr_clients.values():
+            await sr_client.send_text(f"{filename}")
+    # print("end: "+str(datetime.now()))
+
+    
 
 
 @router.get("/sr")
 @router.get("/sr")
 def get_image(
 def get_image(
@@ -77,10 +103,11 @@ def get_image(
     """
     """
     Download image
     Download image
     """
     """
-    filename = Path(file_name)
-    response_filename = filename.stem + "_hr.png"
-    return FileResponse(path="test_medias/superman_resolution.png", media_type='image/png', filename=response_filename)
+    
+    return_file_path = list(BACKEND_ZIP_STORAGE.glob(stored_file_name+"*"))[0]
+    return FileResponse(path=str(return_file_path), media_type='image/png', filename=file_name+return_file_path.suffix)
 
 
+ws_clients = {}
 sr_clients = {}
 sr_clients = {}
 
 
 @router.websocket("/sr")
 @router.websocket("/sr")
@@ -88,12 +115,14 @@ async def websocket_endpoint(websocket: WebSocket):
     await websocket.accept()
     await websocket.accept()
     key = websocket.headers.get('sec-websocket-key')
     key = websocket.headers.get('sec-websocket-key')
     sr_clients[key] = websocket
     sr_clients[key] = websocket
+    print(f"new comer: {key}")
     try:
     try:
         while True:
         while True:
             data = await websocket.receive_text()
             data = await websocket.receive_text()
-            if not data.startswith("subscribe"):
+            print(f"{key}:{data}")
+            if data.startswith("unsubscribe"):
               del sr_clients[key]
               del sr_clients[key]
-              return 
+              print(f"beybey: {key}")
               #for client in sr_clients.values():
               #for client in sr_clients.values():
               #      await client.send_text(f"ID: {key} | Message: {data}")
               #      await client.send_text(f"ID: {key} | Message: {data}")
 
 
@@ -101,7 +130,3 @@ async def websocket_endpoint(websocket: WebSocket):
         #await websocket.close()
         #await websocket.close()
         # 接続が切れた場合、当該クライアントを削除する
         # 接続が切れた場合、当該クライアントを削除する
         del sr_clients[key]
         del sr_clients[key]
-
-async def publish(data):
-    for sr_client in sr_clients.values():
-        await sr_client.send_text(f"{data}")

+ 6 - 3
backend/app/app/api/api_v1/endpoints/videos.py

@@ -56,9 +56,7 @@ def upload_plot(
     print(title)
     print(title)
     print(upload_file.filename)
     print(upload_file.filename)
     file_name = crud.video.generate_file_name(db=db, n=20)
     file_name = crud.video.generate_file_name(db=db, n=20)
-    video_create = schemas.VideoCreate(title=title, progress_state="waiting", stored_file_name=file_name)
-    video = crud.video.create_with_owner(db=db, obj_in=video_create, owner_id=current_user.id)
-
+    
     try:
     try:
         with open(str(Path(BACKEND_ZIP_STORAGE).joinpath(video.stored_file_name+".zip")), 'wb') as f:
         with open(str(Path(BACKEND_ZIP_STORAGE).joinpath(video.stored_file_name+".zip")), 'wb') as f:
             while contents := upload_file.file.read(1024 * 1024):
             while contents := upload_file.file.read(1024 * 1024):
@@ -68,6 +66,11 @@ def upload_plot(
         return {"error": str(e)}
         return {"error": str(e)}
     finally:
     finally:
         upload_file.file.close()
         upload_file.file.close()
+
+    # check valid file
+    video_create = schemas.VideoCreate(title=title, progress_state="waiting", stored_file_name=file_name)
+    video = crud.video.create_with_owner(db=db, obj_in=video_create, owner_id=current_user.id)
+
     zip_filename = video.stored_file_name+".zip"
     zip_filename = video.stored_file_name+".zip"
     print(str(BACKEND_ZIP_STORAGE/zip_filename))
     print(str(BACKEND_ZIP_STORAGE/zip_filename))
     r = subprocess.run(["sshpass", "-p", "choozmo9", 
     r = subprocess.run(["sshpass", "-p", "choozmo9", 

+ 1 - 0
backend/app/app/celeryconf.py

@@ -0,0 +1 @@
+task_track_started=True

+ 2 - 2
backend/app/app/core/celery_app.py

@@ -1,7 +1,7 @@
 from celery import Celery
 from celery import Celery
 
 
-celery_app = Celery("worker", broker="redis://172.104.93.163:16379/0")
+celery_app = Celery("worker", broker="redis://172.104.93.163:16379/0", backend="redis://172.104.93.163:16379/0")
 
 
 
 
 
 
-celery_app.conf.task_routes = {"app.worker.make_video": "main-queue"}
+celery_app.conf.task_routes = {"app.worker.make_video": "main-queue", "app.worker.super_resolution": "main-queue"}

+ 34 - 22
backend/app/app/worker.py

@@ -11,6 +11,7 @@ import dataset
 from app.db.session import SessionLocal
 from app.db.session import SessionLocal
 from app.models import video
 from app.models import video
 from app import crud
 from app import crud
+import gc
 download_to_local_url = urljoin(settings.SERVER_HOST, settings.API_V1_STR+"/videos/worker")
 download_to_local_url = urljoin(settings.SERVER_HOST, settings.API_V1_STR+"/videos/worker")
 upload_to_server_url = urljoin(settings.SERVER_HOST, settings.API_V1_STR+"/videos/worker")
 upload_to_server_url = urljoin(settings.SERVER_HOST, settings.API_V1_STR+"/videos/worker")
 
 
@@ -23,8 +24,8 @@ STORAGE_IP = '192.168.192.252'#os.getenv('STORAGE_IP')
 if not STORAGE_IP:
 if not STORAGE_IP:
     raise Exception
     raise Exception
 
 
-@celery_app.task(acks_late=True)
-def make_video(video_id, filename, user_id) -> str:
+@celery_app.task(acks_late=True, bind=True, track_started=True)
+def make_video(self, video_id, filename, user_id) -> str:
     #video_id, zip_filename, user_id = args
     #video_id, zip_filename, user_id = args
     # download 
     # download 
     '''
     '''
@@ -45,29 +46,40 @@ def make_video(video_id, filename, user_id) -> str:
     db.commit()
     db.commit()
     # make video
     # make video
     try:
     try:
-      make_video_from_zip(working_dir=CELERY_ZIP_STORAGE,style=Path("app/style/choozmo"),  inputfile=zip_filename,opening=False, ending=False)
+      content_time = make_video_from_zip(working_dir=CELERY_ZIP_STORAGE,style=Path("app/style/choozmo"),
+                                         inputfile=zip_filename,
+                                         opening=False, 
+                                         ending=False,
+                                         watermark_path='medias/logo_watermark.jpg')
     except Exception as e:
     except Exception as e:
-      print(f"error:{e}")
-      db.execute(f"UPDATE video SET progress_state='failed' where id={video_id}")
-      db.commit()
+        print(f"error:{e}")
+        db.execute(f"UPDATE video SET progress_state='failed' where id={video_id}")
+        db.commit()
     else:
     else:
-      # 
-      video_filename = filename + ".mp4"
-      r = subprocess.run(["sshpass", "-p", "choozmo9", 
-                          "scp", "-o", "StrictHostKeyChecking=no", f"{str(CELERY_ZIP_STORAGE/'output.mp4')}", f"root@{STORAGE_IP}:{'/var/www/videos/'+video_filename}"])
-      print(f"return to local storage: {r.returncode}")
-      print(f"video_id: {video_id}, file name: {filename}")
+        # 
+        video_filename = filename + ".mp4"
+        r = subprocess.run(["sshpass", "-p", "choozmo9", 
+                            "scp", "-o", "StrictHostKeyChecking=no", f"{str(CELERY_ZIP_STORAGE/'output.mp4')}", f"root@{STORAGE_IP}:{'/var/www/videos/'+video_filename}"])
+        print(f"return to local storage: {r.returncode}")
+        print(f"video_id: {video_id}, file name: {filename}")
 
 
-      db.execute(f"UPDATE video SET progress_state='completed' where id={video_id}")
-      db.commit()
+        db.execute(f"UPDATE video SET progress_state='completed' where id={video_id}")
+        db.commit()
 
 
-      
-      return "complete"
+        gc_ret = gc.collect()
+        print(f"gc_collected: {gc_ret}")
+        
+        return str(int(content_time))
 
 
-@celery_app.task(acks_late=True)
-def super_resolution(filenames):
-   source = [f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE/filename)}" for filename in filenames]
-   r = subprocess.run(["sshpass", "-p", "choozmo9", 
-                        "scp", "-o", "StrictHostKeyChecking=no", f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE/zip_filename)}", f"{str(CELERY_ZIP_STORAGE/zip_filename)}"])
-   
+@celery_app.task(acks_late=True, bind=True, track_started=True)
+def super_resolution(self, dirname:str):
+    print(dirname)
+    re = subprocess.run(["sshpass", "-p", "choozmo9", 
+                        "scp", "-o", "StrictHostKeyChecking=no", "-r", f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE/dirname)}", f"{str(CELERY_ZIP_STORAGE)}"])
+    result_dir = dirname+"_result"
+    re = subprocess.run(["python", "/root/github/GFPGAN/inference_gfpgan.py", "-i", f"{str(CELERY_ZIP_STORAGE/dirname)}", "-o", f"{str(CELERY_ZIP_STORAGE/dirname)}", "-v", "1.4"])
 
 
+    re = subprocess.run(["sshpass", "-p", "choozmo9", 
+                        "scp", "-o", "StrictHostKeyChecking=no", "-r", f"{str(CELERY_ZIP_STORAGE/dirname)}", f"root@{STORAGE_IP}:{str(LOCAL_ZIP_STORAGE)}"])
+    
+    return "complete"