feat: add backend rule group and permission support
This commit is contained in:
@@ -0,0 +1,358 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from sqlalchemy import bindparam, text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from fastapi_admin.config import DB_HOST, DB_PASSWORD, DB_PORT, DB_USER
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_modules.fastapi_leaudit.domian.Dto.promptTemplateDto import PromptTemplateCreateDTO, PromptTemplateUpdateDTO
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.promptTemplateVo import PromptTemplateListVO, PromptTemplateTypeListVO, PromptTemplateTypeOptionVO, PromptTemplateVO
|
||||
from fastapi_modules.fastapi_leaudit.services.promptTemplateService import IPromptTemplateService
|
||||
|
||||
_LEGACY_DB_NAME = os.getenv("LEGACY_RULE_DB_NAME", "docauditai")
|
||||
_LEGACY_DB_URL = (
|
||||
f"postgresql+asyncpg://{quote_plus(str(DB_USER))}:{quote_plus(str(DB_PASSWORD))}"
|
||||
f"@{DB_HOST}:{DB_PORT}/{quote_plus(_LEGACY_DB_NAME)}"
|
||||
)
|
||||
_LEGACY_ENGINE = create_async_engine(_LEGACY_DB_URL, pool_pre_ping=True)
|
||||
_LegacySession = async_sessionmaker(_LEGACY_ENGINE, expire_on_commit=False)
|
||||
|
||||
_ALLOWED_TEMPLATE_TYPES = {"LLM_Extraction", "VLM_Extraction", "Evaluation", "Summary", "Common"}
|
||||
_TYPE_LABELS = {
|
||||
"LLM_Extraction": "LLM抽取",
|
||||
"VLM_Extraction": "VLM抽取",
|
||||
"Evaluation": "评查",
|
||||
"Summary": "摘要",
|
||||
"Common": "通用",
|
||||
}
|
||||
|
||||
|
||||
class PromptTemplateServiceImpl(IPromptTemplateService):
|
||||
async def ListTemplates(
|
||||
self,
|
||||
Search: str | None,
|
||||
TemplateTypes: list[str] | None,
|
||||
Status: int | None,
|
||||
Page: int,
|
||||
PageSize: int,
|
||||
) -> PromptTemplateListVO:
|
||||
offset = max(Page - 1, 0) * PageSize
|
||||
filters = ["1=1"]
|
||||
params: dict[str, Any] = {"limit": PageSize, "offset": offset}
|
||||
|
||||
if Search:
|
||||
filters.append("(pt.template_name ILIKE :search OR pt.template_code ILIKE :search)")
|
||||
params["search"] = f"%{Search.strip()}%"
|
||||
normalized_types = [item for item in (TemplateTypes or []) if item in _ALLOWED_TEMPLATE_TYPES]
|
||||
if normalized_types:
|
||||
filters.append("pt.template_type IN :template_types")
|
||||
params["template_types"] = tuple(normalized_types)
|
||||
if Status is not None:
|
||||
filters.append("pt.status = :status")
|
||||
params["status"] = Status
|
||||
|
||||
where_clause = " AND ".join(filters)
|
||||
|
||||
count_sql = text(f"SELECT COUNT(*) FROM prompt_templates pt WHERE {where_clause}")
|
||||
list_sql = text(
|
||||
f"""
|
||||
SELECT
|
||||
pt.id,
|
||||
pt.template_name,
|
||||
pt.template_type,
|
||||
pt.description,
|
||||
pt.template_content,
|
||||
pt.variables,
|
||||
pt.status,
|
||||
pt.version,
|
||||
pt.created_by,
|
||||
pt.created_at,
|
||||
pt.updated_at,
|
||||
pt.template_code,
|
||||
pt.template_abbreviation
|
||||
FROM prompt_templates pt
|
||||
WHERE {where_clause}
|
||||
ORDER BY pt.updated_at DESC, pt.id DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""
|
||||
)
|
||||
if normalized_types:
|
||||
count_sql = count_sql.bindparams(bindparam("template_types", expanding=True))
|
||||
list_sql = list_sql.bindparams(bindparam("template_types", expanding=True))
|
||||
|
||||
async with _LegacySession() as session:
|
||||
total = int((await session.execute(count_sql, params)).scalar_one())
|
||||
rows = (
|
||||
await session.execute(
|
||||
list_sql,
|
||||
params,
|
||||
)
|
||||
).mappings().all()
|
||||
|
||||
usernames = await self._load_usernames([int(row["created_by"]) for row in rows if row.get("created_by") is not None])
|
||||
items = [self._to_vo(row, usernames.get(int(row["created_by"]))) for row in rows]
|
||||
return PromptTemplateListVO(total=total, page=Page, page_size=PageSize, items=items)
|
||||
|
||||
async def GetTemplate(self, TemplateId: int) -> PromptTemplateVO:
|
||||
async with _LegacySession() as session:
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
template_name,
|
||||
template_type,
|
||||
description,
|
||||
template_content,
|
||||
variables,
|
||||
status,
|
||||
version,
|
||||
created_by,
|
||||
created_at,
|
||||
updated_at,
|
||||
template_code,
|
||||
template_abbreviation
|
||||
FROM prompt_templates
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": TemplateId},
|
||||
)
|
||||
).mappings().first()
|
||||
if not row:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "提示词模板不存在")
|
||||
usernames = await self._load_usernames([int(row["created_by"])] if row.get("created_by") is not None else [])
|
||||
return self._to_vo(row, usernames.get(int(row["created_by"])) if row.get("created_by") is not None else None)
|
||||
|
||||
async def CreateTemplate(self, Body: PromptTemplateCreateDTO) -> PromptTemplateVO:
|
||||
await self._validate_template_payload(Body.template_type, Body.template_code, Body.template_abbreviation)
|
||||
await self._ensure_template_code_unique(Body.template_code)
|
||||
now = datetime.utcnow()
|
||||
payload = self._build_write_payload(Body)
|
||||
payload.update({"created_at": now, "updated_at": now})
|
||||
async with _LegacySession() as session:
|
||||
async with session.begin():
|
||||
new_id = await session.scalar(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO prompt_templates (
|
||||
template_name,
|
||||
template_type,
|
||||
description,
|
||||
template_content,
|
||||
variables,
|
||||
status,
|
||||
version,
|
||||
created_by,
|
||||
created_at,
|
||||
updated_at,
|
||||
template_code,
|
||||
template_abbreviation
|
||||
) VALUES (
|
||||
:template_name,
|
||||
:template_type,
|
||||
:description,
|
||||
:template_content,
|
||||
CAST(:variables AS jsonb),
|
||||
:status,
|
||||
:version,
|
||||
:created_by,
|
||||
:created_at,
|
||||
:updated_at,
|
||||
:template_code,
|
||||
:template_abbreviation
|
||||
) RETURNING id
|
||||
"""
|
||||
),
|
||||
payload,
|
||||
)
|
||||
await session.commit()
|
||||
return await self.GetTemplate(int(new_id))
|
||||
|
||||
async def UpdateTemplate(self, TemplateId: int, Body: PromptTemplateUpdateDTO) -> PromptTemplateVO:
|
||||
current = await self.GetTemplate(TemplateId)
|
||||
payload = Body.model_dump(exclude_unset=True)
|
||||
if not payload:
|
||||
return current
|
||||
new_type = payload.get("template_type", current.template_type)
|
||||
new_code = payload.get("template_code", current.template_code)
|
||||
new_abbr = payload.get("template_abbreviation", current.template_abbreviation)
|
||||
await self._validate_template_payload(new_type, new_code, new_abbr)
|
||||
if "template_code" in payload:
|
||||
await self._ensure_template_code_unique(payload.get("template_code"), TemplateId)
|
||||
|
||||
updates: list[str] = []
|
||||
params: dict[str, Any] = {"id": TemplateId, "updated_at": datetime.utcnow()}
|
||||
simple_fields = [
|
||||
"template_name",
|
||||
"template_type",
|
||||
"description",
|
||||
"template_content",
|
||||
"status",
|
||||
"version",
|
||||
"template_code",
|
||||
"template_abbreviation",
|
||||
]
|
||||
for field in simple_fields:
|
||||
if field in payload:
|
||||
params[field] = payload[field]
|
||||
updates.append(f"{field} = :{field}")
|
||||
if "variables" in payload:
|
||||
params["variables"] = json.dumps(payload.get("variables") or {}, ensure_ascii=False)
|
||||
updates.append("variables = CAST(:variables AS jsonb)")
|
||||
updates.append("updated_at = :updated_at")
|
||||
|
||||
async with _LegacySession() as session:
|
||||
async with session.begin():
|
||||
await session.execute(text(f"UPDATE prompt_templates SET {', '.join(updates)} WHERE id = :id"), params)
|
||||
await session.commit()
|
||||
return await self.GetTemplate(TemplateId)
|
||||
|
||||
async def DeleteTemplate(self, TemplateId: int) -> None:
|
||||
await self.GetTemplate(TemplateId)
|
||||
async with _LegacySession() as session:
|
||||
async with session.begin():
|
||||
await session.execute(text("DELETE FROM prompt_templates WHERE id = :id"), {"id": TemplateId})
|
||||
await session.commit()
|
||||
|
||||
async def GetTemplateTypes(self) -> PromptTemplateTypeListVO:
|
||||
async with _LegacySession() as session:
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT template_type, COUNT(*)::int AS count
|
||||
FROM prompt_templates
|
||||
WHERE template_type IS NOT NULL AND TRIM(template_type) <> ''
|
||||
GROUP BY template_type
|
||||
ORDER BY template_type ASC
|
||||
"""
|
||||
)
|
||||
)
|
||||
).mappings().all()
|
||||
items = [
|
||||
PromptTemplateTypeOptionVO(
|
||||
value=str(row["template_type"]),
|
||||
label=_TYPE_LABELS.get(str(row["template_type"]), str(row["template_type"])),
|
||||
count=int(row["count"] or 0),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
return PromptTemplateTypeListVO(items=items)
|
||||
|
||||
async def DuplicateTemplate(self, TemplateId: int, NewCode: str | None) -> PromptTemplateVO:
|
||||
current = await self.GetTemplate(TemplateId)
|
||||
code = (NewCode or "").strip() or self._generate_copy_code(current.template_code or f"template_{TemplateId}")
|
||||
await self._ensure_template_code_unique(code)
|
||||
body = PromptTemplateCreateDTO(
|
||||
template_name=f"{current.template_name}-副本",
|
||||
template_type=current.template_type,
|
||||
description=current.description,
|
||||
template_content=current.template_content,
|
||||
variables=current.variables,
|
||||
status=current.status,
|
||||
version=current.version,
|
||||
created_by=current.created_by,
|
||||
template_code=code,
|
||||
template_abbreviation=current.template_abbreviation,
|
||||
)
|
||||
return await self.CreateTemplate(body)
|
||||
|
||||
async def _load_usernames(self, user_ids: list[int]) -> dict[int, str]:
|
||||
ids = sorted({int(item) for item in user_ids if item is not None})
|
||||
if not ids:
|
||||
return {}
|
||||
async with GetAsyncSession() as session:
|
||||
rows = (
|
||||
await session.execute(
|
||||
text("SELECT id, username FROM sso_users WHERE id IN :ids").bindparams(bindparam("ids", expanding=True)),
|
||||
{"ids": ids},
|
||||
)
|
||||
).mappings().all()
|
||||
return {int(row["id"]): str(row.get("username") or "") for row in rows}
|
||||
|
||||
async def _validate_template_payload(self, template_type: str | None, template_code: str | None, template_abbreviation: str | None) -> None:
|
||||
if template_type not in _ALLOWED_TEMPLATE_TYPES:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板类型不合法")
|
||||
if template_type == "VLM_Extraction":
|
||||
if not str(template_code or "").strip():
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板code")
|
||||
if not str(template_abbreviation or "").strip():
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板简称")
|
||||
|
||||
async def _ensure_template_code_unique(self, template_code: str | None, template_id: int | None = None) -> None:
|
||||
code = str(template_code or "").strip()
|
||||
if not code:
|
||||
return
|
||||
sql = "SELECT id FROM prompt_templates WHERE LOWER(template_code) = LOWER(:template_code)"
|
||||
params: dict[str, Any] = {"template_code": code}
|
||||
if template_id is not None:
|
||||
sql += " AND id <> :id"
|
||||
params["id"] = template_id
|
||||
async with _LegacySession() as session:
|
||||
exists = (await session.execute(text(sql), params)).scalar_one_or_none()
|
||||
if exists:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板code已存在")
|
||||
|
||||
def _build_write_payload(self, body: PromptTemplateCreateDTO) -> dict[str, Any]:
|
||||
return {
|
||||
"template_name": body.template_name.strip(),
|
||||
"template_type": body.template_type,
|
||||
"description": (body.description or "").strip() or None,
|
||||
"template_content": body.template_content,
|
||||
"variables": json.dumps(body.variables or {}, ensure_ascii=False),
|
||||
"status": int(body.status),
|
||||
"version": str(body.version or "v1.0").strip() or "v1.0",
|
||||
"created_by": int(body.created_by) if body.created_by is not None else None,
|
||||
"template_code": (body.template_code or "").strip() or None,
|
||||
"template_abbreviation": (body.template_abbreviation or "").strip() or None,
|
||||
}
|
||||
|
||||
def _to_vo(self, row: dict[str, Any], username: str | None) -> PromptTemplateVO:
|
||||
return PromptTemplateVO(
|
||||
id=int(row["id"]),
|
||||
template_name=str(row.get("template_name") or ""),
|
||||
template_code=row.get("template_code"),
|
||||
template_type=str(row.get("template_type") or ""),
|
||||
description=row.get("description"),
|
||||
template_content=str(row.get("template_content") or ""),
|
||||
template_abbreviation=row.get("template_abbreviation"),
|
||||
variables=self._parse_variables(row.get("variables")),
|
||||
status=int(row.get("status") or 0),
|
||||
version=str(row.get("version") or "v1.0"),
|
||||
created_by=int(row["created_by"]) if row.get("created_by") is not None else None,
|
||||
created_by_username=username,
|
||||
created_at=self._format_datetime(row.get("created_at")),
|
||||
updated_at=self._format_datetime(row.get("updated_at")),
|
||||
)
|
||||
|
||||
def _parse_variables(self, value: Any) -> dict[str, str]:
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return {str(k): str(v) for k, v in value.items()}
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if isinstance(parsed, dict):
|
||||
return {str(k): str(v) for k, v in parsed.items()}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def _format_datetime(self, value: Any) -> str:
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
return str(value or "")
|
||||
|
||||
def _generate_copy_code(self, code: str) -> str:
|
||||
return f"{code}_copy_{int(datetime.utcnow().timestamp())}"
|
||||
Reference in New Issue
Block a user