359 lines
16 KiB
Python
359 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from urllib.parse import quote_plus
|
|
|
|
from sqlalchemy import bindparam, text
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
|
|
from fastapi_admin.config import DB_HOST, DB_PASSWORD, DB_PORT, DB_USER
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
|
from fastapi_modules.fastapi_leaudit.domian.Dto.promptTemplateDto import PromptTemplateCreateDTO, PromptTemplateUpdateDTO
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.promptTemplateVo import PromptTemplateListVO, PromptTemplateTypeListVO, PromptTemplateTypeOptionVO, PromptTemplateVO
|
|
from fastapi_modules.fastapi_leaudit.services.promptTemplateService import IPromptTemplateService
|
|
|
|
_LEGACY_DB_NAME = os.getenv("LEGACY_RULE_DB_NAME", "docauditai")
|
|
_LEGACY_DB_URL = (
|
|
f"postgresql+asyncpg://{quote_plus(str(DB_USER))}:{quote_plus(str(DB_PASSWORD))}"
|
|
f"@{DB_HOST}:{DB_PORT}/{quote_plus(_LEGACY_DB_NAME)}"
|
|
)
|
|
_LEGACY_ENGINE = create_async_engine(_LEGACY_DB_URL, pool_pre_ping=True)
|
|
_LegacySession = async_sessionmaker(_LEGACY_ENGINE, expire_on_commit=False)
|
|
|
|
_ALLOWED_TEMPLATE_TYPES = {"LLM_Extraction", "VLM_Extraction", "Evaluation", "Summary", "Common"}
|
|
_TYPE_LABELS = {
|
|
"LLM_Extraction": "LLM抽取",
|
|
"VLM_Extraction": "VLM抽取",
|
|
"Evaluation": "评查",
|
|
"Summary": "摘要",
|
|
"Common": "通用",
|
|
}
|
|
|
|
|
|
class PromptTemplateServiceImpl(IPromptTemplateService):
|
|
async def ListTemplates(
|
|
self,
|
|
Search: str | None,
|
|
TemplateTypes: list[str] | None,
|
|
Status: int | None,
|
|
Page: int,
|
|
PageSize: int,
|
|
) -> PromptTemplateListVO:
|
|
offset = max(Page - 1, 0) * PageSize
|
|
filters = ["1=1"]
|
|
params: dict[str, Any] = {"limit": PageSize, "offset": offset}
|
|
|
|
if Search:
|
|
filters.append("(pt.template_name ILIKE :search OR pt.template_code ILIKE :search)")
|
|
params["search"] = f"%{Search.strip()}%"
|
|
normalized_types = [item for item in (TemplateTypes or []) if item in _ALLOWED_TEMPLATE_TYPES]
|
|
if normalized_types:
|
|
filters.append("pt.template_type IN :template_types")
|
|
params["template_types"] = tuple(normalized_types)
|
|
if Status is not None:
|
|
filters.append("pt.status = :status")
|
|
params["status"] = Status
|
|
|
|
where_clause = " AND ".join(filters)
|
|
|
|
count_sql = text(f"SELECT COUNT(*) FROM prompt_templates pt WHERE {where_clause}")
|
|
list_sql = text(
|
|
f"""
|
|
SELECT
|
|
pt.id,
|
|
pt.template_name,
|
|
pt.template_type,
|
|
pt.description,
|
|
pt.template_content,
|
|
pt.variables,
|
|
pt.status,
|
|
pt.version,
|
|
pt.created_by,
|
|
pt.created_at,
|
|
pt.updated_at,
|
|
pt.template_code,
|
|
pt.template_abbreviation
|
|
FROM prompt_templates pt
|
|
WHERE {where_clause}
|
|
ORDER BY pt.updated_at DESC, pt.id DESC
|
|
LIMIT :limit OFFSET :offset
|
|
"""
|
|
)
|
|
if normalized_types:
|
|
count_sql = count_sql.bindparams(bindparam("template_types", expanding=True))
|
|
list_sql = list_sql.bindparams(bindparam("template_types", expanding=True))
|
|
|
|
async with _LegacySession() as session:
|
|
total = int((await session.execute(count_sql, params)).scalar_one())
|
|
rows = (
|
|
await session.execute(
|
|
list_sql,
|
|
params,
|
|
)
|
|
).mappings().all()
|
|
|
|
usernames = await self._load_usernames([int(row["created_by"]) for row in rows if row.get("created_by") is not None])
|
|
items = [self._to_vo(row, usernames.get(int(row["created_by"]))) for row in rows]
|
|
return PromptTemplateListVO(total=total, page=Page, page_size=PageSize, items=items)
|
|
|
|
async def GetTemplate(self, TemplateId: int) -> PromptTemplateVO:
|
|
async with _LegacySession() as session:
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
id,
|
|
template_name,
|
|
template_type,
|
|
description,
|
|
template_content,
|
|
variables,
|
|
status,
|
|
version,
|
|
created_by,
|
|
created_at,
|
|
updated_at,
|
|
template_code,
|
|
template_abbreviation
|
|
FROM prompt_templates
|
|
WHERE id = :id
|
|
"""
|
|
),
|
|
{"id": TemplateId},
|
|
)
|
|
).mappings().first()
|
|
if not row:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "提示词模板不存在")
|
|
usernames = await self._load_usernames([int(row["created_by"])] if row.get("created_by") is not None else [])
|
|
return self._to_vo(row, usernames.get(int(row["created_by"])) if row.get("created_by") is not None else None)
|
|
|
|
async def CreateTemplate(self, Body: PromptTemplateCreateDTO) -> PromptTemplateVO:
|
|
await self._validate_template_payload(Body.template_type, Body.template_code, Body.template_abbreviation)
|
|
await self._ensure_template_code_unique(Body.template_code)
|
|
now = datetime.utcnow()
|
|
payload = self._build_write_payload(Body)
|
|
payload.update({"created_at": now, "updated_at": now})
|
|
async with _LegacySession() as session:
|
|
async with session.begin():
|
|
new_id = await session.scalar(
|
|
text(
|
|
"""
|
|
INSERT INTO prompt_templates (
|
|
template_name,
|
|
template_type,
|
|
description,
|
|
template_content,
|
|
variables,
|
|
status,
|
|
version,
|
|
created_by,
|
|
created_at,
|
|
updated_at,
|
|
template_code,
|
|
template_abbreviation
|
|
) VALUES (
|
|
:template_name,
|
|
:template_type,
|
|
:description,
|
|
:template_content,
|
|
CAST(:variables AS jsonb),
|
|
:status,
|
|
:version,
|
|
:created_by,
|
|
:created_at,
|
|
:updated_at,
|
|
:template_code,
|
|
:template_abbreviation
|
|
) RETURNING id
|
|
"""
|
|
),
|
|
payload,
|
|
)
|
|
await session.commit()
|
|
return await self.GetTemplate(int(new_id))
|
|
|
|
async def UpdateTemplate(self, TemplateId: int, Body: PromptTemplateUpdateDTO) -> PromptTemplateVO:
|
|
current = await self.GetTemplate(TemplateId)
|
|
payload = Body.model_dump(exclude_unset=True)
|
|
if not payload:
|
|
return current
|
|
new_type = payload.get("template_type", current.template_type)
|
|
new_code = payload.get("template_code", current.template_code)
|
|
new_abbr = payload.get("template_abbreviation", current.template_abbreviation)
|
|
await self._validate_template_payload(new_type, new_code, new_abbr)
|
|
if "template_code" in payload:
|
|
await self._ensure_template_code_unique(payload.get("template_code"), TemplateId)
|
|
|
|
updates: list[str] = []
|
|
params: dict[str, Any] = {"id": TemplateId, "updated_at": datetime.utcnow()}
|
|
simple_fields = [
|
|
"template_name",
|
|
"template_type",
|
|
"description",
|
|
"template_content",
|
|
"status",
|
|
"version",
|
|
"template_code",
|
|
"template_abbreviation",
|
|
]
|
|
for field in simple_fields:
|
|
if field in payload:
|
|
params[field] = payload[field]
|
|
updates.append(f"{field} = :{field}")
|
|
if "variables" in payload:
|
|
params["variables"] = json.dumps(payload.get("variables") or {}, ensure_ascii=False)
|
|
updates.append("variables = CAST(:variables AS jsonb)")
|
|
updates.append("updated_at = :updated_at")
|
|
|
|
async with _LegacySession() as session:
|
|
async with session.begin():
|
|
await session.execute(text(f"UPDATE prompt_templates SET {', '.join(updates)} WHERE id = :id"), params)
|
|
await session.commit()
|
|
return await self.GetTemplate(TemplateId)
|
|
|
|
async def DeleteTemplate(self, TemplateId: int) -> None:
|
|
await self.GetTemplate(TemplateId)
|
|
async with _LegacySession() as session:
|
|
async with session.begin():
|
|
await session.execute(text("DELETE FROM prompt_templates WHERE id = :id"), {"id": TemplateId})
|
|
await session.commit()
|
|
|
|
async def GetTemplateTypes(self) -> PromptTemplateTypeListVO:
|
|
async with _LegacySession() as session:
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT template_type, COUNT(*)::int AS count
|
|
FROM prompt_templates
|
|
WHERE template_type IS NOT NULL AND TRIM(template_type) <> ''
|
|
GROUP BY template_type
|
|
ORDER BY template_type ASC
|
|
"""
|
|
)
|
|
)
|
|
).mappings().all()
|
|
items = [
|
|
PromptTemplateTypeOptionVO(
|
|
value=str(row["template_type"]),
|
|
label=_TYPE_LABELS.get(str(row["template_type"]), str(row["template_type"])),
|
|
count=int(row["count"] or 0),
|
|
)
|
|
for row in rows
|
|
]
|
|
return PromptTemplateTypeListVO(items=items)
|
|
|
|
async def DuplicateTemplate(self, TemplateId: int, NewCode: str | None) -> PromptTemplateVO:
|
|
current = await self.GetTemplate(TemplateId)
|
|
code = (NewCode or "").strip() or self._generate_copy_code(current.template_code or f"template_{TemplateId}")
|
|
await self._ensure_template_code_unique(code)
|
|
body = PromptTemplateCreateDTO(
|
|
template_name=f"{current.template_name}-副本",
|
|
template_type=current.template_type,
|
|
description=current.description,
|
|
template_content=current.template_content,
|
|
variables=current.variables,
|
|
status=current.status,
|
|
version=current.version,
|
|
created_by=current.created_by,
|
|
template_code=code,
|
|
template_abbreviation=current.template_abbreviation,
|
|
)
|
|
return await self.CreateTemplate(body)
|
|
|
|
async def _load_usernames(self, user_ids: list[int]) -> dict[int, str]:
|
|
ids = sorted({int(item) for item in user_ids if item is not None})
|
|
if not ids:
|
|
return {}
|
|
async with GetAsyncSession() as session:
|
|
rows = (
|
|
await session.execute(
|
|
text("SELECT id, username FROM sso_users WHERE id IN :ids").bindparams(bindparam("ids", expanding=True)),
|
|
{"ids": ids},
|
|
)
|
|
).mappings().all()
|
|
return {int(row["id"]): str(row.get("username") or "") for row in rows}
|
|
|
|
async def _validate_template_payload(self, template_type: str | None, template_code: str | None, template_abbreviation: str | None) -> None:
|
|
if template_type not in _ALLOWED_TEMPLATE_TYPES:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板类型不合法")
|
|
if template_type == "VLM_Extraction":
|
|
if not str(template_code or "").strip():
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板code")
|
|
if not str(template_abbreviation or "").strip():
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板简称")
|
|
|
|
async def _ensure_template_code_unique(self, template_code: str | None, template_id: int | None = None) -> None:
|
|
code = str(template_code or "").strip()
|
|
if not code:
|
|
return
|
|
sql = "SELECT id FROM prompt_templates WHERE LOWER(template_code) = LOWER(:template_code)"
|
|
params: dict[str, Any] = {"template_code": code}
|
|
if template_id is not None:
|
|
sql += " AND id <> :id"
|
|
params["id"] = template_id
|
|
async with _LegacySession() as session:
|
|
exists = (await session.execute(text(sql), params)).scalar_one_or_none()
|
|
if exists:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板code已存在")
|
|
|
|
def _build_write_payload(self, body: PromptTemplateCreateDTO) -> dict[str, Any]:
|
|
return {
|
|
"template_name": body.template_name.strip(),
|
|
"template_type": body.template_type,
|
|
"description": (body.description or "").strip() or None,
|
|
"template_content": body.template_content,
|
|
"variables": json.dumps(body.variables or {}, ensure_ascii=False),
|
|
"status": int(body.status),
|
|
"version": str(body.version or "v1.0").strip() or "v1.0",
|
|
"created_by": int(body.created_by) if body.created_by is not None else None,
|
|
"template_code": (body.template_code or "").strip() or None,
|
|
"template_abbreviation": (body.template_abbreviation or "").strip() or None,
|
|
}
|
|
|
|
def _to_vo(self, row: dict[str, Any], username: str | None) -> PromptTemplateVO:
|
|
return PromptTemplateVO(
|
|
id=int(row["id"]),
|
|
template_name=str(row.get("template_name") or ""),
|
|
template_code=row.get("template_code"),
|
|
template_type=str(row.get("template_type") or ""),
|
|
description=row.get("description"),
|
|
template_content=str(row.get("template_content") or ""),
|
|
template_abbreviation=row.get("template_abbreviation"),
|
|
variables=self._parse_variables(row.get("variables")),
|
|
status=int(row.get("status") or 0),
|
|
version=str(row.get("version") or "v1.0"),
|
|
created_by=int(row["created_by"]) if row.get("created_by") is not None else None,
|
|
created_by_username=username,
|
|
created_at=self._format_datetime(row.get("created_at")),
|
|
updated_at=self._format_datetime(row.get("updated_at")),
|
|
)
|
|
|
|
def _parse_variables(self, value: Any) -> dict[str, str]:
|
|
if value is None:
|
|
return {}
|
|
if isinstance(value, dict):
|
|
return {str(k): str(v) for k, v in value.items()}
|
|
if isinstance(value, str):
|
|
try:
|
|
parsed = json.loads(value)
|
|
if isinstance(parsed, dict):
|
|
return {str(k): str(v) for k, v in parsed.items()}
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
return {}
|
|
|
|
def _format_datetime(self, value: Any) -> str:
|
|
if isinstance(value, datetime):
|
|
return value.isoformat()
|
|
return str(value or "")
|
|
|
|
def _generate_copy_code(self, code: str) -> str:
|
|
return f"{code}_copy_{int(datetime.utcnow().timestamp())}"
|