tomoya il y a 7 mois
Parent
commit
353de1bb51
1 fichiers modifiés avec 41 ajouts et 1 suppressions
  1. 41 1
      backend/app/app/api/api_v1/endpoints/images.py

+ 41 - 1
backend/app/app/api/api_v1/endpoints/images.py

@@ -1,4 +1,5 @@
-from typing import Any, List, Optional
+import os
+from typing import Any, List, Optional, Literal
 import subprocess
 from fastapi import UploadFile, File, Form
 from fastapi.responses import FileResponse, JSONResponse
@@ -20,6 +21,7 @@ from pathlib import Path
 from app.utils import random_name
 
 from app.core.celery_app import celery_app
+from gradio_client import Client
 
 BACKEND_ZIP_STORAGE = Path("/app").joinpath(settings.BACKEND_ZIP_STORAGE)
 LOCAL_ZIP_STORAGE = Path("/").joinpath(settings.LOCAL_ZIP_STORAGE)
@@ -120,6 +122,42 @@ async def delete_img(filename:str):
 def del_image():
     pass
 
+aspects = ['16:9', '3:2', '1:1', '2:3', '9:16']
+flux_image_size = [
+    (450, 800),
+    (480, 720),
+    (600, 600),
+    (720, 480),
+    (800, 250)
+]
+
+def remove_file(path: str) -> None:
+    os.remove(path)
+
+@router.post('/flux')
+def get_image(
+    *,
+    db: Session = Depends(deps.get_db),
+    current_user: models.User = Depends(deps.get_current_active_user),
+    prompt:str, 
+    aspect:Literal['16:9', '3:2', '1:1', '2:3', '9:16'],
+    background_tasks: BackgroundTasks
+) -> Any:
+    width, height = flux_image_size[aspects.index(aspect)]
+    client = Client("http://192.168.192.83:7860/")
+    result = client.predict(
+            model_id="models/FLUX.1-schnell",
+            prompt=prompt,
+            width=width,
+            height=height,
+            seed=-1,
+            steps=4,
+            guidance_scale=3.5,
+            add_sampling_metadata=True,
+            api_name="/generate"
+    )   
+    return FileResponse(result, background=background_tasks(remove_file, result))
+
 @router.websocket("/sr")
 async def websocket_endpoint(websocket: WebSocket):
     await websocket.accept()
@@ -138,3 +176,5 @@ async def websocket_endpoint(websocket: WebSocket):
     except:
         # 接続が切れた場合、当該クライアントを削除する
         del sr_clients[key]
+        
+