فهرست منبع

Merge branch 'master' of http://git.choozmo.com:3000/ai-anchor/video-maker

tomoya 2 سال پیش
والد
کامیت
3ed6c04f05

+ 1 - 0
backend/app/.gitignore

@@ -1,3 +1,4 @@
 .mypy_cache
 .mypy_cache
 .coverage
 .coverage
 htmlcov
 htmlcov
+app/worker.py

+ 37 - 17
backend/app/app/api/api_v1/endpoints/images.py

@@ -1,7 +1,7 @@
 from typing import Any, List, Optional
 from typing import Any, List, Optional
 import subprocess
 import subprocess
 from fastapi import UploadFile, File, Form
 from fastapi import UploadFile, File, Form
-from fastapi.responses import FileResponse
+from fastapi.responses import FileResponse, JSONResponse
 from fastapi import APIRouter, Depends, HTTPException
 from fastapi import APIRouter, Depends, HTTPException
 from fastapi import WebSocket, BackgroundTasks
 from fastapi import WebSocket, BackgroundTasks
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
@@ -52,19 +52,36 @@ async def supser_resolution(
           return {"error": str(e)}
           return {"error": str(e)}
       finally:
       finally:
           upload_files[i].file.close()
           upload_files[i].file.close()
-    source = [f"{str(BACKEND_ZIP_STORAGE/filename)}" for filename in filenames]
-    r = subprocess.run(["sshpass", "-p", "choozmo9", 
-                    "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
-    
-    background_tasks.add_task(wait_finish, new_dir_path, stemnames)
+
+
+    background_tasks.add_task(wait_finish, new_dir, filenames)
     
     
     print(filenames)
     print(filenames)
-    return {"filenames": stemnames}
+    return JSONResponse({"filenames": filenames}, background=background_tasks)
 
 
 async def wait_finish(dirname, filenames):
 async def wait_finish(dirname, filenames):
+    process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9", 
+                     "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(BACKEND_ZIP_STORAGE/dirname)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}")
+    await process.wait()
+    # r = subprocess.run(["sshpass", "-p", "choozmo9", 
+    #                 "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"{str(new_dir_path)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE)}"])
+    res = celery_app.send_task("app.worker.super_resolution", args=[dirname, filenames])
+
+    while True:
+       await asyncio.sleep(0.5)
+       if res.state == "SUCCESS":
+           break
+    restored_imgs = "restored_imgs"
+    process = await asyncio.create_subprocess_exec("sshpass", "-p", "choozmo9", 
+                     "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", "-r", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/dirname/restored_imgs)}/*", f"{str(BACKEND_ZIP_STORAGE)}")
+    await process.wait()
     for filename in filenames:
     for filename in filenames:
-      await asyncio.sleep(3)
-      await publish(filename)
+        await publish(filename)
+
+
+async def publish(data):
+    for sr_client in sr_clients.values():
+        await sr_client.send_text(f"{data}")
 
 
 @router.get("/sr")
 @router.get("/sr")
 def get_image(
 def get_image(
@@ -77,9 +94,17 @@ def get_image(
     """
     """
     Download image
     Download image
     """
     """
-    filename = Path(file_name)
-    response_filename = filename.stem + "_hr.png"
-    return FileResponse(path="test_medias/superman_resolution.png", media_type='image/png', filename=response_filename)
+    return_files = list(BACKEND_ZIP_STORAGE.glob(stored_file_name))
+    if return_files:
+        return_file = return_files[0]
+        return FileResponse(path=str(return_file), filename=file_name+"_hr"+return_file.suffix)
+    else:
+        print("non")
+    
+
+@router.delete("/sr")
+def del_image():
+    pass
 
 
 sr_clients = {}
 sr_clients = {}
 
 
@@ -93,15 +118,10 @@ async def websocket_endpoint(websocket: WebSocket):
             data = await websocket.receive_text()
             data = await websocket.receive_text()
             if not data.startswith("subscribe"):
             if not data.startswith("subscribe"):
               del sr_clients[key]
               del sr_clients[key]
-              return 
               #for client in sr_clients.values():
               #for client in sr_clients.values():
               #      await client.send_text(f"ID: {key} | Message: {data}")
               #      await client.send_text(f"ID: {key} | Message: {data}")
 
 
     except:
     except:
-        #await websocket.close()
         # 接続が切れた場合、当該クライアントを削除する
         # 接続が切れた場合、当該クライアントを削除する
         del sr_clients[key]
         del sr_clients[key]
 
 
-async def publish(data):
-    for sr_client in sr_clients.values():
-        await sr_client.send_text(f"{data}")

+ 1 - 1
backend/app/app/api/api_v1/endpoints/reputations.py

@@ -38,4 +38,4 @@ async def post_reputation(
     #print(posted_article)
     #print(posted_article)
     article = crud.artivle.create_with_owner(db=db, obj_in=posted_article, owner_id=current_user.id, posted_datetime=str(datetime.now()))
     article = crud.artivle.create_with_owner(db=db, obj_in=posted_article, owner_id=current_user.id, posted_datetime=str(datetime.now()))
     if article:
     if article:
-      return {"id":article.id}
+        return {"id":article.id}

+ 1 - 1
backend/app/app/api/api_v1/endpoints/videos.py

@@ -73,7 +73,7 @@ def upload_plot(
     r = subprocess.run(["sshpass", "-p", "choozmo9", 
     r = subprocess.run(["sshpass", "-p", "choozmo9", 
                     "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", f"{str(BACKEND_ZIP_STORAGE/zip_filename)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/zip_filename)}"])
                     "scp", "-P", "5722", "-o", "StrictHostKeyChecking=no", f"{str(BACKEND_ZIP_STORAGE/zip_filename)}", f"root@172.104.93.163:{str(LOCAL_ZIP_STORAGE/zip_filename)}"])
     print(r.returncode)
     print(r.returncode)
-    celery_app.send_task("app.worker.make_video", args=[video.id, video.stored_file_name, current_user.id])
+    celery_app.send_task("app.worker.make_video", args=[video.id, video.stored_file_name, current_user.id, anchor_id, current_user.membership_status, current_user.available_time])
     return video
     return video
 
 
 @router.get("/{id}")
 @router.get("/{id}")

+ 1 - 0
backend/app/app/celeryconf.py

@@ -0,0 +1 @@
+task_track_started=True

+ 2 - 2
backend/app/app/core/celery_app.py

@@ -1,7 +1,7 @@
 from celery import Celery
 from celery import Celery
 
 
-celery_app = Celery("worker", broker="redis://172.104.93.163:16379/0")
+celery_app = Celery("worker", broker="redis://172.104.93.163:16379/0", backend="redis://172.104.93.163:16379/0")
 
 
 
 
 
 
-celery_app.conf.task_routes = {"app.worker.make_video": "main-queue"}
+celery_app.conf.task_routes = {"app.worker.make_video": "main-queue", "app.worker.super_resolution": "main-queue"}

+ 1 - 1
backend/app/app/models/user.py

@@ -11,7 +11,7 @@ if TYPE_CHECKING:
 class User(Base):
 class User(Base):
   id = Column(Integer, primary_key=True, index=True)
   id = Column(Integer, primary_key=True, index=True)
   full_name = Column(String(20), index=True)
   full_name = Column(String(20), index=True)
-  email = Column(String(30), unique=True, index=True, nullable=False)
+  email = Column(String(50), unique=True, index=True, nullable=False)
   hashed_password = Column(String(100), nullable=False)
   hashed_password = Column(String(100), nullable=False)
   membership_status = Column(String(10), 
   membership_status = Column(String(10), 
                       ForeignKey("membership.status", onupdate="CASCADE", ondelete="RESTRICT"), 
                       ForeignKey("membership.status", onupdate="CASCADE", ondelete="RESTRICT"),