from __future__ import annotations import json import time import uuid from typing import AsyncGenerator import httpx from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_chat_completions_url DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文""" async def generate_stream( query: str, context_chunks: list[dict], conversation_id: str, message_id: str, task_id: str | None = None, system_prompt: str = "", model: str = "", temperature: float | None = None, max_tokens: int | None = None, dataset_name: str = "", ) -> AsyncGenerator[str, None]: task_id = task_id or str(uuid.uuid4()) created_at = int(time.time()) _model = model or RAG_CONFIG["LLM_MODEL"] _temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"] _max_tok = max_tokens or RAG_CONFIG["LLM_MAX_TOKENS"] _prompt = system_prompt or DEFAULT_SYSTEM_PROMPT max_context_chars = 8000 if context_chunks: attachment_parts: list[str] = [] formal_parts: list[str] = [] total_len = 0 for chunk in context_chunks: scope = str(chunk.get("source_scope") or chunk.get("data_source_type") or "formal_kb") if scope == "chat_attachment": label = "上传文件事实" else: label = "正式知识库依据" part = f"[{label} | 来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}" if total_len + len(part) > max_context_chars: break if scope == "chat_attachment": attachment_parts.append(part) else: formal_parts.append(part) total_len += len(part) context_sections: list[str] = [] if attachment_parts: context_sections.append("【上传文件事实】\n" + "\\n\\n---\\n\\n".join(attachment_parts)) if formal_parts: context_sections.append("【正式知识库依据】\n" + "\\n\\n---\\n\\n".join(formal_parts)) context_text = "\\n\\n======\\n\\n".join(context_sections) user_content = ( "请严格区分上下文来源:上传文件事实只能用于判断用户材料中的事实;正式知识库依据只能用于法律、处罚、裁量或案例依据。" "如果依据不足,请明确说明,不要编造。\n\n" f"{context_text}\\n\\n用户问题: {query}" ) else: user_content = query messages = [ {"role": "system", "content": _prompt}, {"role": "user", "content": user_content}, ] total_tokens = 0 try: async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client: async with client.stream( "POST", build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]), json={ "model": _model, "messages": messages, "temperature": _temp, "max_tokens": _max_tok, "stream": True, "chat_template_kwargs": {"enable_thinking": False}, }, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}", }, ) as resp: resp.raise_for_status() async for line in resp.aiter_lines(): if not line.startswith("data: "): continue payload = line[6:].strip() if payload == "[DONE]": break chunk = json.loads(payload) choices = chunk.get("choices") or [] first_choice = choices[0] if choices and isinstance(choices[0], dict) else {} delta = first_choice.get("delta", {}) text = delta.get("content", "") if text: yield _sse_line( { "event": "message", "task_id": task_id, "message_id": message_id, "conversation_id": conversation_id, "answer": text, "created_at": created_at, } ) usage = chunk.get("usage") if usage: total_tokens = usage.get("total_tokens", total_tokens) except Exception as exc: yield _sse_line( { "event": "error", "task_id": task_id, "message_id": message_id, "code": "llm_error", "message": str(exc), } ) return retriever_resources = [ { "position": i + 1, "dataset_id": "", "dataset_name": dataset_name, "document_id": "", "document_name": chunk.get("source", ""), "data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file", "segment_id": chunk.get("id", ""), "retriever_from": "rag", "score": round(chunk.get("score", 0.0), 4), "hit_count": 0, "word_count": len(chunk.get("text", "")), "segment_position": i + 1, "index_node_hash": "", "content": chunk.get("text", "")[:500], "page": None, "source_scope": chunk.get("source_scope") or "", "attachment_id": chunk.get("attachment_id") or "", } for i, chunk in enumerate(context_chunks) ] yield _sse_line( { "event": "message_end", "task_id": task_id, "message_id": message_id, "conversation_id": conversation_id, "metadata": { "usage": {"total_tokens": total_tokens}, "retriever_resources": retriever_resources, }, } ) def _sse_line(data: dict) -> str: return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"