tomoya 7 meses atrás
pai
commit
353de1bb51
1 arquivos alterados com 41 adições e 1 exclusões
  1. 41 1
      backend/app/app/api/api_v1/endpoints/images.py

+ 41 - 1
backend/app/app/api/api_v1/endpoints/images.py

@@ -1,4 +1,5 @@
-from typing import Any, List, Optional
+import os
+from typing import Any, List, Optional, Literal
 import subprocess
 import subprocess
 from fastapi import UploadFile, File, Form
 from fastapi import UploadFile, File, Form
 from fastapi.responses import FileResponse, JSONResponse
 from fastapi.responses import FileResponse, JSONResponse
@@ -20,6 +21,7 @@ from pathlib import Path
 from app.utils import random_name
 from app.utils import random_name
 
 
 from app.core.celery_app import celery_app
 from app.core.celery_app import celery_app
+from gradio_client import Client
 
 
 BACKEND_ZIP_STORAGE = Path("/app").joinpath(settings.BACKEND_ZIP_STORAGE)
 BACKEND_ZIP_STORAGE = Path("/app").joinpath(settings.BACKEND_ZIP_STORAGE)
 LOCAL_ZIP_STORAGE = Path("/").joinpath(settings.LOCAL_ZIP_STORAGE)
 LOCAL_ZIP_STORAGE = Path("/").joinpath(settings.LOCAL_ZIP_STORAGE)
@@ -120,6 +122,42 @@ async def delete_img(filename:str):
 def del_image():
 def del_image():
     pass
     pass
 
 
+aspects = ['16:9', '3:2', '1:1', '2:3', '9:16']
+flux_image_size = [
+    (450, 800),
+    (480, 720),
+    (600, 600),
+    (720, 480),
+    (800, 250)
+]
+
+def remove_file(path: str) -> None:
+    os.remove(path)
+
+@router.post('/flux')
+def get_image(
+    *,
+    db: Session = Depends(deps.get_db),
+    current_user: models.User = Depends(deps.get_current_active_user),
+    prompt:str, 
+    aspect:Literal['16:9', '3:2', '1:1', '2:3', '9:16'],
+    background_tasks: BackgroundTasks
+) -> Any:
+    width, height = flux_image_size[aspects.index(aspect)]
+    client = Client("http://192.168.192.83:7860/")
+    result = client.predict(
+            model_id="models/FLUX.1-schnell",
+            prompt=prompt,
+            width=width,
+            height=height,
+            seed=-1,
+            steps=4,
+            guidance_scale=3.5,
+            add_sampling_metadata=True,
+            api_name="/generate"
+    )   
+    return FileResponse(result, background=background_tasks(remove_file, result))
+
 @router.websocket("/sr")
 @router.websocket("/sr")
 async def websocket_endpoint(websocket: WebSocket):
 async def websocket_endpoint(websocket: WebSocket):
     await websocket.accept()
     await websocket.accept()
@@ -138,3 +176,5 @@ async def websocket_endpoint(websocket: WebSocket):
     except:
     except:
         # 接続が切れた場合、当該クライアントを削除する
         # 接続が切れた場合、当該クライアントを削除する
         del sr_clients[key]
         del sr_clients[key]
+        
+