Files

359 lines
16 KiB
Python

from __future__ import annotations
import json
import os
from datetime import datetime
from typing import Any
from urllib.parse import quote_plus
from sqlalchemy import bindparam, text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from fastapi_admin.config import DB_HOST, DB_PASSWORD, DB_PORT, DB_USER
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_modules.fastapi_leaudit.domian.Dto.promptTemplateDto import PromptTemplateCreateDTO, PromptTemplateUpdateDTO
from fastapi_modules.fastapi_leaudit.domian.vo.promptTemplateVo import PromptTemplateListVO, PromptTemplateTypeListVO, PromptTemplateTypeOptionVO, PromptTemplateVO
from fastapi_modules.fastapi_leaudit.services.promptTemplateService import IPromptTemplateService
_LEGACY_DB_NAME = os.getenv("LEGACY_RULE_DB_NAME", "docauditai")
_LEGACY_DB_URL = (
f"postgresql+asyncpg://{quote_plus(str(DB_USER))}:{quote_plus(str(DB_PASSWORD))}"
f"@{DB_HOST}:{DB_PORT}/{quote_plus(_LEGACY_DB_NAME)}"
)
_LEGACY_ENGINE = create_async_engine(_LEGACY_DB_URL, pool_pre_ping=True)
_LegacySession = async_sessionmaker(_LEGACY_ENGINE, expire_on_commit=False)
_ALLOWED_TEMPLATE_TYPES = {"LLM_Extraction", "VLM_Extraction", "Evaluation", "Summary", "Common"}
_TYPE_LABELS = {
"LLM_Extraction": "LLM抽取",
"VLM_Extraction": "VLM抽取",
"Evaluation": "评查",
"Summary": "摘要",
"Common": "通用",
}
class PromptTemplateServiceImpl(IPromptTemplateService):
async def ListTemplates(
self,
Search: str | None,
TemplateTypes: list[str] | None,
Status: int | None,
Page: int,
PageSize: int,
) -> PromptTemplateListVO:
offset = max(Page - 1, 0) * PageSize
filters = ["1=1"]
params: dict[str, Any] = {"limit": PageSize, "offset": offset}
if Search:
filters.append("(pt.template_name ILIKE :search OR pt.template_code ILIKE :search)")
params["search"] = f"%{Search.strip()}%"
normalized_types = [item for item in (TemplateTypes or []) if item in _ALLOWED_TEMPLATE_TYPES]
if normalized_types:
filters.append("pt.template_type IN :template_types")
params["template_types"] = tuple(normalized_types)
if Status is not None:
filters.append("pt.status = :status")
params["status"] = Status
where_clause = " AND ".join(filters)
count_sql = text(f"SELECT COUNT(*) FROM prompt_templates pt WHERE {where_clause}")
list_sql = text(
f"""
SELECT
pt.id,
pt.template_name,
pt.template_type,
pt.description,
pt.template_content,
pt.variables,
pt.status,
pt.version,
pt.created_by,
pt.created_at,
pt.updated_at,
pt.template_code,
pt.template_abbreviation
FROM prompt_templates pt
WHERE {where_clause}
ORDER BY pt.updated_at DESC, pt.id DESC
LIMIT :limit OFFSET :offset
"""
)
if normalized_types:
count_sql = count_sql.bindparams(bindparam("template_types", expanding=True))
list_sql = list_sql.bindparams(bindparam("template_types", expanding=True))
async with _LegacySession() as session:
total = int((await session.execute(count_sql, params)).scalar_one())
rows = (
await session.execute(
list_sql,
params,
)
).mappings().all()
usernames = await self._load_usernames([int(row["created_by"]) for row in rows if row.get("created_by") is not None])
items = [self._to_vo(row, usernames.get(int(row["created_by"]))) for row in rows]
return PromptTemplateListVO(total=total, page=Page, page_size=PageSize, items=items)
async def GetTemplate(self, TemplateId: int) -> PromptTemplateVO:
async with _LegacySession() as session:
row = (
await session.execute(
text(
"""
SELECT
id,
template_name,
template_type,
description,
template_content,
variables,
status,
version,
created_by,
created_at,
updated_at,
template_code,
template_abbreviation
FROM prompt_templates
WHERE id = :id
"""
),
{"id": TemplateId},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "提示词模板不存在")
usernames = await self._load_usernames([int(row["created_by"])] if row.get("created_by") is not None else [])
return self._to_vo(row, usernames.get(int(row["created_by"])) if row.get("created_by") is not None else None)
async def CreateTemplate(self, Body: PromptTemplateCreateDTO) -> PromptTemplateVO:
await self._validate_template_payload(Body.template_type, Body.template_code, Body.template_abbreviation)
await self._ensure_template_code_unique(Body.template_code)
now = datetime.utcnow()
payload = self._build_write_payload(Body)
payload.update({"created_at": now, "updated_at": now})
async with _LegacySession() as session:
async with session.begin():
new_id = await session.scalar(
text(
"""
INSERT INTO prompt_templates (
template_name,
template_type,
description,
template_content,
variables,
status,
version,
created_by,
created_at,
updated_at,
template_code,
template_abbreviation
) VALUES (
:template_name,
:template_type,
:description,
:template_content,
CAST(:variables AS jsonb),
:status,
:version,
:created_by,
:created_at,
:updated_at,
:template_code,
:template_abbreviation
) RETURNING id
"""
),
payload,
)
await session.commit()
return await self.GetTemplate(int(new_id))
async def UpdateTemplate(self, TemplateId: int, Body: PromptTemplateUpdateDTO) -> PromptTemplateVO:
current = await self.GetTemplate(TemplateId)
payload = Body.model_dump(exclude_unset=True)
if not payload:
return current
new_type = payload.get("template_type", current.template_type)
new_code = payload.get("template_code", current.template_code)
new_abbr = payload.get("template_abbreviation", current.template_abbreviation)
await self._validate_template_payload(new_type, new_code, new_abbr)
if "template_code" in payload:
await self._ensure_template_code_unique(payload.get("template_code"), TemplateId)
updates: list[str] = []
params: dict[str, Any] = {"id": TemplateId, "updated_at": datetime.utcnow()}
simple_fields = [
"template_name",
"template_type",
"description",
"template_content",
"status",
"version",
"template_code",
"template_abbreviation",
]
for field in simple_fields:
if field in payload:
params[field] = payload[field]
updates.append(f"{field} = :{field}")
if "variables" in payload:
params["variables"] = json.dumps(payload.get("variables") or {}, ensure_ascii=False)
updates.append("variables = CAST(:variables AS jsonb)")
updates.append("updated_at = :updated_at")
async with _LegacySession() as session:
async with session.begin():
await session.execute(text(f"UPDATE prompt_templates SET {', '.join(updates)} WHERE id = :id"), params)
await session.commit()
return await self.GetTemplate(TemplateId)
async def DeleteTemplate(self, TemplateId: int) -> None:
await self.GetTemplate(TemplateId)
async with _LegacySession() as session:
async with session.begin():
await session.execute(text("DELETE FROM prompt_templates WHERE id = :id"), {"id": TemplateId})
await session.commit()
async def GetTemplateTypes(self) -> PromptTemplateTypeListVO:
async with _LegacySession() as session:
rows = (
await session.execute(
text(
"""
SELECT template_type, COUNT(*)::int AS count
FROM prompt_templates
WHERE template_type IS NOT NULL AND TRIM(template_type) <> ''
GROUP BY template_type
ORDER BY template_type ASC
"""
)
)
).mappings().all()
items = [
PromptTemplateTypeOptionVO(
value=str(row["template_type"]),
label=_TYPE_LABELS.get(str(row["template_type"]), str(row["template_type"])),
count=int(row["count"] or 0),
)
for row in rows
]
return PromptTemplateTypeListVO(items=items)
async def DuplicateTemplate(self, TemplateId: int, NewCode: str | None) -> PromptTemplateVO:
current = await self.GetTemplate(TemplateId)
code = (NewCode or "").strip() or self._generate_copy_code(current.template_code or f"template_{TemplateId}")
await self._ensure_template_code_unique(code)
body = PromptTemplateCreateDTO(
template_name=f"{current.template_name}-副本",
template_type=current.template_type,
description=current.description,
template_content=current.template_content,
variables=current.variables,
status=current.status,
version=current.version,
created_by=current.created_by,
template_code=code,
template_abbreviation=current.template_abbreviation,
)
return await self.CreateTemplate(body)
async def _load_usernames(self, user_ids: list[int]) -> dict[int, str]:
ids = sorted({int(item) for item in user_ids if item is not None})
if not ids:
return {}
async with GetAsyncSession() as session:
rows = (
await session.execute(
text("SELECT id, username FROM sso_users WHERE id IN :ids").bindparams(bindparam("ids", expanding=True)),
{"ids": ids},
)
).mappings().all()
return {int(row["id"]): str(row.get("username") or "") for row in rows}
async def _validate_template_payload(self, template_type: str | None, template_code: str | None, template_abbreviation: str | None) -> None:
if template_type not in _ALLOWED_TEMPLATE_TYPES:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板类型不合法")
if template_type == "VLM_Extraction":
if not str(template_code or "").strip():
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板code")
if not str(template_abbreviation or "").strip():
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板简称")
async def _ensure_template_code_unique(self, template_code: str | None, template_id: int | None = None) -> None:
code = str(template_code or "").strip()
if not code:
return
sql = "SELECT id FROM prompt_templates WHERE LOWER(template_code) = LOWER(:template_code)"
params: dict[str, Any] = {"template_code": code}
if template_id is not None:
sql += " AND id <> :id"
params["id"] = template_id
async with _LegacySession() as session:
exists = (await session.execute(text(sql), params)).scalar_one_or_none()
if exists:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板code已存在")
def _build_write_payload(self, body: PromptTemplateCreateDTO) -> dict[str, Any]:
return {
"template_name": body.template_name.strip(),
"template_type": body.template_type,
"description": (body.description or "").strip() or None,
"template_content": body.template_content,
"variables": json.dumps(body.variables or {}, ensure_ascii=False),
"status": int(body.status),
"version": str(body.version or "v1.0").strip() or "v1.0",
"created_by": int(body.created_by) if body.created_by is not None else None,
"template_code": (body.template_code or "").strip() or None,
"template_abbreviation": (body.template_abbreviation or "").strip() or None,
}
def _to_vo(self, row: dict[str, Any], username: str | None) -> PromptTemplateVO:
return PromptTemplateVO(
id=int(row["id"]),
template_name=str(row.get("template_name") or ""),
template_code=row.get("template_code"),
template_type=str(row.get("template_type") or ""),
description=row.get("description"),
template_content=str(row.get("template_content") or ""),
template_abbreviation=row.get("template_abbreviation"),
variables=self._parse_variables(row.get("variables")),
status=int(row.get("status") or 0),
version=str(row.get("version") or "v1.0"),
created_by=int(row["created_by"]) if row.get("created_by") is not None else None,
created_by_username=username,
created_at=self._format_datetime(row.get("created_at")),
updated_at=self._format_datetime(row.get("updated_at")),
)
def _parse_variables(self, value: Any) -> dict[str, str]:
if value is None:
return {}
if isinstance(value, dict):
return {str(k): str(v) for k, v in value.items()}
if isinstance(value, str):
try:
parsed = json.loads(value)
if isinstance(parsed, dict):
return {str(k): str(v) for k, v in parsed.items()}
except json.JSONDecodeError:
return {}
return {}
def _format_datetime(self, value: Any) -> str:
if isinstance(value, datetime):
return value.isoformat()
return str(value or "")
def _generate_copy_code(self, code: str) -> str:
return f"{code}_copy_{int(datetime.utcnow().timestamp())}"