Merge pull request '后端:稳定租户链路与 VLM 图片质量检测' (#9) from wren-dev into main
Reviewed-on: #9
This commit was merged in pull request #9.
This commit is contained in:
@@ -239,7 +239,7 @@ class ResilientQwenVLMClient(QwenVLMClient):
|
|||||||
body = response.json()
|
body = response.json()
|
||||||
text = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
|
text = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
|
||||||
parsed = _parse_json_loose(text)
|
parsed = _parse_json_loose(text)
|
||||||
return parsed if isinstance(parsed, dict) else {}
|
return parsed if isinstance(parsed, dict) else {"result": text, "reason": text}
|
||||||
|
|
||||||
|
|
||||||
class ResilientChandraOCRClient(ChandraOCRClient):
|
class ResilientChandraOCRClient(ChandraOCRClient):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tempfile
|
import tempfile
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
import fitz
|
import fitz
|
||||||
from leaudit.converters import doc2pdf
|
from leaudit.converters import doc2pdf
|
||||||
@@ -30,9 +31,11 @@ _PAGE_QUALITY_VLM_PROMPT = """
|
|||||||
你是文档扫描图片质量检测员。请判断这 1 页文档图片是否适合继续做 OCR 与合同/公文评查。
|
你是文档扫描图片质量检测员。请判断这 1 页文档图片是否适合继续做 OCR 与合同/公文评查。
|
||||||
|
|
||||||
判定标准:
|
判定标准:
|
||||||
1. pass:文字主体清晰、方向正常、没有明显截断,能稳定阅读。
|
1. 必须同时检查整页扫描质量,以及页面内所有内嵌照片、证据照片、现场照片、截图、印章和签名图片的清晰度。
|
||||||
2. review:存在轻微模糊、倾斜、阴影、低对比度、局部遮挡、轻微截断,建议人工确认但仍可能可读。
|
2. pass:文字主体清晰、方向正常、没有明显截断;页面内嵌照片/证据照片也能辨认关键视觉信息。
|
||||||
3. reject:严重模糊、重影、过曝/过暗、页面大面积缺失、关键文字不可辨认、方向严重错误、空白页或非文档页,建议重拍。
|
3. review:存在轻微模糊、倾斜、阴影、低对比度、局部遮挡、轻微截断;或内嵌照片/证据照片主体明显发虚、牌匾/场所/人物/关键物证不易辨认,建议人工确认但仍可能可用。
|
||||||
|
4. reject:严重模糊、重影、过曝/过暗、页面大面积缺失、关键文字不可辨认、方向严重错误、空白页或非文档页;或内嵌证据照片主体无法辨认、关键证据信息不可用,建议重拍。
|
||||||
|
5. 即使页面周边文字清楚,只要内嵌证据照片明显模糊,也不能判 pass,至少判 review,严重时判 reject。
|
||||||
|
|
||||||
只输出 JSON,不要输出 Markdown,不要解释额外文本:
|
只输出 JSON,不要输出 Markdown,不要解释额外文本:
|
||||||
{"status":"pass|review|reject","score":0.0到1.0,"reason":"20字以内中文原因"}
|
{"status":"pass|review|reject","score":0.0到1.0,"reason":"20字以内中文原因"}
|
||||||
@@ -495,12 +498,28 @@ class PageQualityServiceImpl(IPageQualityService):
|
|||||||
logger.warning("VLM page quality detection failed: %s", exc)
|
logger.warning("VLM page quality detection failed: %s", exc)
|
||||||
return "review", 0.5, "VLM图片质量检测失败,需人工确认"
|
return "review", 0.5, "VLM图片质量检测失败,需人工确认"
|
||||||
|
|
||||||
status = str((result or {}).get("status") or "").strip().lower()
|
result_dict = self._coerce_vlm_result(result)
|
||||||
|
status = self._normalize_quality_status(
|
||||||
|
self._first_non_empty(
|
||||||
|
result_dict,
|
||||||
|
("status", "quality_status", "qualityStatus", "result", "label", "decision", "conclusion"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
reason = self._normalize_quality_reason(
|
||||||
|
self._first_non_empty(
|
||||||
|
result_dict,
|
||||||
|
("reason", "quality_reason", "qualityReason", "message", "msg", "detail", "explanation", "description"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if status is None and reason:
|
||||||
|
status = self._normalize_quality_status(reason)
|
||||||
if status not in {"pass", "review", "reject"}:
|
if status not in {"pass", "review", "reject"}:
|
||||||
return "review", 0.5, "VLM返回结果不可用,需人工确认"
|
return "review", 0.5, "VLM返回结果不可用,需人工确认"
|
||||||
|
|
||||||
score = self._normalize_quality_score((result or {}).get("score"), status)
|
score = self._normalize_quality_score(
|
||||||
reason = str((result or {}).get("reason") or "").strip() or None
|
self._first_non_empty(result_dict, ("score", "quality_score", "qualityScore", "confidence")),
|
||||||
|
status,
|
||||||
|
)
|
||||||
if status != "pass" and not reason:
|
if status != "pass" and not reason:
|
||||||
reason = "页面图片质量需人工确认"
|
reason = "页面图片质量需人工确认"
|
||||||
return status, score, reason
|
return status, score, reason
|
||||||
@@ -526,6 +545,56 @@ class PageQualityServiceImpl(IPageQualityService):
|
|||||||
return defaults[status]
|
return defaults[status]
|
||||||
return max(0.0, min(1.0, score))
|
return max(0.0, min(1.0, score))
|
||||||
|
|
||||||
|
def _coerce_vlm_result(self, result: Any) -> dict[str, Any]:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
if isinstance(result, str):
|
||||||
|
text_result = result.strip()
|
||||||
|
if not text_result:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text_result)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {"result": text_result, "reason": text_result}
|
||||||
|
return parsed if isinstance(parsed, dict) else {"result": text_result}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _first_non_empty(self, payload: dict[str, Any], keys: tuple[str, ...]) -> Any:
|
||||||
|
for key in keys:
|
||||||
|
value = payload.get(key)
|
||||||
|
if value is not None and str(value).strip():
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _normalize_quality_status(self, raw_status: Any) -> str | None:
|
||||||
|
text_status = str(raw_status or "").strip().lower()
|
||||||
|
if not text_status:
|
||||||
|
return None
|
||||||
|
compact_status = text_status.replace(" ", "").replace("_", "").replace("-", "")
|
||||||
|
if compact_status in {"pass", "passed", "ok", "good", "clear", "readable"}:
|
||||||
|
return "pass"
|
||||||
|
if compact_status in {"review", "warn", "warning", "manual", "uncertain", "suspect", "suspicious"}:
|
||||||
|
return "review"
|
||||||
|
if compact_status in {"reject", "rejected", "fail", "failed", "bad", "unreadable", "retake"}:
|
||||||
|
return "reject"
|
||||||
|
|
||||||
|
reject_keywords = ("不通过", "拒绝", "重拍", "不可读", "无法辨认", "无法识别", "严重", "大面积缺失", "空白页")
|
||||||
|
review_keywords = ("复核", "人工", "疑似", "轻微", "建议确认", "建议人工", "模糊", "不清晰", "低对比", "发虚")
|
||||||
|
pass_keywords = ("通过", "合格", "清晰", "可读")
|
||||||
|
if any(keyword in text_status for keyword in reject_keywords):
|
||||||
|
return "reject"
|
||||||
|
if any(keyword in text_status for keyword in review_keywords):
|
||||||
|
return "review"
|
||||||
|
if any(keyword in text_status for keyword in pass_keywords):
|
||||||
|
return "pass"
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _normalize_quality_reason(self, raw_reason: Any) -> str | None:
|
||||||
|
reason = str(raw_reason or "").strip()
|
||||||
|
if not reason:
|
||||||
|
return None
|
||||||
|
return reason[:80]
|
||||||
|
|
||||||
def _document_service(self):
|
def _document_service(self):
|
||||||
if self.DocumentService is None:
|
if self.DocumentService is None:
|
||||||
from fastapi_modules.fastapi_leaudit.services.impl.documentServiceImpl import DocumentServiceImpl
|
from fastapi_modules.fastapi_leaudit.services.impl.documentServiceImpl import DocumentServiceImpl
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""RAG 聊天服务实现。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -49,6 +51,8 @@ MANUAL_TITLE_SOURCE = "manual"
|
|||||||
|
|
||||||
|
|
||||||
class RagChatServiceImpl(IRagChatService):
|
class RagChatServiceImpl(IRagChatService):
|
||||||
|
"""RAG 聊天服务实现。"""
|
||||||
|
|
||||||
_message_tasks: dict[str, asyncio.Task] = {}
|
_message_tasks: dict[str, asyncio.Task] = {}
|
||||||
_task_events: dict[str, list[dict]] = {}
|
_task_events: dict[str, list[dict]] = {}
|
||||||
_task_done: dict[str, bool] = {}
|
_task_done: dict[str, bool] = {}
|
||||||
@@ -1143,14 +1147,17 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
except Exception:
|
except Exception:
|
||||||
followups = []
|
followups = []
|
||||||
|
|
||||||
|
sources = self._build_sources(context_chunks, dataset_name)
|
||||||
if message_end_payload:
|
if message_end_payload:
|
||||||
message_end_payload.setdefault("metadata", {})["suggested_questions"] = followups
|
message_end_metadata = message_end_payload.setdefault("metadata", {})
|
||||||
|
message_end_metadata["suggested_questions"] = followups
|
||||||
|
message_end_metadata["retriever_resources"] = sources
|
||||||
await self._append_task_event(task_id, message_end_payload)
|
await self._append_task_event(task_id, message_end_payload)
|
||||||
await self._finalize_message_record(
|
await self._finalize_message_record(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
content=collected_answer,
|
content=collected_answer,
|
||||||
sources=self._build_sources(context_chunks, dataset_name),
|
sources=sources,
|
||||||
metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id},
|
metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id},
|
||||||
)
|
)
|
||||||
await self._maybe_schedule_auto_title(
|
await self._maybe_schedule_auto_title(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""RAG 知识库服务实现。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -42,6 +44,8 @@ from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantR
|
|||||||
|
|
||||||
|
|
||||||
class RagDatasetServiceImpl(IRagDatasetService):
|
class RagDatasetServiceImpl(IRagDatasetService):
|
||||||
|
"""RAG 知识库服务实现。"""
|
||||||
|
|
||||||
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
|
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
|
||||||
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
|
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
|
||||||
_APP_LINK_SQL = """
|
_APP_LINK_SQL = """
|
||||||
@@ -295,7 +299,32 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
if target_is_default:
|
if target_is_default:
|
||||||
await self._clear_default_flags(session, tenant_code=resolved_tenant_code)
|
await self._clear_default_flags(session, tenant_code=resolved_tenant_code)
|
||||||
elif existing.get("is_default") and Body.get("is_default") is False:
|
elif existing.get("is_default") and Body.get("is_default") is False:
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认")
|
normalized_tenant_code = str(resolved_tenant_code or "").strip()
|
||||||
|
default_filters = [
|
||||||
|
"deleted_at IS NULL",
|
||||||
|
"is_default = TRUE",
|
||||||
|
"id <> :dataset_id",
|
||||||
|
]
|
||||||
|
default_params = {"dataset_id": DatasetId}
|
||||||
|
if normalized_tenant_code:
|
||||||
|
default_filters.append("tenant_code = :tenant_code")
|
||||||
|
default_params["tenant_code"] = normalized_tenant_code
|
||||||
|
else:
|
||||||
|
default_filters.append("(tenant_code IS NULL OR BTRIM(tenant_code) = '')")
|
||||||
|
other_default_count = (
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
SELECT COUNT(1)
|
||||||
|
FROM rag_dataset
|
||||||
|
WHERE {" AND ".join(default_filters)}
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
default_params,
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
if int(other_default_count or 0) <= 0:
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认")
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
"""
|
"""
|
||||||
|
|||||||
+1
-1
Submodule legal-platform-frontend updated: df04238bbb...20c9d4872a
@@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.resilient_clients import ResilientQwenVLMClient
|
||||||
from fastapi_modules.fastapi_leaudit.services.impl.pageQualityServiceImpl import PageQualityServiceImpl
|
from fastapi_modules.fastapi_leaudit.services.impl.pageQualityServiceImpl import PageQualityServiceImpl
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +34,58 @@ async def test_vlm_page_quality_reject_result_is_used():
|
|||||||
assert score == 0.18
|
assert score == 0.18
|
||||||
assert "严重模糊" in reason
|
assert "严重模糊" in reason
|
||||||
assert "只输出 JSON" in service.VlmClient.prompts[0][0]
|
assert "只输出 JSON" in service.VlmClient.prompts[0][0]
|
||||||
|
assert "内嵌照片" in service.VlmClient.prompts[0][0]
|
||||||
|
assert "即使页面周边文字清楚" in service.VlmClient.prompts[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vlm_page_quality_embedded_evidence_blur_cannot_pass():
|
||||||
|
service = PageQualityServiceImpl()
|
||||||
|
service.VlmClient = _FakeVlmClient(
|
||||||
|
{
|
||||||
|
"quality_status": "疑似模糊",
|
||||||
|
"quality_score": "0.42",
|
||||||
|
"message": "内嵌证据照片主体发虚,门头文字不易辨认",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
status, score, reason = await service._classify_page_image_by_vlm(b"image-bytes")
|
||||||
|
|
||||||
|
assert status == "review"
|
||||||
|
assert score == 0.42
|
||||||
|
assert "内嵌证据照片" in reason
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vlm_page_quality_chinese_reject_status_is_supported():
|
||||||
|
service = PageQualityServiceImpl()
|
||||||
|
service.VlmClient = _FakeVlmClient(
|
||||||
|
{
|
||||||
|
"result": "不通过",
|
||||||
|
"confidence": 0.1,
|
||||||
|
"detail": "证据照片严重模糊,关键场所无法辨认",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
status, score, reason = await service._classify_page_image_by_vlm(b"image-bytes")
|
||||||
|
|
||||||
|
assert status == "reject"
|
||||||
|
assert score == 0.1
|
||||||
|
assert "严重模糊" in reason
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vlm_page_quality_json_string_result_is_supported():
|
||||||
|
service = PageQualityServiceImpl()
|
||||||
|
service.VlmClient = _FakeVlmClient(
|
||||||
|
'{"status":"review","score":0.33,"reason":"页面内照片模糊"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
status, score, reason = await service._classify_page_image_by_vlm(b"image-bytes")
|
||||||
|
|
||||||
|
assert status == "review"
|
||||||
|
assert score == 0.33
|
||||||
|
assert reason == "页面内照片模糊"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -56,3 +110,32 @@ async def test_vlm_page_quality_error_falls_back_to_review_not_pass():
|
|||||||
assert status == "review"
|
assert status == "review"
|
||||||
assert score == 0.5
|
assert score == 0.5
|
||||||
assert "VLM图片质量检测失败" in reason
|
assert "VLM图片质量检测失败" in reason
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resilient_vlm_extract_multifield_keeps_raw_text_when_json_parse_fails(monkeypatch):
|
||||||
|
client = ResilientQwenVLMClient(base_url="http://example.test", api_key="x", model="vlm-test")
|
||||||
|
|
||||||
|
async def fake_post_with_retry(payload):
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": "疑似模糊:内嵌证据照片主体发虚,建议人工复核",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(client, "_post_with_retry", fake_post_with_retry)
|
||||||
|
|
||||||
|
result = await client.extract_multifield(
|
||||||
|
prompt="图片质量检测",
|
||||||
|
images_data_urls=["data:image/png;base64,xxx"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["result"].startswith("疑似模糊")
|
||||||
|
assert "内嵌证据照片" in result["reason"]
|
||||||
|
|||||||
@@ -0,0 +1,91 @@
|
|||||||
|
"""RAG 聊天流式引用来源测试。"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
||||||
|
|
||||||
|
|
||||||
|
async def _fake_generate_stream(**kwargs):
|
||||||
|
yield "data: " + json.dumps(
|
||||||
|
{
|
||||||
|
"event": "message",
|
||||||
|
"task_id": kwargs["task_id"],
|
||||||
|
"message_id": kwargs["message_id"],
|
||||||
|
"conversation_id": kwargs["conversation_id"],
|
||||||
|
"answer": "测试回答",
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
) + "\n\n"
|
||||||
|
yield "data: " + json.dumps(
|
||||||
|
{
|
||||||
|
"event": "message_end",
|
||||||
|
"task_id": kwargs["task_id"],
|
||||||
|
"message_id": kwargs["message_id"],
|
||||||
|
"conversation_id": kwargs["conversation_id"],
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
) + "\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_streaming_task() -> list[dict]:
|
||||||
|
service = RagChatServiceImpl()
|
||||||
|
task_id = "task-test"
|
||||||
|
service._task_events[task_id] = []
|
||||||
|
service._task_done[task_id] = False
|
||||||
|
service._task_locks[task_id] = asyncio.Lock()
|
||||||
|
|
||||||
|
async def fake_retrieve_context(dataset_id, query):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"document_id": "doc-1",
|
||||||
|
"document_name": "引用文档.pdf",
|
||||||
|
"source": "引用文档.pdf",
|
||||||
|
"id": "segment-1",
|
||||||
|
"score": 0.91,
|
||||||
|
"hit_count": 3,
|
||||||
|
"text": "引用片段内容",
|
||||||
|
}
|
||||||
|
], "测试知识库"
|
||||||
|
|
||||||
|
async def noop_async(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
service._retrieve_context = fake_retrieve_context
|
||||||
|
service._finalize_message_record = noop_async
|
||||||
|
service._maybe_schedule_auto_title = noop_async
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||||
|
_fake_generate_stream,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||||
|
return_value=[],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await service._run_message_task(
|
||||||
|
task_id=task_id,
|
||||||
|
conversation_id="conversation-test",
|
||||||
|
message_id="message-test",
|
||||||
|
query="测试问题",
|
||||||
|
app={"dataset_id": 7},
|
||||||
|
)
|
||||||
|
|
||||||
|
return service._task_events[task_id]
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_message_end_includes_retriever_resources():
|
||||||
|
events = asyncio.run(_run_streaming_task())
|
||||||
|
|
||||||
|
message_end = next(event for event in events if event.get("event") == "message_end")
|
||||||
|
|
||||||
|
resources = message_end["metadata"].get("retriever_resources")
|
||||||
|
assert resources
|
||||||
|
assert resources[0]["dataset_id"] == "7"
|
||||||
|
assert resources[0]["dataset_name"] == "测试知识库"
|
||||||
|
assert resources[0]["document_name"] == "引用文档.pdf"
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
"""RAG 知识库默认标记测试。"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||||
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||||
|
|
||||||
|
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
||||||
|
|
||||||
|
|
||||||
|
class _ScalarResult:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def scalar_one(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class _ExecuteResult:
|
||||||
|
def mappings(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return {"id": 1}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, other_default_count: int):
|
||||||
|
self.other_default_count = other_default_count
|
||||||
|
self.executed_sql = []
|
||||||
|
self.executed_params = []
|
||||||
|
|
||||||
|
async def execute(self, sql, params=None):
|
||||||
|
sql_text = str(sql)
|
||||||
|
self.executed_sql.append(sql_text)
|
||||||
|
self.executed_params.append(params or {})
|
||||||
|
if "COUNT(1)" in sql_text and "is_default = TRUE" in sql_text:
|
||||||
|
return _ScalarResult(self.other_default_count)
|
||||||
|
return _ExecuteResult()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSessionContext:
|
||||||
|
def __init__(self, session):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_default_flag_with_other_count(other_default_count: int):
|
||||||
|
service = RagDatasetServiceImpl()
|
||||||
|
fake_session = _FakeSession(other_default_count=other_default_count)
|
||||||
|
existing_row = {
|
||||||
|
"id": 10,
|
||||||
|
"name": "旧默认知识库",
|
||||||
|
"description": "",
|
||||||
|
"area": "梅州",
|
||||||
|
"tenant_code": "MZ",
|
||||||
|
"tenant_name": "梅州",
|
||||||
|
"is_public": False,
|
||||||
|
"is_default": True,
|
||||||
|
"status": 1,
|
||||||
|
"document_count": 0,
|
||||||
|
"total_chunks": 0,
|
||||||
|
"chunk_max_size": 800,
|
||||||
|
"chunk_min_size": 20,
|
||||||
|
"sort_order": 0,
|
||||||
|
"retrieval_model": {},
|
||||||
|
"created_at": None,
|
||||||
|
"updated_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_get_dataset_row(dataset_id):
|
||||||
|
return existing_row
|
||||||
|
|
||||||
|
async def fake_resolve_tenant_context(**kwargs):
|
||||||
|
return {"tenant_code": "MZ", "tenant_name": "梅州", "tenant_type": "CUSTOM", "area": "梅州"}
|
||||||
|
|
||||||
|
async def fake_resolve_dataset_area_input(**kwargs):
|
||||||
|
return "梅州", "MZ", object()
|
||||||
|
|
||||||
|
async def noop_async(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
service._get_dataset_row = fake_get_dataset_row
|
||||||
|
service._resolve_tenant_context = fake_resolve_tenant_context
|
||||||
|
service._resolve_dataset_area_input = fake_resolve_dataset_area_input
|
||||||
|
service._assert_manage_area_scope = noop_async
|
||||||
|
service._ensure_rag_tenant_schema = noop_async
|
||||||
|
service._ensure_linked_app = noop_async
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl.GetAsyncSession",
|
||||||
|
return_value=_FakeSessionContext(fake_session),
|
||||||
|
):
|
||||||
|
result = await service.UpdateAdminDataset(
|
||||||
|
CurrentUserId=1,
|
||||||
|
UserArea="梅州",
|
||||||
|
UserRole="super_admin",
|
||||||
|
TenantCode="MZ",
|
||||||
|
TenantName="梅州",
|
||||||
|
DatasetId=10,
|
||||||
|
Body={"is_default": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
return result, fake_session
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_admin_dataset_allows_unsetting_default_when_same_tenant_has_another_default():
|
||||||
|
result, fake_session = asyncio.run(_update_default_flag_with_other_count(other_default_count=1))
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert any("UPDATE rag_dataset" in sql_text for sql_text in fake_session.executed_sql)
|
||||||
|
assert any("tenant_code = :tenant_code" in sql_text for sql_text in fake_session.executed_sql)
|
||||||
|
assert any(params.get("tenant_code") == "MZ" for params in fake_session.executed_params)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_admin_dataset_rejects_unsetting_only_default_in_tenant():
|
||||||
|
try:
|
||||||
|
asyncio.run(_update_default_flag_with_other_count(other_default_count=0))
|
||||||
|
assert False, "expected LeauditException"
|
||||||
|
except LeauditException as exc:
|
||||||
|
assert exc.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
|
||||||
Reference in New Issue
Block a user