|
@@ -1,4 +1,5 @@
|
|
|
-from typing import Any, List, Optional
|
|
|
+import os
|
|
|
+from typing import Any, List, Optional, Literal
|
|
|
import subprocess
|
|
|
from fastapi import UploadFile, File, Form
|
|
|
from fastapi.responses import FileResponse, JSONResponse
|
|
@@ -20,6 +21,7 @@ from pathlib import Path
|
|
|
from app.utils import random_name
|
|
|
|
|
|
from app.core.celery_app import celery_app
|
|
|
+from gradio_client import Client
|
|
|
|
|
|
BACKEND_ZIP_STORAGE = Path("/app").joinpath(settings.BACKEND_ZIP_STORAGE)
|
|
|
LOCAL_ZIP_STORAGE = Path("/").joinpath(settings.LOCAL_ZIP_STORAGE)
|
|
@@ -120,6 +122,42 @@ async def delete_img(filename:str):
|
|
|
def del_image():
|
|
|
pass
|
|
|
|
|
|
+aspects = ['16:9', '3:2', '1:1', '2:3', '9:16']
|
|
|
+flux_image_size = [
|
|
|
+ (450, 800),
|
|
|
+ (480, 720),
|
|
|
+ (600, 600),
|
|
|
+ (720, 480),
|
|
|
+ (800, 250)
|
|
|
+]
|
|
|
+
|
|
|
+def remove_file(path: str) -> None:
|
|
|
+ os.remove(path)
|
|
|
+
|
|
|
+@router.post('/flux')
|
|
|
+def get_image(
|
|
|
+ *,
|
|
|
+ db: Session = Depends(deps.get_db),
|
|
|
+ current_user: models.User = Depends(deps.get_current_active_user),
|
|
|
+ prompt:str,
|
|
|
+ aspect:Literal['16:9', '3:2', '1:1', '2:3', '9:16'],
|
|
|
+ background_tasks: BackgroundTasks
|
|
|
+) -> Any:
|
|
|
+ width, height = flux_image_size[aspects.index(aspect)]
|
|
|
+ client = Client("http://192.168.192.83:7860/")
|
|
|
+ result = client.predict(
|
|
|
+ model_id="models/FLUX.1-schnell",
|
|
|
+ prompt=prompt,
|
|
|
+ width=width,
|
|
|
+ height=height,
|
|
|
+ seed=-1,
|
|
|
+ steps=4,
|
|
|
+ guidance_scale=3.5,
|
|
|
+ add_sampling_metadata=True,
|
|
|
+ api_name="/generate"
|
|
|
+ )
|
|
|
+ return FileResponse(result, background=background_tasks(remove_file, result))
|
|
|
+
|
|
|
@router.websocket("/sr")
|
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
|
await websocket.accept()
|
|
@@ -138,3 +176,5 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
|
except:
|
|
|
# 接続が切れた場合、当該クライアントを削除する
|
|
|
del sr_clients[key]
|
|
|
+
|
|
|
+
|