Skip to content

Commit

Permalink
refactor(webui): 优化视觉分析批次处理逻辑
Browse files Browse the repository at this point in the history
- 提取 vision_batch_size 到单独变量,提高代码可读性
- 使用 vision_batch_size 替代多次调用 config(frames.get("vision_batch_size")
- 添加调试日志,记录批次数量和每批次的图片数量
  • Loading branch information
linyqh committed Nov 22, 2024
1 parent 72116a7 commit 4ad9c41
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/dockerImageBuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: build_docker
on:
release:
types: [created] # 表示在创建新的 Release 时触发
workflow_dispatch:

jobs:
build_docker:
Expand Down
4 changes: 3 additions & 1 deletion app/utils/vision_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def _generate_content_with_retry(self, prompt, batch):
async def analyze_images(self,
images: Union[List[str], List[PIL.Image.Image]],
prompt: str,
batch_size: int = 5) -> List[Dict]:
batch_size: int) -> List[Dict]:
"""批量分析多张图片"""
try:
# 加载图片
Expand All @@ -82,6 +82,8 @@ async def analyze_images(self,
results = []
total_batches = (len(images) + batch_size - 1) // batch_size

logger.debug(f"共 {total_batches} 个批次,每批次 {batch_size} 张图片")

with tqdm(total=total_batches, desc="分析进度") as pbar:
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
Expand Down
9 changes: 5 additions & 4 deletions webui/components/script_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,12 @@ def update_progress(progress: float, message: str = ""):
asyncio.set_event_loop(loop)

# 执行异步分析
vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size")
results = loop.run_until_complete(
analyzer.analyze_images(
images=keyframe_files,
prompt=config.app.get('vision_analysis_prompt'),
batch_size=config.frames.get("vision_batch_size", st.session_state.get('vision_batch_size', 5))
batch_size=vision_batch_size
)
)
loop.close()
Expand All @@ -437,8 +438,8 @@ def update_progress(progress: float, message: str = ""):
if 'error' in result:
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}")
continue
batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
# 获取当前批次的文件列表
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片")
logger.debug(batch_files)

Expand Down Expand Up @@ -477,7 +478,7 @@ def update_progress(progress: float, message: str = ""):
if 'error' in result:
continue

batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5))
batch_files = get_batch_files(keyframe_files, result, vision_batch_size)
_, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files)

frame_content = {
Expand Down

0 comments on commit 4ad9c41

Please sign in to comment.