diff --git a/app.toml b/app.toml index 1cfca48..dfb0181 100644 --- a/app.toml +++ b/app.toml @@ -42,6 +42,13 @@ BASE_URL = "https://hub.leke.run/qwen/v1" MODEL = "qwen3.5-35b-a3b" API_KEY = "sk-6c7466b543b947ffadc50a5d79135712" +[EMBEDDING] +BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings" +MODEL = "text-embedding-v4" +API_KEY = "sk-6c7466b543b947ffadc50a5d79135712" +DIM = 1024 +BATCH_SIZE = 10 + [OCR] BASE_URL = "https://hub.leke.run/" TIMEOUT = 300 diff --git a/docs/RAG/RAG会话自动标题开发任务拆解与接口变更清单.md b/docs/RAG/RAG会话自动标题开发任务拆解与接口变更清单.md new file mode 100644 index 0000000..6be7ee3 --- /dev/null +++ b/docs/RAG/RAG会话自动标题开发任务拆解与接口变更清单.md @@ -0,0 +1,185 @@ +# RAG 会话自动标题开发任务拆解与接口变更清单 + +> 最后整理:2026-05-19 +> 对应方案:`docs/RAG/RAG会话自动标题生成与重命名实施方案.md` + +## 1. 本期落地范围 + +本期目标不是做一个“前端点一下自动命名”的伪功能,而是完成第一版可稳定运行的后端自动标题链路: + +1. 新会话首轮回答完成后,后端异步生成标题 +2. 手动重命名后,不再被自动标题覆盖 +3. 会话列表排序从“单纯 updated_at”收口到“最后消息时间优先” +4. 前端移除旧的假自动重命名逻辑 + +--- + +## 2. 任务拆解 + +## 2.1 后端数据层 + +任务: + +1. 为 `rag_conversation` 增加标题状态字段 +2. 增加 `last_message_at` +3. 增加相关索引 + +交付物: + +1. `scripts/创建sql/schema_rag_chat_auto_title.sql` + +验收点: + +1. 老数据执行 SQL 不报错 +2. 默认值正确 +3. 索引存在 + +## 2.2 后端会话查询层 + +任务: + +1. 会话列表返回 `titleSource` +2. 会话列表返回 `lastMessageAt` +3. 排序改为 `COALESCE(last_message_at, updated_at) DESC` + +验收点: + +1. 自动标题更新不会让会话无故跳到顶部 +2. 真正发新消息的会话仍能上浮 + +## 2.3 后端消息完成链路 + +任务: + +1. assistant 最终落库时同步更新 `last_message_at` +2. 首轮回答完成后调度自动标题任务 + +验收点: + +1. 标题生成失败不影响主回答 +2. 只在首轮完整回答后触发 + +## 2.4 后端自动标题生成器 + +任务: + +1. 新增 `_maybe_schedule_auto_title(...)` +2. 新增 `_run_auto_title_task(...)` +3. 新增 `_generate_conversation_title(...)` +4. 新增 `_sanitize_generated_title(...)` +5. 新增 `_build_fallback_title(...)` + +验收点: + +1. 同一会话不会重复并发生成标题 +2. 生成失败能记状态 +3. 生成成功后能回写标题 + +## 2.5 后端手动重命名链路 + +任务: + +1. `RenameConversation()` 改为写入 `title_source = 'manual'` +2. 清空最近一次标题生成错误 + +验收点: + +1. 用户手动改名后,后续自动任务不再覆盖 + +## 2.6 前端聊天主链路 + +任务: + +1. 移除 `generateConversationName(...)` 自动调用 +2. 移除 `sidebarRef.autoRename(...)` 自动调用 +3. 回答完成后轻量刷新会话列表,把后端自动标题同步回来 + +验收点: + +1. 首轮问答完成后,会话名称自然更新 +2. 前端不再双写标题 + +## 2.7 前端会话列表排序与本地更新 + +任务: + +1. 会话排序优先用 `last_message_at` +2. 仅标题更新时不再强行覆盖 `updated_at` + +验收点: + +1. 手动重命名不会导致会话跳序 +2. 自动重命名不会导致会话跳序 + +--- + +## 3. 接口变更清单 + +## 3.1 `GET /api/v3/rag/chat/conversations` + +新增返回字段: + +```json +{ + "id": "conversation_id", + "name": "烟草专卖法中的烟草定义", + "introduction": "", + "titleSource": "auto", + "createdAt": 1779078161, + "updatedAt": 1779078200, + "lastMessageAt": 1779078200 +} +``` + +兼容性说明: + +1. 前端旧代码即使不消费新增字段也不会炸 +2. 新前端会使用 `last_message_at` 改善排序 + +## 3.2 `PATCH /api/v3/rag/chat/conversations/{ConversationId}` + +本期请求体不变: + +```json +{ + "name": "我手动改的标题" +} +``` + +但服务端语义变化为: + +1. 更新标题 +2. 将 `title_source` 切为 `manual` +3. 保护后续自动任务不覆盖 + +--- + +## 4. 数据字段说明 + +| 字段 | 说明 | +|---|---| +| `title_source` | `default / auto / manual` | +| `title_generation_status` | `idle / pending / running / succeeded / failed` | +| `title_generated_at` | 自动标题成功时间 | +| `first_question_message_id` | 预留,本期可暂不写 | +| `first_answer_message_id` | 记录首轮回答 message_id | +| `title_generation_error` | 最近一次失败原因 | +| `last_message_at` | 最后一次消息完成时间 | + +--- + +## 5. 本期已知边界 + +1. 自动标题目前走与 follow-up 相同的在线 LLM 依赖 +2. 如果 LLM 配置错误,主回答可能正常但标题生成失败 +3. 本期还未增加专门的 `conversation_renamed` SSE 事件 +4. 前端目前通过“回答完成后刷新列表”同步标题,优先保证正确性 + +--- + +## 6. 后续增强建议 + +1. 新增 `conversation_renamed` SSE 事件,减少列表刷新 +2. 增加后台补偿任务,扫描 `failed` 会话再重试 +3. 引入标题生成提示词版本号,便于回溯 +4. 对 `first_question_message_id` 也做持久化,补齐审计链路 diff --git a/docs/RAG/RAG会话自动标题生成与重命名实施方案.md b/docs/RAG/RAG会话自动标题生成与重命名实施方案.md new file mode 100644 index 0000000..1aa41a8 --- /dev/null +++ b/docs/RAG/RAG会话自动标题生成与重命名实施方案.md @@ -0,0 +1,722 @@ +# RAG 会话自动标题生成与重命名实施方案 + +> 最后整理:2026-05-19 +> 适用模块:`/chat-with-llm/chat` +> 对应前端:`legal-platform-frontend/components/dify-chat/*` +> 对应后端:`fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py` + +## 1. 结论先行 + +当前“会话自动重命名”能力并不是真正可用的自动标题生成,而是前端在首轮回答完成后又走了一次“重命名接口”,并且现有实现路径存在以下问题: + +1. 标题生成职责放在前端,天然不稳定。 +2. 会话真实 ID 与临时 ID 切换期间,自动重命名容易打到错误对象。 +3. 前端切换页面、断网、刷新、关闭标签页时,自动重命名流程可能直接丢失。 +4. 现有 `generateConversationName()` 实际并没有真正“根据问答内容生成标题”的可靠后端闭环。 +5. 手动重命名与自动重命名没有状态隔离,后续极易互相覆盖。 + +建议采用: + +**后端主导、异步生成、首轮回答完成后触发、手动重命名永不被覆盖**。 + +不建议继续沿用当前前端 `autoRename()` / `generateConversationName()` 这条链路做增强。 + +--- + +## 2. 当前实现现状 + +## 2.1 前端现状 + +当前前端存在两条与“自动重命名”有关的路径: + +1. `legal-platform-frontend/hooks/use-chat-message.ts` + 在首轮回答 `onCompleted` 后调用 `generateConversationName(tempNewConversationId)`。 +2. `legal-platform-frontend/components/dify-chat/index.tsx` + 在 `onConversationIdChange(..., { syncName: true })` 时通过 `sidebarRef.current.autoRename(conversationId)` 再触发一次自动重命名。 + +而 `legal-platform-frontend/components/dify-chat/sidebar.tsx` 中的 `autoRename()` 目前本质上只是: + +- 调 `renameConversation(conversationId, '新对话', false)` + +这不是“生成标题”,只是再次把标题写成“新对话”。 + +## 2.2 前端 API 现状 + +`legal-platform-frontend/lib/api/legacy/dify-chat/client.ts` + +- `generateConversationName(id)` 实际是 `renameConversation(id, '', true)` +- 但 `difyClient.renameConversation()` 最终只是: + - 把空标题兜底为 `'新对话'` + - `PATCH /api/v3/rag/chat/conversations/{conversationId}` + - 请求体只有 `{ name: finalName }` + +这说明当前所谓 `auto_generate=true` 在后端并没有真正形成“让模型生成标题”的能力闭环。 + +## 2.3 后端现状 + +`fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py` + +当前会话创建逻辑: + +```sql +INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction) +VALUES (:conversation_id, :user_id, :app_id, '新对话', '') +``` + +当前重命名逻辑: + +```sql +UPDATE rag_conversation +SET name = :name, updated_at = NOW() +WHERE conversation_id = :conversation_id +``` + +也就是说: + +1. 会话创建时标题恒为 `新对话` +2. 后端没有“自动标题生成任务” +3. 后端没有“标题来源状态” +4. 后端无法区分: + - 默认标题 + - 自动生成标题 + - 用户手动标题 + +--- + +## 3. 问题本质 + +自动标题本质不是一个前端 UI 小功能,而是一个会话元数据状态机问题。 + +如果设计不对,会持续出现这些 bug: + +1. 前端刚发完首问就切走,标题没生成。 +2. 用户断网后端其实已经答完,但标题没改。 +3. 用户手动改名后,异步自动标题结果又把它覆盖掉。 +4. 同一个会话出现两次标题更新,列表排序抖动。 +5. 首轮回答失败或中断,却错误生成了标题。 + +所以这件事必须满足两个原则: + +1. **标题生成结果必须由后端持久化并裁决** +2. **手动重命名优先级必须高于自动生成** + +--- + +## 4. 目标 + +本方案的目标是: + +1. 新会话在首轮有效回答完成后,自动生成一个简短、可读、稳定的标题。 +2. 用户刷新页面、切换会话、切换应用、前端断流时,标题生成不丢。 +3. 手动重命名后,后续自动任务不能再覆盖。 +4. 标题更新不应破坏会话列表选择态、排序和缓存一致性。 +5. 方案应尽量复用现有 RAG 聊天主链路,不引入过重基础设施。 + +非目标: + +1. 不在本期做多轮动态标题持续改写。 +2. 不在本期做“每次问答都重新概括标题”。 +3. 不在本期做复杂工作流引擎或独立任务队列系统。 + +--- + +## 5. 推荐方案 + +## 5.1 总体方案 + +推荐采用: + +**方案 C:后端异步自动标题生成,前端只负责显示结果** + +它和之前“会话 ID 提前同步方案”并不冲突,反而是互补关系: + +1. 会话 ID 提前同步解决的是 `temp -> real` 的绑定竞态。 +2. 自动标题生成解决的是 `real conversation` 的元数据完善问题。 + +完整链路应为: + +1. 前端发送首问。 +2. 后端创建真实会话。 +3. SSE 第一帧返回 `conversation_created`,前端立即完成 temp 升 real。 +4. 后端继续生成主回答。 +5. 主回答真正完成并持久化后,后端异步触发“自动标题生成任务”。 +6. 标题写回 `rag_conversation`。 +7. 前端通过轮询刷新或专门事件获知新标题并更新列表。 + +## 5.2 为什么不建议继续把标题生成放前端 + +因为前端无法可靠保证以下场景: + +1. 用户切走当前会话 +2. 浏览器刷新 +3. 页面关闭 +4. SSE 中断 +5. 多标签页并发 +6. 首轮回答完成但前端回调未触发 + +而这些情况在当前聊天系统里都是真实存在的。 + +自动标题如果依赖前端触发,就一定会丢边界。 + +--- + +## 6. 数据模型设计 + +## 6.1 `rag_conversation` 建议新增字段 + +建议新增以下字段: + +| 字段 | 类型 | 说明 | +|---|---|---| +| `title_source` | `varchar(20)` | 标题来源:`default / auto / manual` | +| `title_generation_status` | `varchar(20)` | 生成状态:`idle / pending / running / succeeded / failed` | +| `title_generated_at` | `timestamp with time zone` | 自动标题最终生成时间 | +| `first_question_message_id` | `varchar(100)` | 首条用户消息 ID | +| `first_answer_message_id` | `varchar(100)` | 首条回答消息 ID | +| `title_generation_error` | `text` | 最近一次失败原因,便于排障 | +| `last_message_at` | `timestamp with time zone` | 会话最后一次消息完成时间,可选 | + +最小可落地集合其实只需要: + +1. `title_source` +2. `title_generation_status` +3. `title_generated_at` + +但从排障和防重角度,建议把 `first_question_message_id / first_answer_message_id / title_generation_error` 一并加上。 + +## 6.2 字段语义 + +### `title_source` + +- `default` + - 当前仍是系统默认标题,例如 `新对话` +- `auto` + - 已由系统自动生成 +- `manual` + - 已被用户手动修改 + +### `title_generation_status` + +- `idle` + - 默认状态,尚未进入自动生成流程 +- `pending` + - 已满足触发条件,等待后台执行 +- `running` + - 正在生成 +- `succeeded` + - 已成功生成 +- `failed` + - 生成失败,但不影响主会话正常使用 + +## 6.3 迁移 SQL 草案 + +```sql +ALTER TABLE rag_conversation +ADD COLUMN IF NOT EXISTS title_source VARCHAR(20) NOT NULL DEFAULT 'default', +ADD COLUMN IF NOT EXISTS title_generation_status VARCHAR(20) NOT NULL DEFAULT 'idle', +ADD COLUMN IF NOT EXISTS title_generated_at TIMESTAMPTZ NULL, +ADD COLUMN IF NOT EXISTS first_question_message_id VARCHAR(100) NULL, +ADD COLUMN IF NOT EXISTS first_answer_message_id VARCHAR(100) NULL, +ADD COLUMN IF NOT EXISTS title_generation_error TEXT NULL, +ADD COLUMN IF NOT EXISTS last_message_at TIMESTAMPTZ NULL; + +CREATE INDEX IF NOT EXISTS idx_rag_conversation_title_generation_status +ON rag_conversation(title_generation_status) +WHERE deleted_at IS NULL; +``` + +如果本期想尽量收敛改动,也可以拆成两批迁移。 + +--- + +## 7. 触发时机设计 + +## 7.1 唯一推荐触发点 + +**首轮 assistant 消息状态从 `running` 变为 `completed` 且内容非空时触发。** + +这比“前端收到 `onCompleted`”更可靠。 + +## 7.2 为什么必须是“回答完成后” + +原因很直接: + +1. 只用首个用户问题生成标题,标题质量偏差较大。 +2. 很多用户首问比较短,例如“这个是什么”“帮我解释一下”。 +3. 加上首轮回答后,标题更容易稳定落到真实主题。 + +示例: + +- 首问:`烟草是什么?` +- 首答:围绕《中华人民共和国烟草专卖法》的定义和专卖属性展开 + +最终标题应更像: + +- `烟草专卖法中的烟草定义` + +而不是: + +- `烟草是什么` + +## 7.3 不应触发的情况 + +以下情况不应自动生成标题: + +1. 用户发送后立即停止,回答为空。 +2. 后端异常,assistant 最终状态为 `error`。 +3. 标题来源已经是 `manual`。 +4. 标题来源已经是 `auto` 且当前标题非默认值。 +5. 会话已被删除。 + +--- + +## 8. 标题生成策略 + +## 8.1 输入材料 + +建议只使用以下材料: + +1. 首条用户问题 `query` +2. 首条 assistant 有效回答 `answer` + +不要直接把整段上下文、多轮历史、引用来源、思考链全部塞进去。 + +原因: + +1. 标题生成目标很小,不需要太多上下文。 +2. 输入越多,标题越容易发散。 +3. 可以降低调用成本和失败面。 + +## 8.2 输出要求 + +建议标题约束为: + +1. 12 到 24 个中文字符优先 +2. 最长不超过 40 个字符 +3. 不加句号 +4. 不加“用户问了什么”这种废话 +5. 不使用“关于”“浅谈”“分析一下”等空泛词 +6. 以主题名词或主题短句为主 + +可接受示例: + +1. `烟草专卖法中的烟草定义` +2. `政府采购法适用范围说明` +3. `合同违约责任条款审查要点` + +不可接受示例: + +1. `新对话` +2. `关于烟草是什么的回答` +3. `用户询问烟草相关内容` +4. `烟草是什么?` + +## 8.3 兜底策略 + +如果模型生成失败,建议兜底顺序: + +1. 从首问做规则截断摘要 +2. 如果首问太短,仍保留 `新对话` + +规则摘要示例: + +- 原文:`请结合中华人民共和国烟草专卖法解释烟草在法律上的定义及监管属性` +- 兜底标题:`烟草的法律定义及监管属性` + +--- + +## 9. 后端实现设计 + +## 9.1 责任归属 + +后端负责: + +1. 判定是否应该生成标题 +2. 防重复调度 +3. 执行标题生成 +4. 持久化最终标题 +5. 保证不覆盖手动标题 + +前端只负责: + +1. 展示当前标题 +2. 用户手动重命名 +3. 在标题变化后刷新列表 + +## 9.2 推荐实现方式 + +建议在 `RagChatServiceImpl` 中新增以下私有方法: + +1. `_maybe_schedule_auto_title(...)` +2. `_run_auto_title_task(...)` +3. `_generate_conversation_title(query: str, answer: str) -> str` +4. `_build_fallback_title(query: str, answer: str) -> str` + +## 9.3 建议调用时机 + +在主回答最终落库成功、assistant 消息状态改为 `completed` 后: + +1. 更新 `rag_message` +2. 更新 `rag_conversation.last_message_at` +3. 调 `_maybe_schedule_auto_title(...)` + +## 9.4 `_maybe_schedule_auto_title(...)` 规则 + +该方法只做“判定 + 抢占式设置 pending/running”,不做实际生成。 + +必须同时满足: + +1. 当前会话存在且未删除 +2. `title_source = 'default'` +3. 当前标题为空或等于 `新对话` +4. 首条回答状态为 `completed` +5. 回答内容非空 +6. 当前未处于 `running` +7. 当前未成功生成过 + +如果满足,则: + +1. 把 `title_generation_status` 更新为 `pending` +2. 记录 `first_question_message_id / first_answer_message_id` +3. 启动后台异步任务 + +## 9.5 `_run_auto_title_task(...)` 规则 + +执行步骤建议如下: + +1. 再次查库确认会话仍存在 +2. 再次确认 `title_source != 'manual'` +3. 将 `title_generation_status` 置为 `running` +4. 调标题生成函数 +5. 清洗标题文本 +6. 二次查库确认用户没有在这期间手动改名 +7. 仅在 `title_source = 'default'` 时执行更新: + - `name = generated_title` + - `title_source = 'auto'` + - `title_generation_status = 'succeeded'` + - `title_generated_at = NOW()` +8. 如果失败: + - `title_generation_status = 'failed'` + - `title_generation_error = ...` + +## 9.6 为什么必须二次查库 + +因为存在这个时序: + +1. 后端开始生成自动标题 +2. 用户在前端手动把标题改成 `烟草法问答` +3. 自动任务稍后返回 `烟草专卖法中的烟草定义` + +如果没有二次查库保护,就会把用户手动标题覆盖掉。 + +这是不能接受的。 + +--- + +## 10. 手动重命名规则 + +## 10.1 规则定义 + +用户手动重命名时,后端必须: + +1. 更新 `name` +2. 设置 `title_source = 'manual'` +3. 可选把 `title_generation_status` 保留原值,或重置为 `succeeded` +4. 清空 `title_generation_error` + +推荐直接: + +- `title_source = 'manual'` +- `title_generation_status` 不再参与后续覆盖判定 + +## 10.2 接口行为 + +当前接口: + +`PATCH /api/v3/rag/chat/conversations/{ConversationId}` + +当前 DTO 只有: + +```python +class RagConversationRenameDTO(BaseModel): + name: str +``` + +本期其实不需要改手动重命名接口入参,后端只要在 `RenameConversation()` 里补充: + +```sql +UPDATE rag_conversation +SET + name = :name, + title_source = 'manual', + updated_at = NOW() +WHERE conversation_id = :conversation_id +``` + +即可满足主需求。 + +## 10.3 是否保留 `auto_generate` 参数 + +建议: + +1. 不再让前端显式调用 `auto_generate=true` +2. 后续可以废弃该语义 + +因为自动生成应该是后端内生行为,而不是前端额外指令。 + +--- + +## 11. 前端同步策略 + +## 11.1 推荐方案 + +推荐采用: + +**列表刷新 + 局部状态更新** + +即: + +1. 会话主链路仍由 SSE 负责回答流 +2. 标题更新不强依赖新增 SSE 事件 +3. 在以下时机刷新会话列表即可: + - 首轮回答完成后 + - 用户切回该应用时 + - 定时轻量刷新 + +原因: + +1. 标题不是强实时核心信息 +2. 不值得为了标题专门把主流协议继续复杂化 +3. 当前项目已经存在会话列表刷新逻辑,接入成本更低 + +## 11.2 如果想更丝滑 + +可以额外新增一个可选 SSE 事件: + +```json +{ + "event": "conversation_renamed", + "conversation_id": "xxx", + "name": "烟草专卖法中的烟草定义", + "source": "auto" +} +``` + +但这不是本期必需。 + +## 11.3 前端必须删除的错误路径 + +建议清理以下逻辑: + +1. `use-chat-message.ts` 中首轮完成后主动 `generateConversationName(...)` +2. `index.tsx` 中通过 `sidebarRef.current.autoRename(...)` 再次发起“自动重命名” +3. `sidebar.tsx` 中把 `autoRename()` 重命名成“仅保留手动重命名能力”,或直接删除 ref 暴露 + +否则后续会形成双写: + +1. 后端自动生成一版 +2. 前端又发起一次旧式 rename + +结果会继续乱。 + +--- + +## 12. 与“对话中断”场景的关系 + +这是本方案里必须明确的一点。 + +## 12.1 前端切走当前会话 + +如果用户在回答过程中切到其他会话: + +1. 后端生成不应停止 +2. 会话标题生成也不应停止 +3. 前端只是暂时不再消费当前流 +4. 回来时应能从后端历史恢复完整回答与最终标题 + +这也是为什么标题生成不能放前端。 + +## 12.2 前端断网 / SSE 断开 + +如果网络中断,但后端任务实际还在跑并最终成功: + +1. 消息应继续在后端落库 +2. 标题自动生成任务也应照常执行 +3. 用户刷新后应直接看到完整回答和最终标题 + +## 12.3 用户主动停止回答 + +如果用户点了“停止回答”: + +1. assistant 消息可能为不完整内容 +2. 这类会话默认不建议自动生成标题 +3. 除非后续有完整回答完成,再考虑生成 + +结论: + +**自动标题触发依据必须是后端最终消息状态,而不是前端是否还在线。** + +--- + +## 13. 会话列表排序影响 + +自动标题更新时,建议: + +1. 更新标题名称 +2. 不额外刷新 `updated_at` + +原因: + +如果自动标题生成把 `updated_at` 也改掉,会造成: + +1. 用户明明没发新消息 +2. 会话却突然跳到列表最上方 +3. 视觉上像“又收到一条新消息” + +这不符合直觉。 + +推荐做法: + +1. 聊天消息完成时更新 `updated_at` +2. 自动标题只改 `name / title_source / title_generated_at` +3. 不改 `updated_at` + +这样排序稳定。 + +--- + +## 14. 失败与重试策略 + +## 14.1 失败影响面 + +自动标题失败不应影响: + +1. 主回答返回 +2. 会话可见性 +3. 消息历史读取 +4. 手动重命名 + +标题生成只是增强能力,不是主链路。 + +## 14.2 是否自动重试 + +本期建议: + +1. 单次失败记为 `failed` +2. 不做无限重试 +3. 可在后续加一个后台补偿脚本,扫描: + - `title_source = 'default'` + - `title_generation_status in ('failed', 'idle')` + - 且首轮回答已完成 + +再进行补偿。 + +这样设计更稳,不会把在线请求链路搞复杂。 + +--- + +## 15. 实施步骤建议 + +## 15.1 第一阶段:先把状态能力补齐 + +1. 为 `rag_conversation` 增加标题来源和生成状态字段 +2. 后端 `RenameConversation()` 改为写入 `title_source = 'manual'` +3. 会话列表查询结果继续返回 `name` 即可,本期不必先暴露所有状态给前端 + +## 15.2 第二阶段:接入后端自动标题 + +1. 在首轮回答完成落库后调用 `_maybe_schedule_auto_title(...)` +2. 异步生成并回写标题 +3. 保证不覆盖手动标题 + +## 15.3 第三阶段:清理前端旧逻辑 + +1. 删除 `generateConversationName(...)` 自动调用 +2. 删除 `sidebarRef.autoRename(...)` 自动调用 +3. 仅保留用户主动重命名能力 + +## 15.4 第四阶段:补 UI 细节 + +可选: + +1. 默认标题会话显示一个淡化文案,比如 `新对话` +2. 自动标题生成成功后列表静默更新 +3. 手动重命名后可加一个标记字段供后续分析,但本期无需展示 + +--- + +## 16. 测试清单 + +## 16.1 基本流程 + +1. 新建会话,首轮回答完成后,标题由 `新对话` 自动变为语义标题 +2. 刷新页面后,标题仍正确存在 +3. 进入历史消息页,标题与会话内容匹配 + +## 16.2 手动重命名保护 + +1. 首轮回答刚完成,自动任务尚未回写前,用户手动重命名 +2. 自动任务返回后,不得覆盖手动名称 +3. 后续再次发消息,也不得重新自动改名 + +## 16.3 中断场景 + +1. 发送首问后立刻切到别的会话,等待后台完成 +2. 再切回原会话,应看到完整回答和自动标题 +3. 发送首问后断网,稍后恢复并刷新,应看到完整回答和自动标题 + +## 16.4 失败场景 + +1. 标题生成模型异常,主回答仍正常显示 +2. 会话仍可手动重命名 +3. 列表不应出现空标题 + +## 16.5 排序稳定性 + +1. 自动标题生成前后的会话排序不跳动 +2. 只有真正发消息时会话才因 `updated_at` 变化上浮 + +## 16.6 并发与幂等 + +1. 同一会话只生成一次自动标题 +2. 页面多开、多次刷新,不得生成多个不同标题反复覆盖 +3. 重复触发 `_maybe_schedule_auto_title(...)` 不得造成重复更新 + +--- + +## 17. 推荐落地结论 + +推荐最终决策如下: + +1. 自动标题只由后端生成和裁决 +2. 触发时机锁定为“首轮 assistant 完整回答持久化后” +3. `rag_conversation` 增加 `title_source` 与 `title_generation_status` +4. 手动重命名立即把 `title_source` 切到 `manual` +5. 前端移除现有假自动重命名链路 +6. 标题更新不改 `updated_at` + +这是当前成本、稳定性、后续可维护性之间最平衡的方案。 + +--- + +## 18. 本期不建议做的事 + +1. 不建议继续扩展前端 `generateConversationName()` 语义 +2. 不建议把标题生成塞进主 SSE 返回链路里阻塞回答结束 +3. 不建议做“每轮消息都重算标题” +4. 不建议让自动标题更新触发会话重新排序 +5. 不建议在没有 `title_source` 状态字段前直接上线自动改名 + +--- + +## 19. 后续可扩展项 + +如果后续要继续增强,可以再加: + +1. 会话标题生成提示词版本号 +2. 后台补偿任务 +3. 管理员手动重跑标题生成 +4. 标题质量审计日志 +5. 基于 `last_message_at` 的更精确列表行为 + +但这些都应该在本期稳定方案落地之后再做。 diff --git a/docs/RAG/RAG聊天接口.md b/docs/RAG/RAG聊天接口.md index 44a4313..f4cfd1b 100644 --- a/docs/RAG/RAG聊天接口.md +++ b/docs/RAG/RAG聊天接口.md @@ -309,10 +309,12 @@ data: {"event":"error","task_id":"...","message_id":"...","code":"llm_error","me 服务端行为说明: - 若 `conversationId` 为空,会自动创建新会话 +- 若 `conversationId` 为空,会自动创建新会话,默认标题为 `新对话` - 会先落一条 `role = user` 消息,再流式生成回答 - 流结束后会落一条 `role = assistant` 消息 - 若命中知识库,会把引用结果写入 `sources / metadata` - 会根据对话内容追加 `suggested_questions` +- 首轮回答完成后,后端会异步尝试生成会话标题 - 当前应用解析顺序是:指定 `appId` -> 任意默认应用 -> 排序第一条应用;每一步都只检查当前命中的那一条记录是否可见,不会遍历全部可见应用 ### 4.6 获取会话列表 @@ -337,10 +339,12 @@ data: {"event":"error","task_id":"...","message_id":"...","code":"llm_error","me "data": [ { "id": "b17d3b0b-xxxx-xxxx", - "name": "新对话", + "name": "烟草专卖法中的烟草定义", "introduction": "", + "titleSource": "auto", "createdAt": 1746580000, - "updatedAt": 1746580066 + "updatedAt": 1746580066, + "lastMessageAt": 1746580066 } ], "hasMore": false, @@ -349,6 +353,15 @@ data: {"event":"error","task_id":"...","message_id":"...","code":"llm_error","me } ``` +补充说明: + +- `titleSource` 取值: + - `default`:默认标题,例如 `新对话` + - `auto`:系统自动生成标题 + - `manual`:用户手动重命名 +- 会话列表排序按 `COALESCE(last_message_at, updated_at) DESC, updated_at DESC` +- 自动标题更新不应该被前端视作“新消息到达” + ### 4.7 获取会话消息 `GET /api/v3/rag/chat/conversations/{ConversationId}/messages` @@ -402,6 +415,11 @@ data: {"event":"error","task_id":"...","message_id":"...","code":"llm_error","me } ``` +补充说明: + +- 手动重命名会将该会话的标题来源标记为 `manual` +- 已手动重命名的会话,后续自动标题任务不得覆盖 + 说明: - 返回结构是按“问答对”聚合后的结果,不是底层 `rag_message` 原始逐条结果。 diff --git a/fastapi_admin/config/__init__.py b/fastapi_admin/config/__init__.py index aeaa3d2..9ded1f5 100644 --- a/fastapi_admin/config/__init__.py +++ b/fastapi_admin/config/__init__.py @@ -13,7 +13,7 @@ from ._loader import load_config as _load_config # 优先加载 TOML → os.environ(必须在 Settings 实例化之前) _load_config() -from ._settings import app, jwt, db, redis, oss, llm, vlm, ocr, leaudit as _leaudit # noqa: E402 +from ._settings import app, jwt, db, redis, oss, llm, vlm, embedding, ocr, leaudit as _leaudit # noqa: E402 def _export_settings(instance: object, prefix: str = "") -> dict[str, object]: @@ -48,6 +48,7 @@ _REDIS = _export_settings(redis) _OSS = _export_settings(oss) _LLM = _export_settings(llm) _VLM = _export_settings(vlm) +_EMBEDDING = _export_settings(embedding) _OCR = _export_settings(ocr) _LEAUDIT = _export_settings(_leaudit) @@ -60,6 +61,7 @@ _ALL.update(_REDIS) _ALL.update(_OSS) _ALL.update(_LLM) _ALL.update(_VLM) +_ALL.update(_EMBEDDING) _ALL.update(_OCR) _ALL.update(_LEAUDIT) diff --git a/fastapi_admin/config/__init__.pyi b/fastapi_admin/config/__init__.pyi index 472990e..9194245 100644 --- a/fastapi_admin/config/__init__.pyi +++ b/fastapi_admin/config/__init__.pyi @@ -44,6 +44,14 @@ LLM_API_KEY: str # VLM VLM_BASE_URL: str VLM_MODEL: str +VLM_API_KEY: str + +# EMBEDDING +EMBEDDING_BASE_URL: str +EMBEDDING_MODEL: str +EMBEDDING_API_KEY: str +EMBEDDING_DIM: int +EMBEDDING_BATCH_SIZE: int # OCR OCR_BASE_URL: str diff --git a/fastapi_admin/config/_settings.py b/fastapi_admin/config/_settings.py index 5d599f4..874d3ee 100644 --- a/fastapi_admin/config/_settings.py +++ b/fastapi_admin/config/_settings.py @@ -82,6 +82,15 @@ class VlmSettings(_Base): VLM_API_KEY: str = "" +class EmbeddingSettings(_Base): + """Embedding 配置 [EMBEDDING]。""" + EMBEDDING_BASE_URL: str = "" + EMBEDDING_MODEL: str = "" + EMBEDDING_API_KEY: str = "" + EMBEDDING_DIM: int = 1024 + EMBEDDING_BATCH_SIZE: int = 10 + + class OcrSettings(_Base): """OCR 配置 [OCR]。""" OCR_BASE_URL: str = "" @@ -125,5 +134,6 @@ redis = RedisSettings() oss = OssSettings() llm = LlmSettings() vlm = VlmSettings() +embedding = EmbeddingSettings() ocr = OcrSettings() leaudit = LeauditSettings() diff --git a/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py b/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py index 84c777a..7ebd5a1 100644 --- a/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py +++ b/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py @@ -15,6 +15,7 @@ from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagConversationRenameDTO, RagChatSendMessageDTO, RagMessageFeedbackDTO, + RagStopMessageDTO, ) from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import ( RagDatasetBatchDocumentDeleteDTO, @@ -479,6 +480,17 @@ class RagChatController(BaseController): headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"}, ) + @self.router.post("/chat/messages/{MessageId}/stop", response_model=Result[RagOperationResultVO]) + async def StopMessage( + MessageId: str, + Body: RagStopMessageDTO | None = None, + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有停止 RAG 对话权限", "data": None}) + data = await self.RagChatService.StopMessage(int(payload["user_id"]), MessageId, Body) + return Result.success(data=data) + @self.router.get("/chat/conversations", response_model=Result[RagConversationPageVO]) async def GetConversations( appId: int | None = Query(None, description="聊天应用ID"), diff --git a/fastapi_modules/fastapi_leaudit/domian/Dto/ragChatDto.py b/fastapi_modules/fastapi_leaudit/domian/Dto/ragChatDto.py index 12e530b..730e833 100644 --- a/fastapi_modules/fastapi_leaudit/domian/Dto/ragChatDto.py +++ b/fastapi_modules/fastapi_leaudit/domian/Dto/ragChatDto.py @@ -13,3 +13,7 @@ class RagConversationRenameDTO(BaseModel): class RagMessageFeedbackDTO(BaseModel): rating: str | None = Field(None, description="反馈: like/dislike/None") + + +class RagStopMessageDTO(BaseModel): + taskId: str | None = Field(None, description="流式任务ID") diff --git a/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py b/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py index e020a16..47ec176 100644 --- a/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py +++ b/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py @@ -17,8 +17,10 @@ class RagConversationItemVO(BaseModel): id: str = Field(..., description="会话ID") name: str = Field(..., description="会话名称") introduction: str = Field("", description="会话简介") + titleSource: str = Field("default", description="标题来源: default/auto/manual") createdAt: int = Field(0, description="创建时间戳") updatedAt: int = Field(0, description="更新时间戳") + lastMessageAt: int = Field(0, description="最后一条消息完成时间戳") class RagConversationPageVO(BaseModel): @@ -35,6 +37,7 @@ class RagMessageItemVO(BaseModel): feedback: dict | None = Field(None) retrieverResources: list[dict] | None = Field(None) suggestedQuestions: list[str] = Field(default_factory=list) + status: str = Field("completed") createdAt: int = Field(0) diff --git a/fastapi_modules/fastapi_leaudit/govdoc_engine/llm/client.py b/fastapi_modules/fastapi_leaudit/govdoc_engine/llm/client.py index 6044a00..0635e04 100644 --- a/fastapi_modules/fastapi_leaudit/govdoc_engine/llm/client.py +++ b/fastapi_modules/fastapi_leaudit/govdoc_engine/llm/client.py @@ -38,6 +38,7 @@ from fastapi_admin.config import ( LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS, ) from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.cache import LlmCache, make_key +from fastapi_modules.fastapi_leaudit.rag_engine.config import normalize_openai_base_url _log = logging.getLogger(__name__) @@ -153,8 +154,9 @@ class LlmClient: "LLM_API_KEY is not configured. Set LLM_API_KEY in platform config." ) else: - self._client = OpenAI(api_key=key, base_url=base_url or LLM_BASE_URL) - self._aclient = AsyncOpenAI(api_key=key, base_url=base_url or LLM_BASE_URL) + normalized_base_url = normalize_openai_base_url(base_url or LLM_BASE_URL) + self._client = OpenAI(api_key=key, base_url=normalized_base_url) + self._aclient = AsyncOpenAI(api_key=key, base_url=normalized_base_url) self.model = model or LLM_MODEL self.timeout = timeout_seconds if timeout_seconds is not None else LEAUDIT_LLM_REQUEST_TIMEOUT self.max_retries = max_retries if max_retries is not None else LEAUDIT_LLM_RETRY_MAX_ATTEMPTS diff --git a/fastapi_modules/fastapi_leaudit/leaudit_bridge/client_factory.py b/fastapi_modules/fastapi_leaudit/leaudit_bridge/client_factory.py index 3e2c045..4a55fb1 100644 --- a/fastapi_modules/fastapi_leaudit/leaudit_bridge/client_factory.py +++ b/fastapi_modules/fastapi_leaudit/leaudit_bridge/client_factory.py @@ -29,6 +29,7 @@ from fastapi_modules.fastapi_leaudit.leaudit_bridge.resilient_clients import ( ResilientOpenAICompatibleClient, ResilientQwenVLMClient, ) +from fastapi_modules.fastapi_leaudit.rag_engine.config import normalize_openai_base_url if TYPE_CHECKING: from leaudit.llm.base import BaseLLMClient @@ -68,7 +69,7 @@ def create_ocr_client() -> BaseOCRClient: def create_llm_client() -> BaseLLMClient: """Create a leaudit OpenAICompatibleClient from docauditai's LLM config.""" - base_url = LLM_BASE_URL + base_url = normalize_openai_base_url(LLM_BASE_URL) model = LLM_MODEL api_key = LLM_API_KEY or "no-key" @@ -93,7 +94,7 @@ def create_llm_client() -> BaseLLMClient: def create_vlm_client() -> BaseVLMClient | None: """Create a leaudit QwenVLMClient from docauditai's VLM config.""" - base_url = VLM_BASE_URL + base_url = normalize_openai_base_url(VLM_BASE_URL) model = VLM_MODEL api_key = VLM_API_KEY or LLM_API_KEY or "no-key" diff --git a/fastapi_modules/fastapi_leaudit/rag_engine/config.py b/fastapi_modules/fastapi_leaudit/rag_engine/config.py index ff4f995..63f8ef0 100644 --- a/fastapi_modules/fastapi_leaudit/rag_engine/config.py +++ b/fastapi_modules/fastapi_leaudit/rag_engine/config.py @@ -1,6 +1,6 @@ from __future__ import annotations -from fastapi_admin.config._settings import llm +from fastapi_admin.config._settings import embedding, llm def _get_str(name: str, default: str = "") -> str: @@ -36,11 +36,23 @@ RAG_CONFIG = { "CHROMA_PORT": _get_int("RAG_CHROMA_PORT", 8010), "CHROMA_TOKEN": _get_str("RAG_CHROMA_TOKEN", ""), "CHROMA_AUTH_HEADER": _get_str("RAG_CHROMA_AUTH_HEADER", "X-Chroma-Token"), - "EMBED_URL": _get_str("RAG_EMBED_URL", _get_str("GRAPH_RAG_EMBED_URL", "")), - "EMBED_KEY": _get_str("RAG_EMBED_KEY", _get_str("GRAPH_RAG_EMBED_KEY", "")), - "EMBED_MODEL": _get_str("RAG_EMBED_MODEL", _get_str("GRAPH_RAG_EMBED_MODEL", "")), - "EMBED_DIM": _get_int("RAG_EMBED_DIM", 1024), - "EMBED_BATCH_SIZE": _get_int("RAG_EMBED_BATCH_SIZE", 10), + "EMBED_URL": _get_str( + "RAG_EMBED_URL", + _get_str("GRAPH_RAG_EMBED_URL", _get_str("EMBEDDING_BASE_URL", embedding.EMBEDDING_BASE_URL)), + ), + "EMBED_KEY": _get_str( + "RAG_EMBED_KEY", + _get_str("GRAPH_RAG_EMBED_KEY", _get_str("EMBEDDING_API_KEY", embedding.EMBEDDING_API_KEY)), + ), + "EMBED_MODEL": _get_str( + "RAG_EMBED_MODEL", + _get_str("GRAPH_RAG_EMBED_MODEL", _get_str("EMBEDDING_MODEL", embedding.EMBEDDING_MODEL)), + ), + "EMBED_DIM": _get_int("RAG_EMBED_DIM", _get_int("EMBEDDING_DIM", embedding.EMBEDDING_DIM)), + "EMBED_BATCH_SIZE": _get_int( + "RAG_EMBED_BATCH_SIZE", + _get_int("EMBEDDING_BATCH_SIZE", embedding.EMBEDDING_BATCH_SIZE), + ), "RERANKER_URL": _get_str("RAG_RERANKER_URL", _get_str("GRAPH_RAG_RERANKER_URL", "")), "RERANKER_KEY": _get_str("RAG_RERANKER_KEY", _get_str("GRAPH_RAG_RERANKER_KEY", "")), "RERANKER_MODEL": _get_str("RAG_RERANKER_MODEL", _get_str("GRAPH_RAG_RERANKER_MODEL", "")), @@ -58,3 +70,34 @@ RAG_CONFIG = { "HYBRID_SEARCH": _get_bool("RAG_HYBRID_SEARCH", True), "RERANKING": _get_bool("RAG_RERANKING", True), } + + +def build_openai_chat_completions_url(base_url: str) -> str: + normalized = (base_url or "").strip().rstrip("/") + if not normalized: + return "/chat/completions" + if normalized.endswith("/chat/completions"): + return normalized + return f"{normalized}/chat/completions" + + +def build_openai_embeddings_url(base_url: str) -> str: + normalized = (base_url or "").strip().rstrip("/") + if not normalized: + return "/embeddings" + if normalized.endswith("/chat/completions"): + normalized = normalized[:-len("/chat/completions")] + if normalized.endswith("/embeddings"): + return normalized + return f"{normalized}/embeddings" + + +def normalize_openai_base_url(base_url: str) -> str: + normalized = (base_url or "").strip().rstrip("/") + if not normalized: + return "" + if normalized.endswith("/chat/completions"): + return normalized[:-len("/chat/completions")] + if normalized.endswith("/embeddings"): + return normalized[:-len("/embeddings")] + return normalized diff --git a/fastapi_modules/fastapi_leaudit/rag_engine/generator.py b/fastapi_modules/fastapi_leaudit/rag_engine/generator.py index 08076c9..14efc36 100644 --- a/fastapi_modules/fastapi_leaudit/rag_engine/generator.py +++ b/fastapi_modules/fastapi_leaudit/rag_engine/generator.py @@ -7,7 +7,7 @@ from typing import AsyncGenerator import httpx -from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG +from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_chat_completions_url DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文""" @@ -17,13 +17,14 @@ async def generate_stream( context_chunks: list[dict], conversation_id: str, message_id: str, + task_id: str | None = None, system_prompt: str = "", model: str = "", temperature: float | None = None, max_tokens: int | None = None, dataset_name: str = "", ) -> AsyncGenerator[str, None]: - task_id = str(uuid.uuid4()) + task_id = task_id or str(uuid.uuid4()) created_at = int(time.time()) _model = model or RAG_CONFIG["LLM_MODEL"] _temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"] @@ -55,7 +56,7 @@ async def generate_stream( async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client: async with client.stream( "POST", - f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions", + build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]), json={ "model": _model, "messages": messages, diff --git a/fastapi_modules/fastapi_leaudit/rag_engine/question_chains.py b/fastapi_modules/fastapi_leaudit/rag_engine/question_chains.py index 1941f02..e71c4e4 100644 --- a/fastapi_modules/fastapi_leaudit/rag_engine/question_chains.py +++ b/fastapi_modules/fastapi_leaudit/rag_engine/question_chains.py @@ -4,7 +4,7 @@ import json import httpx -from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG +from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_chat_completions_url async def generate_followups(query: str, answer: str) -> list[str]: @@ -15,7 +15,7 @@ async def generate_followups(query: str, answer: str) -> list[str]: ) async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( - f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions", + build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]), json={ "model": RAG_CONFIG["LLM_MODEL"], "messages": [{"role": "user", "content": prompt}], diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py index 33cf0bb..bcb8641 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py @@ -1,6 +1,9 @@ from __future__ import annotations +import asyncio import json +import re +import time import uuid from typing import AsyncGenerator @@ -14,6 +17,7 @@ from fastapi_common.fastapi_common_web.exception.LeauditException import Leaudit from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagConversationRenameDTO, RagMessageFeedbackDTO, + RagStopMessageDTO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagAppParametersVO, @@ -26,13 +30,30 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagMessagePageVO, RagOperationResultVO, ) -from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG +from fastapi_modules.fastapi_leaudit.rag_engine.config import ( + RAG_CONFIG, + build_openai_chat_completions_url, + build_openai_embeddings_url, +) +from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService +DEFAULT_CONVERSATION_NAME = "新对话" +DEFAULT_TITLE_SOURCE = "default" +AUTO_TITLE_SOURCE = "auto" +MANUAL_TITLE_SOURCE = "manual" + + class RagChatServiceImpl(IRagChatService): + _message_tasks: dict[str, asyncio.Task] = {} + _task_events: dict[str, list[dict]] = {} + _task_done: dict[str, bool] = {} + _task_locks: dict[str, asyncio.Lock] = {} + _title_tasks: dict[str, asyncio.Task] = {} + async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO: apps = await self._load_apps(UserArea, UserRole, only_default=False) return RagChatAppListVO(data=apps, total=len(apps)) @@ -63,6 +84,8 @@ class RagChatServiceImpl(IRagChatService): conversationId = await self._ensure_conversation(CurrentUserId, ConversationId, app["id"]) messageId = str(uuid.uuid4()) + taskId = str(uuid.uuid4()) + is_new_conversation = not ConversationId or ConversationId == "-1" async with GetAsyncSession() as session: async with session.begin(): @@ -79,67 +102,17 @@ class RagChatServiceImpl(IRagChatService): "content": Query, }, ) - - context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), Query) - collected_answer = "" - held_message_end: dict | None = None - - async for chunk in generate_stream( - query=Query, - context_chunks=context_chunks, - conversation_id=conversationId, - message_id=messageId, - system_prompt=app.get("system_prompt") or "", - model=app.get("llm_model") or "", - temperature=app.get("temperature"), - max_tokens=app.get("max_tokens"), - dataset_name=dataset_name, - ): - chunk_bytes = chunk.encode("utf-8") - data = self._parse_sse_event(chunk) - if not data: - yield chunk_bytes - continue - - if data.get("event") == "message": - collected_answer += data.get("answer", "") - yield chunk_bytes - continue - - if data.get("event") == "message_end": - held_message_end = data - continue - - yield chunk_bytes - - followups: list[str] = [] - try: - followups = await generate_followups(Query, collected_answer) - except Exception: - followups = [] - - if held_message_end: - try: - held_message_end.setdefault("metadata", {})["suggested_questions"] = followups - yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8") - except Exception: - yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8") - - async with GetAsyncSession() as session: - async with session.begin(): await session.execute( text( """ INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata) - VALUES (:message_id, :conversation_id, 'assistant', :content, CAST(:sources AS jsonb), CAST(:metadata AS jsonb)) + VALUES (:message_id, :conversation_id, 'assistant', '', '[]'::jsonb, CAST(:metadata AS jsonb)) """ ), { "message_id": messageId, "conversation_id": conversationId, - "content": collected_answer, - "sources": json.dumps(self._build_sources(context_chunks, dataset_name), ensure_ascii=False), - "metadata": json.dumps({"suggested_questions": followups}, ensure_ascii=False), + "metadata": json.dumps({"status": "running", "task_id": taskId}, ensure_ascii=False), }, ) await session.execute( @@ -149,6 +122,45 @@ class RagChatServiceImpl(IRagChatService): {"conversation_id": conversationId}, ) + await self._start_message_task( + task_id=taskId, + conversation_id=conversationId, + message_id=messageId, + query=Query, + app=app, + ) + + event_index = 0 + initial_events: list[dict] = [] + if is_new_conversation: + initial_events.append( + { + "event": "conversation_created", + "conversation_id": conversationId, + "message_id": messageId, + "task_id": taskId, + } + ) + + while True: + if event_index < len(initial_events): + payload = initial_events[event_index] + event_index += 1 + yield self._format_sse(payload) + continue + + events = self._task_events.get(taskId, []) + if event_index - len(initial_events) < len(events): + payload = events[event_index - len(initial_events)] + event_index += 1 + yield self._format_sse(payload) + continue + + if self._task_done.get(taskId): + break + + await asyncio.sleep(0.05) + async def GetConversations(self, CurrentUserId: int, AppId: int | None, Page: int, PageSize: int) -> RagConversationPageVO: async with GetAsyncSession() as session: rows = ( @@ -156,11 +168,13 @@ class RagChatServiceImpl(IRagChatService): text( """ SELECT conversation_id, name, introduction, created_at, updated_at + , COALESCE(title_source, 'default') AS title_source + , COALESCE(EXTRACT(EPOCH FROM last_message_at), 0) AS last_message_at FROM rag_conversation WHERE user_id = :user_id AND deleted_at IS NULL AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT)) - ORDER BY updated_at DESC + ORDER BY COALESCE(last_message_at, updated_at) DESC, updated_at DESC OFFSET :offset LIMIT :limit """ ), @@ -180,8 +194,10 @@ class RagChatServiceImpl(IRagChatService): id=row["conversation_id"], name=row["name"], introduction=row.get("introduction") or "", + titleSource=str(row.get("title_source") or DEFAULT_TITLE_SOURCE), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, + lastMessageAt=int(float(row.get("last_message_at") or 0)), ) for row in items ], @@ -199,7 +215,13 @@ class RagChatServiceImpl(IRagChatService): SELECT message_id, role, content, sources, metadata, feedback, created_at FROM rag_message WHERE conversation_id = :conversation_id - ORDER BY created_at ASC + ORDER BY created_at ASC, + CASE role + WHEN 'user' THEN 0 + WHEN 'assistant' THEN 1 + ELSE 2 + END ASC, + message_id ASC OFFSET :offset LIMIT :limit """ ), @@ -218,20 +240,43 @@ class RagChatServiceImpl(IRagChatService): row = items[idx] if row["role"] == "user": answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None - answer_sources = self._parse_json_field(answer.get("sources")) if answer else [] - answer_metadata = self._parse_json_field(answer.get("metadata")) if answer else {} - suggested_questions = answer_metadata.get("suggested_questions") if isinstance(answer_metadata, dict) else [] - if not isinstance(suggested_questions, list): - suggested_questions = [] + answer_metadata = dict((answer.get("metadata") if answer else None) or {}) + answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running")) + answer_content = (answer.get("content") if answer else None) or "" + + if answer: + task_id = str(answer_metadata.get("task_id") or "").strip() + reconstructed_content = self._rebuild_message_content_from_events(task_id) if task_id else "" + if reconstructed_content and len(reconstructed_content) >= len(answer_content): + if reconstructed_content != answer_content: + await self._update_message_progress( + conversation_id=ConversationId, + message_id=answer["message_id"], + content=reconstructed_content, + metadata=answer_metadata, + ) + answer_content = reconstructed_content + + normalized_status = await self._resolve_persisted_message_status( + conversation_id=ConversationId, + message_id=answer["message_id"], + content=answer_content, + metadata=answer_metadata, + ) + if normalized_status != answer_status: + answer_status = normalized_status + answer_metadata["status"] = normalized_status + data.append( RagMessageItemVO( id=(answer["message_id"] if answer else row["message_id"]), conversationId=ConversationId, query=row["content"], - answer=answer["content"] if answer else "", + answer=answer_content if answer else "", feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None), - retrieverResources=answer_sources or None, - suggestedQuestions=[str(item) for item in suggested_questions], + retrieverResources=(answer.get("sources") if answer else None), + suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()], + status=answer_status, createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, ) ) @@ -240,17 +285,79 @@ class RagChatServiceImpl(IRagChatService): idx += 1 return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize) + async def _resolve_persisted_message_status( + self, + *, + conversation_id: str, + message_id: str, + content: str, + metadata: dict, + ) -> str: + status = str(metadata.get("status") or "completed") + if status != "running": + return status + + task_id = str(metadata.get("task_id") or "").strip() + task = self._message_tasks.get(task_id) if task_id else None + task_done = self._task_done.get(task_id, False) if task_id else False + + if task and not task.done() and not task_done: + return "running" + + normalized_status = "completed" if content.strip() else "error" + normalized_metadata = { + **metadata, + "status": normalized_status, + } + if normalized_status == "error" and not normalized_metadata.get("error"): + normalized_metadata["error"] = "生成任务已结束,但未产出有效回答" + + await self._update_message_progress( + conversation_id=conversation_id, + message_id=message_id, + content=content, + metadata=normalized_metadata, + ) + return normalized_status + + def _rebuild_message_content_from_events(self, task_id: str) -> str: + if not task_id: + return "" + + chunks: list[str] = [] + for event in self._task_events.get(task_id, []): + if event.get("event") != "message": + continue + answer = event.get("answer") + if isinstance(answer, str) and answer: + chunks.append(answer) + return "".join(chunks) + async def RenameConversation(self, CurrentUserId: int, ConversationId: str, Body: RagConversationRenameDTO) -> RagConversationRenameVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) + final_name = Body.name.strip() + if not final_name: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话名称不能为空") async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( - "UPDATE rag_conversation SET name = :name, updated_at = NOW() WHERE conversation_id = :conversation_id" + """ + UPDATE rag_conversation + SET name = :name, + title_source = 'manual', + title_generation_status = CASE + WHEN COALESCE(title_generation_status, 'idle') = 'running' THEN 'succeeded' + ELSE COALESCE(title_generation_status, 'idle') + END, + title_generation_error = NULL, + updated_at = NOW() + WHERE conversation_id = :conversation_id + """ ), - {"name": Body.name, "conversation_id": ConversationId}, + {"name": final_name, "conversation_id": ConversationId}, ) - return RagConversationRenameVO(result="success", name=Body.name) + return RagConversationRenameVO(result="success", name=final_name) async def DeleteConversation(self, CurrentUserId: int, ConversationId: str) -> RagOperationResultVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) @@ -290,6 +397,35 @@ class RagChatServiceImpl(IRagChatService): ) return RagOperationResultVO(result="success") + async def StopMessage(self, CurrentUserId: int, MessageId: str, Body: RagStopMessageDTO | None = None) -> RagOperationResultVO: + async with GetAsyncSession() as session: + row = ( + await session.execute( + text( + """ + SELECT m.metadata, c.user_id + FROM rag_message m + JOIN rag_conversation c ON c.conversation_id = m.conversation_id + WHERE m.message_id = :message_id + AND c.deleted_at IS NULL + LIMIT 1 + """ + ), + {"message_id": MessageId}, + ) + ).mappings().first() + if not row: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在") + if int(row["user_id"]) != CurrentUserId: + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权停止该消息") + + metadata = row.get("metadata") or {} + task_id = str(Body.taskId or metadata.get("task_id") or "").strip() + task = self._message_tasks.get(task_id) if task_id else None + if task and not task.done(): + task.cancel() + return RagOperationResultVO(result="success") + async def GetAppParameters( self, CurrentUserId: int, @@ -400,18 +536,6 @@ class RagChatServiceImpl(IRagChatService): area = row.get("area") or "" return area in ("", "省级", user_area or "") or bool(row.get("dataset_public")) - def _parse_json_field(self, value): - if value is None: - return {} - if isinstance(value, (dict, list)): - return value - if isinstance(value, str): - try: - return json.loads(value) - except Exception: - return {} - return {} - async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str: if conversation_id and conversation_id != "-1": async with GetAsyncSession() as session: @@ -440,10 +564,15 @@ class RagChatServiceImpl(IRagChatService): text( """ INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction) - VALUES (:conversation_id, :user_id, :app_id, '新对话', '') + VALUES (:conversation_id, :user_id, :app_id, :name, '') """ ), - {"conversation_id": conversation_id, "user_id": user_id, "app_id": app_id}, + { + "conversation_id": conversation_id, + "user_id": user_id, + "app_id": app_id, + "name": DEFAULT_CONVERSATION_NAME, + }, ) return conversation_id @@ -490,45 +619,57 @@ class RagChatServiceImpl(IRagChatService): except (TypeError, ValueError): score_threshold = None try: - from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma - except Exception: - return [], dataset.get("name") or "" - try: - collection = get_chroma().get_or_create_collection(dataset["collection_name"]) query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "") + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) result = collection.query( query_embeddings=query_embedding, n_results=max(top_k, 1), include=["documents", "metadatas", "distances"], ) + ids = (result.get("ids") or [[]])[0] if result.get("ids") else [] docs = (result.get("documents") or [[]])[0] metas = (result.get("metadatas") or [[]])[0] distances = (result.get("distances") or [[]])[0] chunks: list[dict] = [] for idx, doc in enumerate(docs): meta = metas[idx] if idx < len(metas) else {} - dist = distances[idx] if idx < len(distances) else 0.0 - distance = max(0.0, float(dist or 0.0)) - score = 1.0 / (1.0 + distance) + dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0 + score = 1.0 / (1.0 + max(dist, 0.0)) if score_threshold is not None and score < score_threshold: continue chunks.append( { - "id": str(meta.get("id") or idx), + "id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx), "text": doc, "source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "", "score": score, - "chunk_index": idx, + "chunk_index": int(meta.get("chunk_index") or idx), "document_name": meta.get("document_name") or meta.get("source") or "", + "document_id": meta.get("document_id"), + "page": meta.get("page"), } ) chunks = await self._hydrate_document_hits(dataset_id, chunks) + if chunks: + return chunks[:top_k], dataset.get("name") or "" + except Exception: + pass + + try: + chunks = await self._keyword_retrieve_context( + dataset_id=dataset_id, + collection_name=str(dataset["collection_name"]), + dataset_name=str(dataset.get("name") or ""), + query=query, + top_k=top_k, + score_threshold=score_threshold, + ) return chunks[:top_k], dataset.get("name") or "" except Exception: return [], dataset.get("name") or "" async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: - embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings" + embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) @@ -567,6 +708,647 @@ class RagChatServiceImpl(IRagChatService): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") return embeddings + async def _start_message_task( + self, + *, + task_id: str, + conversation_id: str, + message_id: str, + query: str, + app: dict, + ) -> None: + self._task_events[task_id] = [] + self._task_done[task_id] = False + self._task_locks.setdefault(task_id, asyncio.Lock()) + task = asyncio.create_task( + self._run_message_task( + task_id=task_id, + conversation_id=conversation_id, + message_id=message_id, + query=query, + app=app, + ) + ) + self._message_tasks[task_id] = task + + async def _run_message_task( + self, + *, + task_id: str, + conversation_id: str, + message_id: str, + query: str, + app: dict, + ) -> None: + context_chunks: list[dict] = [] + dataset_name = "" + collected_answer = "" + message_end_payload: dict | None = None + final_status = "completed" + error_payload: dict | None = None + last_persisted_length = 0 + last_persisted_at = time.monotonic() + + try: + context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query) + async for chunk in generate_stream( + query=query, + context_chunks=context_chunks, + conversation_id=conversation_id, + message_id=message_id, + system_prompt=app.get("system_prompt") or "", + model=app.get("llm_model") or "", + temperature=app.get("temperature"), + max_tokens=app.get("max_tokens"), + dataset_name=dataset_name, + task_id=task_id, + ): + data = self._parse_sse_event(chunk) + if not data: + continue + + event = data.get("event") + if event == "message": + collected_answer += data.get("answer", "") + now = time.monotonic() + if len(collected_answer) - last_persisted_length >= 80 or now - last_persisted_at >= 0.5: + await self._update_message_progress( + conversation_id=conversation_id, + message_id=message_id, + content=collected_answer, + metadata={"status": "running", "task_id": task_id}, + ) + last_persisted_length = len(collected_answer) + last_persisted_at = now + await self._append_task_event(task_id, data) + continue + if event == "message_end": + message_end_payload = data + continue + if event == "error": + final_status = "error" + error_payload = data + await self._append_task_event(task_id, data) + continue + + await self._append_task_event(task_id, data) + + if final_status == "completed": + followups: list[str] = [] + try: + followups = await generate_followups(query, collected_answer) + except Exception: + followups = [] + + if message_end_payload: + message_end_payload.setdefault("metadata", {})["suggested_questions"] = followups + await self._append_task_event(task_id, message_end_payload) + await self._finalize_message_record( + conversation_id=conversation_id, + message_id=message_id, + content=collected_answer, + sources=self._build_sources(context_chunks, dataset_name), + metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id}, + ) + await self._maybe_schedule_auto_title( + conversation_id=conversation_id, + message_id=message_id, + query=query, + answer=collected_answer, + ) + else: + await self._finalize_message_record( + conversation_id=conversation_id, + message_id=message_id, + content=collected_answer, + sources=self._build_sources(context_chunks, dataset_name), + metadata={ + "suggested_questions": [], + "status": "error", + "task_id": task_id, + "error": (error_payload or {}).get("message", ""), + }, + ) + except asyncio.CancelledError: + final_status = "stopped" + await self._append_task_event( + task_id, + { + "event": "error", + "task_id": task_id, + "message_id": message_id, + "code": "message_stopped", + "message": "用户已停止回答", + }, + ) + await self._finalize_message_record( + conversation_id=conversation_id, + message_id=message_id, + content=collected_answer, + sources=self._build_sources(context_chunks, dataset_name), + metadata={"suggested_questions": [], "status": "stopped", "task_id": task_id}, + ) + raise + except Exception as exc: + final_status = "error" + await self._append_task_event( + task_id, + { + "event": "error", + "task_id": task_id, + "message_id": message_id, + "code": "server_error", + "message": str(exc), + }, + ) + await self._finalize_message_record( + conversation_id=conversation_id, + message_id=message_id, + content=collected_answer, + sources=self._build_sources(context_chunks, dataset_name), + metadata={"suggested_questions": [], "status": "error", "task_id": task_id, "error": str(exc)}, + ) + finally: + self._task_done[task_id] = True + self._message_tasks.pop(task_id, None) + self._task_locks.pop(task_id, None) + + async def _append_task_event(self, task_id: str, payload: dict) -> None: + lock = self._task_locks.setdefault(task_id, asyncio.Lock()) + async with lock: + self._task_events.setdefault(task_id, []).append(payload) + + async def _finalize_message_record( + self, + *, + conversation_id: str, + message_id: str, + content: str, + sources: list[dict], + metadata: dict, + ) -> None: + async with GetAsyncSession() as session: + async with session.begin(): + await session.execute( + text( + """ + UPDATE rag_message + SET content = :content, + sources = CAST(:sources AS jsonb), + metadata = CAST(:metadata AS jsonb) + WHERE message_id = :message_id + AND conversation_id = :conversation_id + AND role = 'assistant' + """ + ), + { + "conversation_id": conversation_id, + "message_id": message_id, + "content": content, + "sources": json.dumps(sources, ensure_ascii=False), + "metadata": json.dumps(metadata, ensure_ascii=False), + }, + ) + await session.execute( + text( + """ + UPDATE rag_conversation + SET updated_at = NOW(), + last_message_at = NOW() + WHERE conversation_id = :conversation_id + """ + ), + {"conversation_id": conversation_id}, + ) + + async def _update_message_progress( + self, + *, + conversation_id: str, + message_id: str, + content: str, + metadata: dict, + ) -> None: + async with GetAsyncSession() as session: + async with session.begin(): + await session.execute( + text( + """ + UPDATE rag_message + SET content = :content, + metadata = CAST(:metadata AS jsonb) + WHERE message_id = :message_id + AND conversation_id = :conversation_id + AND role = 'assistant' + """ + ), + { + "conversation_id": conversation_id, + "message_id": message_id, + "content": content, + "metadata": json.dumps(metadata, ensure_ascii=False), + }, + ) + + async def _maybe_schedule_auto_title( + self, + *, + conversation_id: str, + message_id: str, + query: str, + answer: str, + ) -> None: + normalized_query = (query or "").strip() + normalized_answer = (answer or "").strip() + if not normalized_query or not normalized_answer: + return + + async with GetAsyncSession() as session: + row = ( + await session.execute( + text( + """ + SELECT conversation_id, + name, + COALESCE(title_source, 'default') AS title_source, + COALESCE(title_generation_status, 'idle') AS title_generation_status, + first_answer_message_id + FROM rag_conversation + WHERE conversation_id = :conversation_id + AND deleted_at IS NULL + LIMIT 1 + """ + ), + {"conversation_id": conversation_id}, + ) + ).mappings().first() + + if not row: + return + + title_source = str(row.get("title_source") or DEFAULT_TITLE_SOURCE) + if title_source == MANUAL_TITLE_SOURCE: + return + + current_name = str(row.get("name") or "").strip() + if current_name and current_name != DEFAULT_CONVERSATION_NAME and title_source != DEFAULT_TITLE_SOURCE: + return + + current_status = str(row.get("title_generation_status") or "idle") + if current_status in {"pending", "running", "succeeded"}: + return + + if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != message_id: + return + + async with GetAsyncSession() as session: + async with session.begin(): + await session.execute( + text( + """ + UPDATE rag_conversation + SET title_generation_status = 'pending', + first_answer_message_id = COALESCE(first_answer_message_id, :message_id), + title_generation_error = NULL + WHERE conversation_id = :conversation_id + AND deleted_at IS NULL + AND COALESCE(title_source, 'default') = 'default' + """ + ), + { + "conversation_id": conversation_id, + "message_id": message_id, + }, + ) + + existing_task = self._title_tasks.get(conversation_id) + if existing_task and not existing_task.done(): + return + + task = asyncio.create_task( + self._run_auto_title_task( + conversation_id=conversation_id, + answer_message_id=message_id, + query=normalized_query, + answer=normalized_answer, + ) + ) + self._title_tasks[conversation_id] = task + + async def _run_auto_title_task( + self, + *, + conversation_id: str, + answer_message_id: str, + query: str, + answer: str, + ) -> None: + try: + async with GetAsyncSession() as session: + row = ( + await session.execute( + text( + """ + SELECT name, + COALESCE(title_source, 'default') AS title_source, + COALESCE(title_generation_status, 'idle') AS title_generation_status, + first_answer_message_id + FROM rag_conversation + WHERE conversation_id = :conversation_id + AND deleted_at IS NULL + LIMIT 1 + """ + ), + {"conversation_id": conversation_id}, + ) + ).mappings().first() + + if not row: + return + + if str(row.get("title_source") or DEFAULT_TITLE_SOURCE) == MANUAL_TITLE_SOURCE: + return + + if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != answer_message_id: + return + + async with GetAsyncSession() as session: + async with session.begin(): + await session.execute( + text( + """ + UPDATE rag_conversation + SET title_generation_status = 'running', + title_generation_error = NULL + WHERE conversation_id = :conversation_id + AND deleted_at IS NULL + AND COALESCE(title_source, 'default') = 'default' + """ + ), + {"conversation_id": conversation_id}, + ) + + generated_title = await self._generate_conversation_title(query=query, answer=answer) + cleaned_title = self._sanitize_generated_title(generated_title) + if not cleaned_title: + cleaned_title = self._build_fallback_title(query=query, answer=answer) + + if not cleaned_title: + raise ValueError("未生成有效标题") + + async with GetAsyncSession() as session: + async with session.begin(): + await session.execute( + text( + """ + UPDATE rag_conversation + SET name = :name, + title_source = 'auto', + title_generation_status = 'succeeded', + title_generated_at = NOW(), + title_generation_error = NULL + WHERE conversation_id = :conversation_id + AND deleted_at IS NULL + AND COALESCE(title_source, 'default') = 'default' + AND ( + name IS NULL + OR BTRIM(name) = '' + OR name = :default_name + ) + """ + ), + { + "conversation_id": conversation_id, + "name": cleaned_title, + "default_name": DEFAULT_CONVERSATION_NAME, + }, + ) + except Exception as exc: + async with GetAsyncSession() as session: + async with session.begin(): + await session.execute( + text( + """ + UPDATE rag_conversation + SET title_generation_status = 'failed', + title_generation_error = :error + WHERE conversation_id = :conversation_id + AND deleted_at IS NULL + AND COALESCE(title_source, 'default') = 'default' + """ + ), + { + "conversation_id": conversation_id, + "error": str(exc)[:1000], + }, + ) + finally: + self._title_tasks.pop(conversation_id, None) + + async def _generate_conversation_title(self, *, query: str, answer: str) -> str: + prompt = ( + "请基于用户首轮提问和助手首轮回答,生成一个简洁、准确的中文会话标题。" + "要求:" + "1. 只输出标题本身;" + "2. 不要标点结尾;" + "3. 不要出现“关于”“用户询问”“问题解答”等空话;" + "4. 优先 12-24 个中文字符,最长不超过 40 个字符。\\n" + f"用户问题:{query[:500]}\\n" + f"助手回答:{answer[:1500]}" + ) + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]), + json={ + "model": RAG_CONFIG["LLM_MODEL"], + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.2, + "max_tokens": 80, + "chat_template_kwargs": {"enable_thinking": False}, + }, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}", + }, + ) + resp.raise_for_status() + content = resp.json()["choices"][0]["message"]["content"] + return str(content or "").strip() + + def _sanitize_generated_title(self, value: str) -> str: + title = str(value or "").strip() + if not title: + return "" + + title = title.replace("\r", " ").replace("\n", " ").strip() + title = re.sub(r"^```[\w-]*", "", title).strip() + title = title.replace("```", "").strip() + title = re.sub(r'^[\"“”\'‘’\[\]()()【】]+', "", title) + title = re.sub(r'[\"“”\'‘’\[\]()()【】]+$', "", title) + title = re.sub(r"^(标题|会话标题)[::]\\s*", "", title) + title = re.sub(r"\\s+", " ", title).strip() + title = title.rstrip("。!?!?.;;,,::") + + if title in {"新对话", "会话标题", "标题"}: + return "" + + if len(title) > 40: + title = title[:40].rstrip(",,。;;:: ") + + return title + + def _build_fallback_title(self, *, query: str, answer: str) -> str: + base = query.strip() or answer.strip() + if not base: + return "" + + base = re.sub(r"\\s+", " ", base) + base = re.sub(r"^[请问帮我一下关于针对结合根据请解释说明一下::,,\\s]+", "", base) + base = base.rstrip("。!?!?.;;,,::") + if len(base) > 24: + base = base[:24].rstrip(",,。;;:: ") + return base + + async def _keyword_retrieve_context( + self, + *, + dataset_id: int, + collection_name: str, + dataset_name: str, + query: str, + top_k: int, + score_threshold: float | None, + ) -> list[dict]: + collection = get_chroma().get_or_create_collection(collection_name) + raw = collection.get(include=["documents", "metadatas"]) + ids = raw.get("ids") or [] + docs = raw.get("documents") or [] + metas = raw.get("metadatas") or [] + + terms = self._build_keyword_terms(query) + if not terms: + return [] + + scored_chunks: list[dict] = [] + for idx, chunk_id in enumerate(ids): + doc = docs[idx] if idx < len(docs) else "" + meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {} + score = self._score_keyword_chunk( + query=query, + terms=terms, + content=doc or "", + document_name=str(meta.get("document_name") or meta.get("source") or ""), + ) + if score <= 0: + continue + if score_threshold is not None and score < score_threshold: + continue + scored_chunks.append( + { + "id": str(chunk_id), + "text": doc or "", + "source": meta.get("source") or meta.get("document_name") or dataset_name, + "score": score, + "chunk_index": int(meta.get("chunk_index") or idx), + "document_name": meta.get("document_name") or meta.get("source") or "", + "document_id": meta.get("document_id"), + "page": meta.get("page"), + } + ) + + scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0))) + hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)]) + return hydrated[:top_k] + + def _build_keyword_terms(self, query: str) -> list[str]: + normalized = self._normalize_keyword_query(query) + spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()] + if not spans: + return [] + + stop_terms = { + "什么", + "请问", + "一下", + "有关", + "关于", + "如何", + "哪些", + "怎么", + "是否", + "规定", + "办法", + "条例", + "法律", + } + terms: list[str] = [] + for span in spans: + if span in stop_terms: + continue + terms.append(span) + if re.fullmatch(r"[\u4e00-\u9fff]+", span): + for size in (2, 3, 4): + if len(span) > size: + for start in range(0, len(span) - size + 1): + token = span[start:start + size] + if token not in stop_terms: + terms.append(token) + + unique_terms: list[str] = [] + seen: set[str] = set() + for term in sorted(terms, key=len, reverse=True): + if term and term not in seen: + unique_terms.append(term) + seen.add(term) + return unique_terms[:20] + + def _normalize_keyword_query(self, query: str) -> str: + normalized = (query or "").strip().lower() + patterns = [ + "是什么", + "什么是", + "有哪些", + "有什么", + "是什么?", + "是什么?", + "请问", + "介绍一下", + "解释一下", + "帮我分析", + "帮我看看", + ] + for pattern in patterns: + normalized = normalized.replace(pattern, " ") + return re.sub(r"\s+", " ", normalized).strip() + + def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float: + haystack = f"{document_name}\n{content}".lower() + if not haystack: + return 0.0 + + exact_query = self._normalize_keyword_query(query) + if exact_query and exact_query in haystack: + return 0.98 + + matched_weight = 0.0 + total_weight = 0.0 + name_bonus = 0.0 + for term in terms: + weight = float(max(len(term), 1) ** 2) + total_weight += weight + if term.lower() in haystack: + matched_weight += weight + if term.lower() in document_name.lower(): + name_bonus += min(0.15, 0.03 * len(term)) + + if total_weight <= 0: + return 0.0 + score = (matched_weight / total_weight) + name_bonus + return round(min(score, 0.99), 6) + + def _format_sse(self, payload: dict) -> bytes: + return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8") + def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]: return [ { diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py index 360988e..62b7615 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py @@ -36,7 +36,7 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma -from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG +from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService @@ -1503,7 +1503,7 @@ class RagDatasetServiceImpl(IRagDatasetService): return chunks async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: - embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings" + embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) diff --git a/fastapi_modules/fastapi_leaudit/services/ragChatService.py b/fastapi_modules/fastapi_leaudit/services/ragChatService.py index b943605..19fd995 100644 --- a/fastapi_modules/fastapi_leaudit/services/ragChatService.py +++ b/fastapi_modules/fastapi_leaudit/services/ragChatService.py @@ -6,6 +6,7 @@ from typing import AsyncGenerator from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagConversationRenameDTO, RagMessageFeedbackDTO, + RagStopMessageDTO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagAppParametersVO, @@ -52,6 +53,9 @@ class IRagChatService(ABC): @abstractmethod async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO: ... + @abstractmethod + async def StopMessage(self, CurrentUserId: int, MessageId: str, Body: RagStopMessageDTO | None = None) -> RagOperationResultVO: ... + @abstractmethod async def GetAppParameters( self, diff --git a/scripts/创建sql/schema_rag_chat_auto_title.sql b/scripts/创建sql/schema_rag_chat_auto_title.sql new file mode 100644 index 0000000..d87df03 --- /dev/null +++ b/scripts/创建sql/schema_rag_chat_auto_title.sql @@ -0,0 +1,24 @@ +ALTER TABLE rag_conversation + ADD COLUMN IF NOT EXISTS title_source VARCHAR(20) NOT NULL DEFAULT 'default', + ADD COLUMN IF NOT EXISTS title_generation_status VARCHAR(20) NOT NULL DEFAULT 'idle', + ADD COLUMN IF NOT EXISTS title_generated_at TIMESTAMPTZ, + ADD COLUMN IF NOT EXISTS first_question_message_id VARCHAR(100), + ADD COLUMN IF NOT EXISTS first_answer_message_id VARCHAR(100), + ADD COLUMN IF NOT EXISTS title_generation_error TEXT, + ADD COLUMN IF NOT EXISTS last_message_at TIMESTAMPTZ; + +UPDATE rag_conversation +SET title_source = 'default' +WHERE title_source IS NULL OR BTRIM(title_source) = ''; + +UPDATE rag_conversation +SET title_generation_status = 'idle' +WHERE title_generation_status IS NULL OR BTRIM(title_generation_status) = ''; + +CREATE INDEX IF NOT EXISTS idx_rag_conversation_title_generation_status +ON rag_conversation(title_generation_status) +WHERE deleted_at IS NULL; + +CREATE INDEX IF NOT EXISTS idx_rag_conversation_last_message_at +ON rag_conversation(last_message_at DESC) +WHERE deleted_at IS NULL;