from __future__ import annotations import json import os from datetime import datetime from typing import Any from urllib.parse import quote_plus from sqlalchemy import bindparam, text from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from fastapi_admin.config import DB_HOST, DB_PASSWORD, DB_PORT, DB_USER from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_modules.fastapi_leaudit.domian.Dto.promptTemplateDto import PromptTemplateCreateDTO, PromptTemplateUpdateDTO from fastapi_modules.fastapi_leaudit.domian.vo.promptTemplateVo import PromptTemplateListVO, PromptTemplateTypeListVO, PromptTemplateTypeOptionVO, PromptTemplateVO from fastapi_modules.fastapi_leaudit.services.promptTemplateService import IPromptTemplateService _LEGACY_DB_NAME = os.getenv("LEGACY_RULE_DB_NAME", "docauditai") _LEGACY_DB_URL = ( f"postgresql+asyncpg://{quote_plus(str(DB_USER))}:{quote_plus(str(DB_PASSWORD))}" f"@{DB_HOST}:{DB_PORT}/{quote_plus(_LEGACY_DB_NAME)}" ) _LEGACY_ENGINE = create_async_engine(_LEGACY_DB_URL, pool_pre_ping=True) _LegacySession = async_sessionmaker(_LEGACY_ENGINE, expire_on_commit=False) _ALLOWED_TEMPLATE_TYPES = {"LLM_Extraction", "VLM_Extraction", "Evaluation", "Summary", "Common"} _TYPE_LABELS = { "LLM_Extraction": "LLM抽取", "VLM_Extraction": "VLM抽取", "Evaluation": "评查", "Summary": "摘要", "Common": "通用", } class PromptTemplateServiceImpl(IPromptTemplateService): async def ListTemplates( self, Search: str | None, TemplateTypes: list[str] | None, Status: int | None, Page: int, PageSize: int, ) -> PromptTemplateListVO: offset = max(Page - 1, 0) * PageSize filters = ["1=1"] params: dict[str, Any] = {"limit": PageSize, "offset": offset} if Search: filters.append("(pt.template_name ILIKE :search OR pt.template_code ILIKE :search)") params["search"] = f"%{Search.strip()}%" normalized_types = [item for item in (TemplateTypes or []) if item in _ALLOWED_TEMPLATE_TYPES] if normalized_types: filters.append("pt.template_type IN :template_types") params["template_types"] = tuple(normalized_types) if Status is not None: filters.append("pt.status = :status") params["status"] = Status where_clause = " AND ".join(filters) count_sql = text(f"SELECT COUNT(*) FROM prompt_templates pt WHERE {where_clause}") list_sql = text( f""" SELECT pt.id, pt.template_name, pt.template_type, pt.description, pt.template_content, pt.variables, pt.status, pt.version, pt.created_by, pt.created_at, pt.updated_at, pt.template_code, pt.template_abbreviation FROM prompt_templates pt WHERE {where_clause} ORDER BY pt.updated_at DESC, pt.id DESC LIMIT :limit OFFSET :offset """ ) if normalized_types: count_sql = count_sql.bindparams(bindparam("template_types", expanding=True)) list_sql = list_sql.bindparams(bindparam("template_types", expanding=True)) async with _LegacySession() as session: total = int((await session.execute(count_sql, params)).scalar_one()) rows = ( await session.execute( list_sql, params, ) ).mappings().all() usernames = await self._load_usernames([int(row["created_by"]) for row in rows if row.get("created_by") is not None]) items = [self._to_vo(row, usernames.get(int(row["created_by"]))) for row in rows] return PromptTemplateListVO(total=total, page=Page, page_size=PageSize, items=items) async def GetTemplate(self, TemplateId: int) -> PromptTemplateVO: async with _LegacySession() as session: row = ( await session.execute( text( """ SELECT id, template_name, template_type, description, template_content, variables, status, version, created_by, created_at, updated_at, template_code, template_abbreviation FROM prompt_templates WHERE id = :id """ ), {"id": TemplateId}, ) ).mappings().first() if not row: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "提示词模板不存在") usernames = await self._load_usernames([int(row["created_by"])] if row.get("created_by") is not None else []) return self._to_vo(row, usernames.get(int(row["created_by"])) if row.get("created_by") is not None else None) async def CreateTemplate(self, Body: PromptTemplateCreateDTO) -> PromptTemplateVO: await self._validate_template_payload(Body.template_type, Body.template_code, Body.template_abbreviation) await self._ensure_template_code_unique(Body.template_code) now = datetime.utcnow() payload = self._build_write_payload(Body) payload.update({"created_at": now, "updated_at": now}) async with _LegacySession() as session: async with session.begin(): new_id = await session.scalar( text( """ INSERT INTO prompt_templates ( template_name, template_type, description, template_content, variables, status, version, created_by, created_at, updated_at, template_code, template_abbreviation ) VALUES ( :template_name, :template_type, :description, :template_content, CAST(:variables AS jsonb), :status, :version, :created_by, :created_at, :updated_at, :template_code, :template_abbreviation ) RETURNING id """ ), payload, ) await session.commit() return await self.GetTemplate(int(new_id)) async def UpdateTemplate(self, TemplateId: int, Body: PromptTemplateUpdateDTO) -> PromptTemplateVO: current = await self.GetTemplate(TemplateId) payload = Body.model_dump(exclude_unset=True) if not payload: return current new_type = payload.get("template_type", current.template_type) new_code = payload.get("template_code", current.template_code) new_abbr = payload.get("template_abbreviation", current.template_abbreviation) await self._validate_template_payload(new_type, new_code, new_abbr) if "template_code" in payload: await self._ensure_template_code_unique(payload.get("template_code"), TemplateId) updates: list[str] = [] params: dict[str, Any] = {"id": TemplateId, "updated_at": datetime.utcnow()} simple_fields = [ "template_name", "template_type", "description", "template_content", "status", "version", "template_code", "template_abbreviation", ] for field in simple_fields: if field in payload: params[field] = payload[field] updates.append(f"{field} = :{field}") if "variables" in payload: params["variables"] = json.dumps(payload.get("variables") or {}, ensure_ascii=False) updates.append("variables = CAST(:variables AS jsonb)") updates.append("updated_at = :updated_at") async with _LegacySession() as session: async with session.begin(): await session.execute(text(f"UPDATE prompt_templates SET {', '.join(updates)} WHERE id = :id"), params) await session.commit() return await self.GetTemplate(TemplateId) async def DeleteTemplate(self, TemplateId: int) -> None: await self.GetTemplate(TemplateId) async with _LegacySession() as session: async with session.begin(): await session.execute(text("DELETE FROM prompt_templates WHERE id = :id"), {"id": TemplateId}) await session.commit() async def GetTemplateTypes(self) -> PromptTemplateTypeListVO: async with _LegacySession() as session: rows = ( await session.execute( text( """ SELECT template_type, COUNT(*)::int AS count FROM prompt_templates WHERE template_type IS NOT NULL AND TRIM(template_type) <> '' GROUP BY template_type ORDER BY template_type ASC """ ) ) ).mappings().all() items = [ PromptTemplateTypeOptionVO( value=str(row["template_type"]), label=_TYPE_LABELS.get(str(row["template_type"]), str(row["template_type"])), count=int(row["count"] or 0), ) for row in rows ] return PromptTemplateTypeListVO(items=items) async def DuplicateTemplate(self, TemplateId: int, NewCode: str | None) -> PromptTemplateVO: current = await self.GetTemplate(TemplateId) code = (NewCode or "").strip() or self._generate_copy_code(current.template_code or f"template_{TemplateId}") await self._ensure_template_code_unique(code) body = PromptTemplateCreateDTO( template_name=f"{current.template_name}-副本", template_type=current.template_type, description=current.description, template_content=current.template_content, variables=current.variables, status=current.status, version=current.version, created_by=current.created_by, template_code=code, template_abbreviation=current.template_abbreviation, ) return await self.CreateTemplate(body) async def _load_usernames(self, user_ids: list[int]) -> dict[int, str]: ids = sorted({int(item) for item in user_ids if item is not None}) if not ids: return {} async with GetAsyncSession() as session: rows = ( await session.execute( text("SELECT id, username FROM sso_users WHERE id IN :ids").bindparams(bindparam("ids", expanding=True)), {"ids": ids}, ) ).mappings().all() return {int(row["id"]): str(row.get("username") or "") for row in rows} async def _validate_template_payload(self, template_type: str | None, template_code: str | None, template_abbreviation: str | None) -> None: if template_type not in _ALLOWED_TEMPLATE_TYPES: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板类型不合法") if template_type == "VLM_Extraction": if not str(template_code or "").strip(): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板code") if not str(template_abbreviation or "").strip(): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "VLM抽取模板必须填写模板简称") async def _ensure_template_code_unique(self, template_code: str | None, template_id: int | None = None) -> None: code = str(template_code or "").strip() if not code: return sql = "SELECT id FROM prompt_templates WHERE LOWER(template_code) = LOWER(:template_code)" params: dict[str, Any] = {"template_code": code} if template_id is not None: sql += " AND id <> :id" params["id"] = template_id async with _LegacySession() as session: exists = (await session.execute(text(sql), params)).scalar_one_or_none() if exists: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "模板code已存在") def _build_write_payload(self, body: PromptTemplateCreateDTO) -> dict[str, Any]: return { "template_name": body.template_name.strip(), "template_type": body.template_type, "description": (body.description or "").strip() or None, "template_content": body.template_content, "variables": json.dumps(body.variables or {}, ensure_ascii=False), "status": int(body.status), "version": str(body.version or "v1.0").strip() or "v1.0", "created_by": int(body.created_by) if body.created_by is not None else None, "template_code": (body.template_code or "").strip() or None, "template_abbreviation": (body.template_abbreviation or "").strip() or None, } def _to_vo(self, row: dict[str, Any], username: str | None) -> PromptTemplateVO: return PromptTemplateVO( id=int(row["id"]), template_name=str(row.get("template_name") or ""), template_code=row.get("template_code"), template_type=str(row.get("template_type") or ""), description=row.get("description"), template_content=str(row.get("template_content") or ""), template_abbreviation=row.get("template_abbreviation"), variables=self._parse_variables(row.get("variables")), status=int(row.get("status") or 0), version=str(row.get("version") or "v1.0"), created_by=int(row["created_by"]) if row.get("created_by") is not None else None, created_by_username=username, created_at=self._format_datetime(row.get("created_at")), updated_at=self._format_datetime(row.get("updated_at")), ) def _parse_variables(self, value: Any) -> dict[str, str]: if value is None: return {} if isinstance(value, dict): return {str(k): str(v) for k, v in value.items()} if isinstance(value, str): try: parsed = json.loads(value) if isinstance(parsed, dict): return {str(k): str(v) for k, v in parsed.items()} except json.JSONDecodeError: return {} return {} def _format_datetime(self, value: Any) -> str: if isinstance(value, datetime): return value.isoformat() return str(value or "") def _generate_copy_code(self, code: str) -> str: return f"{code}_copy_{int(datetime.utcnow().timestamp())}"