tomoya 1 mês atrás
pai
commit
db6961c3da
1 arquivos alterados com 39 adições e 36 exclusões
  1. 39 36
      backend/app/app/api/api_v1/endpoints/text2zip.py

+ 39 - 36
backend/app/app/api/api_v1/endpoints/text2zip.py

@@ -166,7 +166,6 @@ async def generate_video(
         return HTTPException("No texts.")
 
 async def wait_finish(model, email, texts, lang): 
-    db = SessionLocal()
     if not model:
         model = 'sd3'
     wb = px.Workbook()
@@ -175,42 +174,44 @@ async def wait_finish(model, email, texts, lang):
     ws['A1'] = '大標'
     ws['B1'] = '字幕'
     ws['C1'] = '素材'
-    with tempfile.TemporaryDirectory() as td:
-        dir = Path(f'{td}/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}')
-        dir.mkdir(exist_ok=False)
-        texts = [text for text in texts if text]
-        for i, text in enumerate(texts):
-            print(f'{i+1}/{len(texts)}')
-            prompt = gen_prompt(text)
-            if model=='flux':
-                img_path = Path(gen_flux_image(prompt))
+    td = Path(f'{td}/{datetime.datetime.now().strftime("%Y%m%d%H%M%S-td")}')
+    td.mkdir(exist_ok=False)
+    dir = td/f'{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}'
+    dir.mkdir(exist_ok=False)
+    texts = [text for text in texts if text]
+    for i, text in enumerate(texts):
+        print(f'{i+1}/{len(texts)}')
+        prompt = gen_prompt(text)
+        if model=='flux':
+            img_path = Path(gen_flux_image(prompt))
 
-            elif model=='sd3':
-                img_path = Path(gen_sd_image(prompt))
-            print("before", str(img_path))
-            img_path = img_path.rename(dir/(f'{i+1:02}'+img_path.suffix))
-            print("after", str(img_path))
-            ws['B'+ str(i+2)] = re.sub(punctuation, r"\\", text)
-            ws['C'+ str(i+2)] = img_path.name
-        excel_path = Path(dir/'script.xlsx')
-        wb.save(excel_path)
-        output_dir = '/tmp'
-        shutil.make_archive(f'{output_dir}/{dir.name}', format='zip', root_dir=td)
-        # def remove_zip():
-        #     if os.path.exists(f'{output_dir}/{dir.name}.zip'):
-        #         os.remove(f'{output_dir}/{dir.name}.zip')
-        current_user = crud.user.get(db, id=0)
-        video_create = schemas.VideoCreate(title="guest", progress_state="PENDING", stored_filename=dir.name)
-        video = crud.video.create_with_owner(db=db, obj_in=video_create, owner_id=current_user.id)
-        video_data = jsonable_encoder(video)
-        video_data['membership_status'] = current_user.membership_status
-        video_data['available_time'] = current_user.available_time
-        video_data['video_id'] = video_data['id']
-        video_data['character'] = "hannah-2"
-        video_data['anchor'] = "hannah-2"
-        video_data['style'] = "style14"
-        video_data['lang'] = lang
-        video_data['email'] = email
+        elif model=='sd3':
+            img_path = Path(gen_sd_image(prompt))
+        print("before", str(img_path))
+        img_path = img_path.rename(dir/(f'{i+1:02}'+img_path.suffix))
+        print("after", str(img_path))
+        ws['B'+ str(i+2)] = re.sub(punctuation, r"\\", text)
+        ws['C'+ str(i+2)] = img_path.name
+    excel_path = Path(dir/'script.xlsx')
+    wb.save(excel_path)
+    output_dir = '/tmp'
+    shutil.make_archive(f'{output_dir}/{dir.name}', format='zip', root_dir=td)
+    # def remove_zip():
+    #     if os.path.exists(f'{output_dir}/{dir.name}.zip'):
+    #         os.remove(f'{output_dir}/{dir.name}.zip')
+    db = SessionLocal()
+    current_user = crud.user.get(db, id=0)
+    video_create = schemas.VideoCreate(title="guest", progress_state="PENDING", stored_filename=dir.name)
+    video = crud.video.create_with_owner(db=db, obj_in=video_create, owner_id=current_user.id)
+    video_data = jsonable_encoder(video)
+    video_data['membership_status'] = current_user.membership_status
+    video_data['available_time'] = current_user.available_time
+    video_data['video_id'] = video_data['id']
+    video_data['character'] = "hannah-2"
+    video_data['anchor'] = "hannah-2"
+    video_data['style'] = "style14"
+    video_data['lang'] = lang
+    video_data['email'] = email
     db.close()
     
     zip_filename = video_data['stored_filename']+".zip"
@@ -220,6 +221,8 @@ async def wait_finish(model, email, texts, lang):
     await process.wait()
     if os.path.exists(f"/tmp/{zip_filename}"):
         os.remove(f"/tmp/{zip_filename}")
+    if td.exists():
+        shutil.rmtree(str(td))
     headers = {
         'Authorization': 'Bearer ' + LINE_TOKEN    
     }