diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..145189c0100ccb67a8da21250bcea156505d1eb6 --- /dev/null +++ b/.env.example @@ -0,0 +1,23 @@ +# .env.example +# 数据库配置,默认使用postgresql +DB_TYPE= +# postgresql配置 +POSTGRES_USER= +POSTGRES_PASSWORD= +POSTGRES_HOST= +POSTGRES_PORT= +POSTGRES_DB= + +# MongoDB 配置 +MONGODB_USER= +MONGODB_PASSWORD= +MONGODB_HOST= +MONGODB_PORT= + +# 是否需要认证鉴权: deepInsight集成到openeuler intelligence依赖认证健全,如果单用户部署则不需要,默认用户为admin +REQUIRE_AUTHENTICATION=false + +# 默认监听地址和端口 +HOST=0.0.0.0 +PORT=8000 + diff --git a/.gitignore b/.gitignore index 383d8a0d6ec3d13ec92093194e5a3e756a475119..ead8a8d69049f83709cba08fb2738b9474924770 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea output -**/*/__pycache__ -mcp_config.json \ No newline at end of file +**/__pycache__ +mcp_config.json +/.env diff --git a/deepinsight/api/app.py b/deepinsight/api/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2be7a8dcc81b95348b76161ee71ed53e910e1de6 --- /dev/null +++ b/deepinsight/api/app.py @@ -0,0 +1,153 @@ +import os +import uuid +from datetime import datetime +from typing import Dict + +from flask import Flask, request, jsonify, Response, stream_with_context +from flask_cors import CORS + +from deepinsight.service.conversation import ConversationService +from deepinsight.service.deep_research import MessageType, DeepResearchService +from deepinsight.service.schemas.chat import GetChatHistoryStructure, GetChatHistoryRsp +from deepinsight.service.schemas.conversation import ( + ConversationListRsp, ConversationListMsg, ConversationListItem, + AddConversationRsp, AddConversationMsg +) + +# 读取环境变量中的 API 前缀 +API_PREFIX = os.getenv("API_PREFIX", "") + +# 创建 Flask 实例 +app = Flask(__name__) +_conversations: Dict[str, ConversationListItem] = {} + +# 跨域配置 +CORS(app, resources={r"/*": {"origins": "*"}}) + + +@app.route(f"{API_PREFIX}/api/chat", methods=['POST']) +def chat_stream(): + try: + body = request.get_json() + conversation_id = body.get("conversation_id", "") or str(uuid.uuid4()) + conversation_info = ConversationService.get_conversation_info(conversation_id) + if not conversation_info: + return jsonify({"error": "Conversation not found"}), 404 + + messages = body.get("messages", []) + if not isinstance(messages, list): + return jsonify({"error": "messages must be a list"}), 400 + + query = None + for item in reversed(messages): + if item.get('role') == MessageType.USER.value and item.get("content", None): + query = item.get("content") + + def generate(): + uuid_str = str(uuid.uuid4()) + for item in DeepResearchService.research(query=query, conversation_id=conversation_id, user_id=""): + resp = GetChatHistoryRsp( + code=0, + message="", + data=GetChatHistoryStructure( + id=uuid_str, + conversation_id=conversation_id, + user_id=conversation_info.user_id, + created_time=str(datetime.now()), + title=conversation_info.title, + status=conversation_info.status, + messages=item + ) + ).model_dump_json() + yield f"data: {resp}\n\n" + yield 'data: [DONE]\n\n' + + return Response(stream_with_context(generate()), mimetype="text/event-stream") + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route(f"{API_PREFIX}/api/conversations", methods=['GET']) +def get_conversation_list(): + try: + user_id = "test_user" + offset = int(request.args.get('offset', 0)) + limit = int(request.args.get('limit', 100)) + + conversation_list = ConversationService.get_list(user_id=user_id, offset=offset, limit=limit) + response = ConversationListRsp( + code=0, + message="OK", + data=ConversationListMsg(conversations=conversation_list) + ) + return jsonify(response.dict()) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route(f"{API_PREFIX}/api/conversation", methods=['POST']) +def add_conversation(): + try: + body = request.get_json() + new_conversation = ConversationService.add_conversation( + user_id="test_user", + title=body['title'], + conversation_id=body.get('conversation_id') + ) + + response = AddConversationRsp( + code=0, + message="OK", + data=AddConversationMsg( + conversationId=str(new_conversation.conversation_id), + created_time=str(new_conversation.created_time) + ) + ) + return jsonify(response.dict()) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route(f"{API_PREFIX}/api/conversation", methods=['DELETE']) +def delete_conversation(): + try: + data = request.get_json() + for cid in data['conversation_list']: + ConversationService.del_conversation(conversation_id=cid) + return jsonify({"code": 0, "message": "Deleted", "data": {}}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route(f"{API_PREFIX}/api/conversation", methods=['PUT']) +def rename_conversation(): + try: + data = request.get_json() + conversation, is_succeed = ConversationService.rename_conversation( + conversation_id=data['conversation_id'], + new_name=data['new_name'] + ) + if is_succeed: + return jsonify({"code": 0, "message": "Modified", "data": {"new_name": data['new_name']}}) + else: + return jsonify({"code": 100, "message": "Conversation Not Found", "data": {"new_name": data['new_name']}}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route(f"{API_PREFIX}/api/conversations//messages", methods=['GET']) +def get_conversation_messages(conversation_id): + try: + conversation_info = ConversationService.get_conversation_info(conversation_id) + history_present = ConversationService.get_history_messages(conversation_id) + new_data = GetChatHistoryStructure( + conversation_id=conversation_id, + user_id=conversation_info.user_id, + created_time=str(conversation_info.created_time), + title=conversation_info.title, + status=conversation_info.status, + messages=history_present + ) + return jsonify(GetChatHistoryRsp(code=0, message="ok", data=new_data).dict()) + except Exception as e: + return jsonify({"error": str(e)}), 500 \ No newline at end of file diff --git a/deepinsight/core/agent/base.py b/deepinsight/core/agent/base.py index 4675b7382ab9eefaef4bb565e3922cd1101fe006..5cbdc20b9933900f050d3eda30784c2beaf355f0 100644 --- a/deepinsight/core/agent/base.py +++ b/deepinsight/core/agent/base.py @@ -9,6 +9,7 @@ # See the Mulan PSL v2 for more details. from __future__ import annotations +import uuid from abc import abstractmethod, ABC from typing import Any, Dict, Optional, TypeVar, Generator, Generic @@ -18,7 +19,9 @@ from camel.toolkits import MCPToolkit from deepinsight.config.model import ModelConfig from deepinsight.core.agent.stream_chat_agent import StreamChatAgent -from deepinsight.core.types.messages import Message +from deepinsight.core.prompt.prompt_template import PromptTemplate +from deepinsight.core.types.agent import AgentMessageAdditionType +from deepinsight.core.types.messages import Message, CompleteMessage, MessageMetadataKey from deepinsight.utils.aio import get_or_create_loop OutputType = TypeVar("OutputType") @@ -39,6 +42,7 @@ class BaseAgent(ABC, Generic[OutputType]): model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, ) -> None: """ Initialize the base agent with configuration. @@ -51,6 +55,7 @@ class BaseAgent(ABC, Generic[OutputType]): self.mcp_toolkit_instance = None self.mcp_tools_config_path = mcp_tools_config_path self.mcp_client_timeout = mcp_client_timeout + self.tips_prompt_template = tips_prompt_template self.connect_mcp() if model_config.model_config_dict and model_config.model_config_dict.get("stream", False): @@ -161,6 +166,7 @@ class BaseAgent(ABC, Generic[OutputType]): Returns: T: The parsed output from parse_output() """ + yield from self.pre_run(query, context) prompt = self.build_user_prompt(query=query, context=context) if isinstance(self.agent, StreamChatAgent): response = yield from self.agent.stream_step(prompt) @@ -170,7 +176,16 @@ class BaseAgent(ABC, Generic[OutputType]): self.post_run(output) return output - def post_run(self, output: OutputType) -> None: + def pre_run( + self, + query: str, + context: Dict[str, Any] | None = None, + ) -> Generator[Message, None, None]: + if False: # This technique is used to maintain the generator function characteristics + yield + return + + def post_run(self, output: OutputType) -> Generator[Message, None, None]: """ Post-processing hook that is called after run() completes successfully. @@ -180,4 +195,23 @@ class BaseAgent(ABC, Generic[OutputType]): Args: output: The output from run() method """ - pass \ No newline at end of file + if False: # This technique is used to maintain the generator function characteristics + yield + return + + def yield_tips_messages(self, tips_prompt_name: str, **variables) -> Generator[Message, None, None]: + if self.tips_prompt_template: + yield self._create_tips_messages(tips_prompt_name, **variables) + + def _create_tips_messages(self, tips_prompt_name: str, **variables) -> Message: + content = self.tips_prompt_template.get_prompt( + stage=tips_prompt_name, + variables=variables + ) + return CompleteMessage( + stream_id=str(uuid.uuid4()), + payload=content, + metadata={ + MessageMetadataKey.ADDITION_TYPE: AgentMessageAdditionType.TIPS + } + ) diff --git a/deepinsight/core/agent/planner.py b/deepinsight/core/agent/planner.py index da715cb88a46e8a88d7291e2f845f85c13cb4e7f..cfc18115c9be749386cc7d8c32d12bcd31683915 100644 --- a/deepinsight/core/agent/planner.py +++ b/deepinsight/core/agent/planner.py @@ -13,7 +13,7 @@ import logging import re from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, TypeAlias, Callable +from typing import Any, Dict, List, Optional, TypeAlias, Callable, Generator from camel.messages import BaseMessage from camel.responses import ChatAgentResponse @@ -23,8 +23,9 @@ from typing_extensions import override from deepinsight.config.model import ModelConfig from deepinsight.core.agent.base import BaseAgent -from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage +from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage, PromptTemplate from deepinsight.core.types.historical_message import HistoricalMessage, HistoricalMessageType +from deepinsight.core.types.messages import Message class NotSupportStreamException(Exception): @@ -172,6 +173,7 @@ class Planner(BaseAgent[PlanResult]): model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, plan_parser: PlanParser = None, latest_search_plan: Optional[str] = None, historical_messages: Optional[List[HistoricalMessage]] = None, @@ -187,7 +189,7 @@ class Planner(BaseAgent[PlanResult]): latest_search_plan: Current search plan context historical_messages: List of historical messages """ - super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) + super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout, tips_prompt_template) self.plan_parser = plan_parser or self._default_plan_parser # Init plan if historical_messages: @@ -233,7 +235,7 @@ class Planner(BaseAgent[PlanResult]): return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt( stage=PromptStage.PLAN_USER, variables=dict( query=query, - current_search_plan="\n".join([each.origin_plan for each in self.current_search_plan.search_plans]) if self.current_search_plan else "", + current_search_plan=self._search_plan_text(self.current_search_plan) if self.current_search_plan else "", **context if context is not None else {}, ) ) @@ -271,6 +273,8 @@ class Planner(BaseAgent[PlanResult]): ) if full_response.startswith("开始研究"): + if not self.current_search_plan: + raise NoPlanException("No plan can not start") # If need start research final_plan = deepcopy(self.current_search_plan) final_plan.status = PlanStatus.FINALIZED @@ -315,7 +319,15 @@ class Planner(BaseAgent[PlanResult]): ) return plan_result - def post_run(self, output: PlanResult) -> None: + def post_run(self, output: PlanResult) -> Generator[Message, None, None]: """Post run process.""" - super().post_run(output) self.current_search_plan = output + if output.status == PlanStatus.FINALIZED: + yield from self.yield_tips_messages(PromptStage.PLAN_START_TIPS, search_plans=self._search_plan_text(output)) + yield from super().post_run(output) + + def _search_plan_text(self, plan_result: PlanResult) -> str: + if plan_result.search_plans: + return "\n".join([each.origin_plan for each in plan_result.search_plans]) + else: + return "" diff --git a/deepinsight/core/agent/reporter.py b/deepinsight/core/agent/reporter.py index 8006c065131db03ac2a8479b021d3a9fbbb49d0a..1dd2fbdb671b936e6164c9cdc272c350204d6475 100644 --- a/deepinsight/core/agent/reporter.py +++ b/deepinsight/core/agent/reporter.py @@ -9,19 +9,19 @@ # See the Mulan PSL v2 for more details. import logging import re - -from camel.responses import ChatAgentResponse - -from deepinsight.config.model import ModelConfig -from deepinsight.core.agent.base import BaseAgent, OutputType +import uuid from typing import Any, Dict, Generator from typing import Optional, TypeAlias, Callable, List +from camel.responses import ChatAgentResponse from pydantic import BaseModel +from deepinsight.config.model import ModelConfig +from deepinsight.core.agent.base import BaseAgent, OutputType from deepinsight.core.agent.researcher import ResearchExecution -from deepinsight.core.types.messages import Message -from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage +from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage, PromptTemplate +from deepinsight.core.types.agent import AgentExecutePhase +from deepinsight.core.types.messages import Message, CompleteMessage, MessageMetadataKey from deepinsight.utils.parallel_worker_utils import Executor @@ -56,11 +56,13 @@ class GenerateSubTaskAgent(BaseAgent[List[WritingTask]]): Inherits from BaseAgent with List[WritingTask] as the concrete output type. """ + def __init__( self, model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, report_plan_parser: Optional[ReportPlanParser] = None, ) -> None: """ @@ -72,7 +74,7 @@ class GenerateSubTaskAgent(BaseAgent[List[WritingTask]]): mcp_client_timeout: Timeout for MCP client operations report_plan_parser: Custom parser for converting LLM responses to writing tasks """ - super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) + super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout, tips_prompt_template) self.report_plan_parser = report_plan_parser def build_system_prompt(self) -> str: @@ -89,6 +91,13 @@ class GenerateSubTaskAgent(BaseAgent[List[WritingTask]]): return self.report_plan_parser(response.msg.content) return super().parse_output(response) + def pre_run(self, query: str, context: Dict[str, Any] | None = None) -> Generator[Message, None, None]: + yield from self.yield_tips_messages(tips_prompt_name=PromptStage.REPORT_PLAN_TIPS) + + def post_run(self, output: OutputType) -> Generator[Message, None, None]: + yield from self.yield_tips_messages(tips_prompt_name=PromptStage.REPORT_WRITE_TIPS) + yield from super().post_run(output) + class ExecuteSubTaskAgent(BaseAgent[str]): """ @@ -97,10 +106,15 @@ class ExecuteSubTaskAgent(BaseAgent[str]): Inherits from BaseAgent with str as the concrete output type. """ - def __init__(self, model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, - mcp_client_timeout: Optional[int] = None) -> None: + def __init__( + self, + model_config: ModelConfig, + mcp_tools_config_path: Optional[str] = None, + mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, + ) -> None: """Initialize with model and MCP configuration.""" - super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) + super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout, tips_prompt_template) def build_system_prompt(self) -> str: return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt(PromptStage.REPORT_WRITE_SYSTEM) @@ -130,6 +144,7 @@ class Reporter: model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, report_plan_parser: Optional[ReportPlanParser] = None, report_post_processer: Optional[ReportPostProcesser] = None, ): @@ -146,6 +161,7 @@ class Reporter: self.model_config = model_config self.mcp_tools_config_path = mcp_tools_config_path self.mcp_client_timeout = mcp_client_timeout + self.tips_prompt_template = tips_prompt_template self.report_plan_parser = report_plan_parser or self._default_report_plan_parser self.report_post_processer = report_post_processer or self._default_report_post_processer @@ -168,6 +184,14 @@ class Reporter: str: The final generated report """ writing_tasks = yield from self._generate_writing_task(query=query, research_executions=research_executions) + # Report plan has complete, send a phase message + yield CompleteMessage( + stream_id=str(uuid.uuid4()), + payload=None, + metadata={ + MessageMetadataKey.AGENT_EXECUTE_PHASE: AgentExecutePhase.REPORT_WRITING + } + ) writing_report_executor = Executor("writing_report") def report_writing_worker(i, writing_task: WritingTask): @@ -190,6 +214,7 @@ class Reporter: self.model_config, self.mcp_tools_config_path, self.mcp_client_timeout, + self.tips_prompt_template, self.report_plan_parser, ) research_info = self._construct_research_info(research_executions) @@ -208,7 +233,13 @@ class Reporter: def _write_task(self, query, writing_task: WritingTask) -> Generator[Message, None, str]: """Execute an individual writing task.""" - write_agent = ExecuteSubTaskAgent(self.model_config, self.mcp_tools_config_path, self.mcp_client_timeout) + write_agent = ExecuteSubTaskAgent( + self.model_config, + self.mcp_tools_config_path, + self.mcp_client_timeout, + self.tips_prompt_template, + ) + report = yield from write_agent.run( query=query, context=dict( diff --git a/deepinsight/core/agent/researcher.py b/deepinsight/core/agent/researcher.py index 1d994b11cb3475b2b2d5c541b8f71d68a33e099b..fce4aaec48ae175d99f1f8365dc8afe9d6dbaee1 100644 --- a/deepinsight/core/agent/researcher.py +++ b/deepinsight/core/agent/researcher.py @@ -19,8 +19,9 @@ from pydantic import BaseModel, Field from deepinsight.config.model import ModelConfig from deepinsight.core.agent.base import BaseAgent from deepinsight.core.agent.planner import SearchPlan, PlanResult -from deepinsight.core.types.messages import Message -from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage +from deepinsight.core.types.agent import AgentMessageAdditionType +from deepinsight.core.types.messages import Message, CompleteMessage, MessageMetadataKey +from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage, PromptTemplate from deepinsight.utils.parallel_worker_utils import Executor @@ -95,6 +96,7 @@ class RolePlayingUser(BaseAgent[ChatAgentResponse]): Specializes BaseAgent to handle user-side interactions in research dialogues. """ + def __init__( self, model_config: ModelConfig, @@ -106,7 +108,7 @@ class RolePlayingUser(BaseAgent[ChatAgentResponse]): def build_system_prompt(self) -> str: return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt(PromptStage.RESEARCH_ROLE_PLAYING_USER_SYSTEM) - def build_user_prompt(self, *, query:str, context: Dict[str, Any] | None = None) -> str: + def build_user_prompt(self, *, query: str, context: Dict[str, Any] | None = None) -> str: return query @@ -116,6 +118,7 @@ class RolePlayingAssistant(BaseAgent[ChatAgentResponse]): Specializes BaseAgent to handle assistant-side interactions in research dialogues. """ + def __init__(self, model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None) -> None: super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) @@ -123,8 +126,7 @@ class RolePlayingAssistant(BaseAgent[ChatAgentResponse]): def build_system_prompt(self) -> str: return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt(PromptStage.RESEARCH_ROLE_PLAYING_ASSISTANT_SYSTEM) - - def build_user_prompt(self, *, query:str, context: Dict[str, Any] | None = None) -> str: + def build_user_prompt(self, *, query: str, context: Dict[str, Any] | None = None) -> str: return query @@ -137,6 +139,7 @@ class StreamRolePlaying: - Message exchange - Termination conditions """ + def __init__( self, model_config: ModelConfig, @@ -144,7 +147,7 @@ class StreamRolePlaying: mcp_client_timeout: Optional[int] = None, should_terminate_callback: Optional[ResearchShouldTerminateCallback] = None, **kwargs, - ) -> None: + ) -> None: """ Initialize the role-playing orchestrator. @@ -179,7 +182,7 @@ class StreamRolePlaying: self, query: str, context: Dict[str, Any] | None = None, - ) -> Generator[Message, None, Tuple[ChatAgentResponse, ChatAgentResponse]]: + ) -> Generator[Message, None, Tuple[ChatAgentResponse, ChatAgentResponse]]: """ Execute a complete role-playing dialogue turn. @@ -218,7 +221,8 @@ class StreamRolePlaying: ) user_msg = user_response.msg - assistant_response: ChatAgentResponse = yield from self.assistant_agent.run(query=user_msg.content, context=context) + assistant_response: ChatAgentResponse = yield from self.assistant_agent.run(query=user_msg.content, + context=context) if assistant_response.terminated or assistant_response.msgs is None: return ( ChatAgentResponse( @@ -261,6 +265,7 @@ class Researcher: model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, round_limit: int = 15, should_terminate_callback: ResearchShouldTerminateCallback = None, ) -> None: @@ -277,6 +282,7 @@ class Researcher: self.model_config = model_config self.mcp_tools_config_path = mcp_tools_config_path self.mcp_client_timeout = mcp_client_timeout + self.tips_prompt_template = tips_prompt_template self.round_limit = round_limit self.should_terminate_callback: ResearchShouldTerminateCallback = should_terminate_callback or self._default_should_terminate_callback @@ -296,12 +302,27 @@ class Researcher: """ # Parallel search info search_parallel_executor = Executor("Search") + def search_info_worker(i, search_step: SearchPlan): + yield CompleteMessage( + stream_id=str(uuid.uuid4()), + payload=self.tips_prompt_template.get_prompt( + stage=PromptStage.RESEARCH_START_TIPS, + variables=dict( + task_id=i + 1, + task_title=search_step.title, + ) + ), + metadata={ + MessageMetadataKey.ADDITION_TYPE: AgentMessageAdditionType.TIPS + } + ) one_search_content = yield from self._search_info_with_role_playing( query=query, search_plan=search_step, ) return one_search_content or [] + all_content = yield from search_parallel_executor(search_info_worker, list(enumerate(plan_result.search_plans))) flattened = [] @@ -350,7 +371,7 @@ class Researcher: stage=PromptStage.RESEARCH_ROLE_PLAYING_USER_USER, variables=dict( query=query, - current_plan=search_plan.origin_plan, + current_plan=search_plan.origin_plan, ) ) @@ -368,7 +389,8 @@ class Researcher: last_user_response = user_response assistant_execution_step = ExecutionStep( content=assistant_response.msg.content, - tool_calls=assistant_response.info.get("tool_calls") if assistant_response.info.get("tool_calls") else None + tool_calls=assistant_response.info.get("tool_calls") if assistant_response.info.get( + "tool_calls") else None ) user_execution_stop = ExecutionStep( diff --git a/deepinsight/core/orchestrator.py b/deepinsight/core/orchestrator.py index 2aa969f34ce97bc47a4bd1e779c89c2d9f6ae9eb..18175ef953be32423b48e59d3718c78e3d4a9c88 100644 --- a/deepinsight/core/orchestrator.py +++ b/deepinsight/core/orchestrator.py @@ -17,9 +17,10 @@ from deepinsight.config.model import ModelConfig from deepinsight.core.agent.planner import Planner, PlanStatus from deepinsight.core.agent.reporter import Reporter from deepinsight.core.agent.researcher import Researcher -from deepinsight.core.types.agent import AgentType +from deepinsight.core.prompt.prompt_template import PromptTemplate, PromptStage +from deepinsight.core.types.agent import AgentType, AgentExecutePhase from deepinsight.core.types.historical_message import HistoricalMessage -from deepinsight.core.types.messages import Message +from deepinsight.core.types.messages import Message, MessageMetadataKey class OrchestratorStatus(BaseModel): @@ -42,7 +43,8 @@ class OrchestratorStatusType(str, Enum): PENDING = OrchestratorStatus(status="pending") PLANNING = OrchestratorStatus(status="planning") RESEARCHING = OrchestratorStatus(status="researching") - REPORTING = OrchestratorStatus(status="reporting") + REPORT_PLANNING = OrchestratorStatus(status="report_planning") + REPORT_WRITING = OrchestratorStatus(status="report_writing") COMPLETED = OrchestratorStatus(status="completed") FAILED = OrchestratorStatus(status="failed") @@ -78,7 +80,7 @@ class OrchestrationRequest(BaseModel): ) -class OrchestratorResult(BaseModel): +class OrchestrationResult(BaseModel): """ Container for orchestration process outputs with interactive capabilities. @@ -130,6 +132,7 @@ class Orchestrator: mcp_client_timeout: Optional[int] = None, research_round_limit: int = 5, init_request: Optional[OrchestrationRequest] = None, + execute_tips_template_dict: Optional[Dict[Union[str, PromptStage], str]] = None, ) -> None: """ Initialize the orchestration engine with configuration. @@ -141,10 +144,12 @@ class Orchestrator: research_round_limit: Maximum number of research iterations init_request: Init request for orchestration """ + self.agent_execute_phase_tips_template = self._init_tips_template(execute_tips_template_dict) self.planner = Planner( model_config=model_config, mcp_tools_config_path=mcp_tools_config_path, mcp_client_timeout=mcp_client_timeout, + tips_prompt_template=self.agent_execute_phase_tips_template, historical_messages=init_request.agent_historical_messages.get( AgentType.PLANNER, [] ) if init_request else [] @@ -154,6 +159,7 @@ class Orchestrator: model_config=model_config, mcp_tools_config_path=mcp_tools_config_path, mcp_client_timeout=mcp_client_timeout, + tips_prompt_template=self.agent_execute_phase_tips_template, round_limit=research_round_limit ) @@ -161,6 +167,7 @@ class Orchestrator: model_config=model_config, mcp_tools_config_path=mcp_tools_config_path, mcp_client_timeout=mcp_client_timeout, + tips_prompt_template=self.agent_execute_phase_tips_template, ) self.current_phase = OrchestratorStatusType.PENDING self.start_time = None @@ -169,7 +176,7 @@ class Orchestrator: def run( self, query: str - ) -> Generator[Union[Message, OrchestratorStatusType], None, OrchestratorResult]: + ) -> Generator[Union[Message, OrchestratorStatusType], None, OrchestrationResult]: """ Execute the full orchestration workflow for a given query. @@ -180,7 +187,7 @@ class Orchestrator: Union[Message, PhaseStartMessage]: Progress messages during execution Returns: - OrchestratorResult: Final output artifacts + OrchestrationResult: Final output artifacts Raises: OrchestrationException: If any phase fails @@ -197,13 +204,13 @@ class Orchestrator: plan_result = yield from self.planner.run(query) if plan_result.requires_user_input: # Need user feedback - return OrchestratorResult( + return OrchestrationResult( require_user_interactive=True, require_user_feedback=plan_result.information_required, ) elif plan_result.status == PlanStatus.DRAFT: # Need user confirm plan draft - return OrchestratorResult( + return OrchestrationResult( require_user_interactive=True, plan_draft="\n".join([plan.origin_plan for plan in plan_result.search_plans]) ) @@ -218,16 +225,27 @@ class Orchestrator: ) # Phase 3: report - self.current_phase = OrchestratorStatusType.REPORTING - yield OrchestratorStatusType.RESEARCHING - report_result = yield from self.reporter.run( + self.current_phase = OrchestratorStatusType.REPORT_PLANNING + yield OrchestratorStatusType.REPORT_PLANNING + + reporter_run_generator = self.reporter.run( query=query, research_executions=research_executions, ) + try: + while True: + item = next(reporter_run_generator) + if item.metadata.get(MessageMetadataKey.AGENT_EXECUTE_PHASE, None) == AgentExecutePhase.REPORT_WRITING: + self.current_phase = OrchestratorStatusType.REPORT_WRITING + yield OrchestratorStatusType.REPORT_WRITING + else: + yield item + except StopIteration as e: + report_result = e.value self.current_phase = OrchestratorStatusType.COMPLETED - return OrchestratorResult(report=report_result) + return OrchestrationResult(report=report_result) except Exception as exc: self.current_phase = OrchestratorStatusType.FAILED @@ -237,3 +255,13 @@ class Orchestrator: ) from exc finally: self.end_time = datetime.utcnow() + + + def _init_tips_template(self, execute_tips_template_dict: Dict): + tips_template = PromptTemplate( + template_dict={} + ) + if execute_tips_template_dict: + for key, value in execute_tips_template_dict.items(): + tips_template.add_template(key, value) + return tips_template diff --git a/deepinsight/core/prompt/prompt_template.py b/deepinsight/core/prompt/prompt_template.py index cdaa9f572cbed31eb2bd1fb1dd7f32f3fa73d594..df6757f1aba26df2713a291d65d836c92a4d3711 100644 --- a/deepinsight/core/prompt/prompt_template.py +++ b/deepinsight/core/prompt/prompt_template.py @@ -31,6 +31,10 @@ class PromptStage(str, Enum): REPORT_PLAN_USER = "report_plan_user" REPORT_WRITE_SYSTEM = "report_write_system" REPORT_WRITE_USER = "report_write_user" + PLAN_START_TIPS = "plan_start_tips" + RESEARCH_START_TIPS = "research_start_tips" + REPORT_PLAN_TIPS = "report_plan_tips" + REPORT_WRITE_TIPS = "report_write_tips" ERROR_RECOVERY = "error_recovery" CUSTOM = "custom" diff --git a/deepinsight/core/types/agent.py b/deepinsight/core/types/agent.py index 562fdef600e9b1dc722405ed5ee197afc59d35c4..8877b28c9e7005b24807d89921d515d9354099d5 100644 --- a/deepinsight/core/types/agent.py +++ b/deepinsight/core/types/agent.py @@ -13,4 +13,13 @@ from enum import Enum class AgentType(str, Enum): PLANNER = "planner" RESEARCHER = "researcher" - REPORTER = "reporter" \ No newline at end of file + REPORTER = "reporter" + + +class AgentMessageAdditionType(str, Enum): + TIPS = "tips" + + +class AgentExecutePhase(str, Enum): + REPORT_PLANING = "report_planning" + REPORT_WRITING = "report_writing" \ No newline at end of file diff --git a/deepinsight/core/types/messages.py b/deepinsight/core/types/messages.py index 99bc6fbda495c0b98b3fd0c2237ed809b3ef1415..589f088e4edc285b8a99fc8288e6357c696e4e86 100644 --- a/deepinsight/core/types/messages.py +++ b/deepinsight/core/types/messages.py @@ -126,6 +126,12 @@ class HeartbeatMessage(BaseMessage): message_type: Literal[MessageType.HEARTBEAT] = MessageType.HEARTBEAT latency_ms: Optional[int] = Field(None, ge=0) + +class MessageMetadataKey(str, Enum): + ADDITION_TYPE = "addition_type" + AGENT_EXECUTE_PHASE = "agent_execute_phase" + + # Union type representing all possible message types Message = Union[ StartMessage[T], diff --git a/deepinsight/app.py b/deepinsight/main.py similarity index 96% rename from deepinsight/app.py rename to deepinsight/main.py index df0d0c23be714872e0bfa08da2faa54efe9405ea..40d4b412a3b337abfea517eb498cdd9150e8b8cc 100644 --- a/deepinsight/app.py +++ b/deepinsight/main.py @@ -13,11 +13,11 @@ from pathlib import Path from camel.types import ModelPlatformType, ModelType from deepinsight.config.model import ModelConfig -from deepinsight.core.orchestrator import Orchestrator, OrchestratorResult +from deepinsight.core.orchestrator import Orchestrator, OrchestrationResult from deepinsight.utils.console_utils import display_stream -def save_artifact(output_dir: Path, data: OrchestratorResult) -> None: +def save_artifact(output_dir: Path, data: OrchestrationResult) -> None: """Save research artifacts to files""" output_dir.mkdir(exist_ok=True) @@ -66,7 +66,7 @@ def main(): mcp_tools_config_path="./mcp_config.json", research_round_limit=1, ) - result: OrchestratorResult = display_stream(orchestration.run(args.query)) + result: OrchestrationResult = display_stream(orchestration.run(args.query)) # Handle interactive states while result.require_user_interactive: diff --git a/deepinsight/server.py b/deepinsight/server.py new file mode 100644 index 0000000000000000000000000000000000000000..80437ebe98358e351f96c57b5f593140394e43e8 --- /dev/null +++ b/deepinsight/server.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import os + +import uvicorn + +if __name__ == "__main__": + # 从环境变量获取配置,或使用默认值 + # 从环境变量获取配置,或使用默认值 + host = os.getenv("HOST", "0.0.0.0") # 从环境变量读取HOST,默认监听所有接口 + port = int(os.getenv("PORT", "8000")) # 从环境变量读取PORT,默认8000 + + # 配置日志 + log_config = uvicorn.config.LOGGING_CONFIG + log_config["formatters"]["access"]["fmt"] = '%(asctime)s - %(levelname)s - %(message)s' + log_config["formatters"]["default"]["fmt"] = '%(asctime)s - %(levelname)s - %(message)s' + + # 启动UVicorn服务器 + uvicorn.run( + "deepinsight.api.app:app_instance", + host=host, + port=port, + reload=True, # 开发时启用热重载 + log_config=log_config, + workers=1, # 生产环境可以增加worker数量 + access_log=True + ) diff --git a/deepinsight/service/__init__.py b/deepinsight/service/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..d2976bf8bcec80688ad297611a6809da4f37fd4a 100644 --- a/deepinsight/service/__init__.py +++ b/deepinsight/service/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. \ No newline at end of file diff --git a/deepinsight/service/auth/session_manager.py b/deepinsight/service/auth/session_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a602f0edf0b046118ebdd65975fa00b92a29a4ff --- /dev/null +++ b/deepinsight/service/auth/session_manager.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import logging + +from deepinsight.stores.mongodb.database import MongoDB, Session + + +class SessionManager: + """浏览器Session管理""" + + @staticmethod + async def verify_user(session_id: str) -> bool: + """验证用户是否在Session中""" + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return False + return Session(**data).user_sub is not None + except Exception as e: + err = "用户不在Session中" + logging.error("[SessionManager] %s", err) + raise e + + @staticmethod + async def get_user_sub(session_id: str) -> str: + """从Session中获取用户""" + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return None + user_sub = Session(**data).user_sub + except Exception as e: + err = "从Session中获取用户失败" + logging.error("[SessionManager] %s", err) + raise e + + return user_sub diff --git a/deepinsight/service/auth/session_service.py b/deepinsight/service/auth/session_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb59a25c7b841188ce4ec3b26eba75c59a2695c --- /dev/null +++ b/deepinsight/service/auth/session_service.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import uuid + +from starlette import status +from starlette.exceptions import HTTPException +from starlette.requests import HTTPConnection +from deepinsight.service.auth.session_manager import SessionManager + + +class UserHTTPException(HTTPException): + def __init__(self, status_code: int, retcode: int, rtmsg: str, data): + super().__init__(status_code=status_code) + self.retcode = retcode + self.rtmsg = rtmsg + self.data = data + + +async def verify_user(request: HTTPConnection): + """验证用户是否在Session中""" + try: + session_id = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + session_id = auth_header.split(" ", 1)[1] + elif "ECSESSION" in request.cookies: + session_id = request.cookies["ECSESSION"] + if session_id is None: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + except: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + if not (await SessionManager.verify_user(session_id)): + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + + +async def get_user_sub(request: HTTPConnection) -> uuid: + """从Session中获取用户""" + # if config["DEBUG"]: + # await UserManager.add_user((await Convertor.convert_user_sub_to_user_entity('admin'))) + # return "admin" + try: + session_id = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + session_id = auth_header.split(" ", 1)[1] + elif "ECSESSION" in request.cookies: + session_id = request.cookies["ECSESSION"] + except: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + user_sub = await SessionManager.get_user_sub(session_id) + if not user_sub: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + return user_sub diff --git a/deepinsight/service/auth/user_dependency.py b/deepinsight/service/auth/user_dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..16f6fe28639b70402eeb96a15e1d27e75d8868a7 --- /dev/null +++ b/deepinsight/service/auth/user_dependency.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import os +from fastapi import Depends +from starlette.requests import HTTPConnection + +from deepinsight.service.auth.session_manager import SessionManager +from deepinsight.service.auth.session_service import UserHTTPException + + +async def get_current_user(request: HTTPConnection): + """ + 获取当前用户的依赖项 + + 此函数会先检查环境变量是否需要认证,不需要则直接返回admin + 需要认证时会验证用户身份,并将用户信息添加到请求对象中 + 可作为依赖项在路由处理函数中使用 + """ + try: + # 检查环境变量是否需要认证鉴权,默认需要 + need_auth = os.getenv("REQUIRE_AUTHENTICATION", "true").lower() != "false" + + # 如果不需要认证,直接返回默认用户admin + if not need_auth: + request.state.user = "admin" + return "admin" + + # 从请求头或cookie中获取session_id + session_id = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + session_id = auth_header.split(" ", 1)[1] + elif "ECSESSION" in request.cookies: + session_id = request.cookies["ECSESSION"] + + if not session_id: + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg="Authentication Error: No session ID found", + data="" + ) + + # 验证用户会话 + if not await SessionManager.verify_user(session_id): + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg="Authentication Error: Invalid session", + data="" + ) + + # 获取用户信息 + user_sub = await SessionManager.get_user_sub(session_id) + if not user_sub: + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg="Authentication Error: User not found", + data="" + ) + + # 将用户信息添加到请求状态中 + request.state.user = user_sub + return user_sub + + except Exception as e: + if isinstance(e, UserHTTPException): + raise e + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg=f"Authentication Error: {str(e)}", + data="" + ) + + +# 用于需要验证用户的路由的装饰器依赖 +require_user = Depends(get_current_user) \ No newline at end of file diff --git a/deepinsight/service/conversation.py b/deepinsight/service/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..490bb6e59f26892739bce0b6f4982c45012d0de2 --- /dev/null +++ b/deepinsight/service/conversation.py @@ -0,0 +1,84 @@ +from typing import List + +from deepinsight.service.deep_research import DeepResearchService +from deepinsight.service.schemas.chat import ServiceMessage +from deepinsight.service.schemas.conversation import ConversationListItem +from deepinsight.stores.postgresql.database import get_database_session +from deepinsight.stores.postgresql.repositories.conversation_repository import ConversationRepository +from deepinsight.stores.postgresql.repositories.message_repository import MessageRepository +from deepinsight.stores.postgresql.repositories.report_repository import ReportRepository + + +class ConversationService: + @classmethod + def get_list(cls, user_id, offset: int = 0, limit: int = 100): + db = get_database_session() + conversation_repo = ConversationRepository(db) + conversation_list = conversation_repo.get_by_user_id(user_id=user_id, offset=offset, limit=limit) + conv_item_list = [] + for conv in conversation_list: + conv_item_list.append(ConversationListItem( + conversationId=str(conv.conversation_id), + title=conv.title, + createdTime=str(conv.created_time) + ) + ) + return conv_item_list + + @classmethod + def del_conversation(cls, conversation_id): + db = get_database_session() + conversation_repo = ConversationRepository(db) + message_repo = MessageRepository(db) + report_repo = ReportRepository(db) + message_repo.delete_by_conversation_id(conversation_id=conversation_id) + report_repo.delete_by_conversation_id(conversation_id=conversation_id) + conversation_repo.delete_conversation(conversation_id=conversation_id) + + @classmethod + def rename_conversation(cls, conversation_id, new_name): + db = get_database_session() + conversation_repo = ConversationRepository(db) + return conversation_repo.update_title(conversation_id=conversation_id, new_title=new_name) + + @classmethod + def add_conversation(cls, user_id, title, conversation_id): + db = get_database_session() + conversation_repo = ConversationRepository(db) + saved_conversation = conversation_repo.create_conversation(user_id, title, conversation_id) + return saved_conversation + + @classmethod + def get_conversation_info(cls, conversation_id_str: str): + db = get_database_session + repository = ConversationRepository(db) + return repository.get_by_id(conversation_id_str) + + @classmethod + def get_history_messages(cls, conversation_id_str: str) -> List[ServiceMessage]: + db = get_database_session() + repository = MessageRepository(db) + messages_from_db = repository.get_by_conversation_id(conversation_id_str) + + processed_messages = [] + for msg in messages_from_db: + content_to_use = msg.content + if msg.type == "report": + processed_report = DeepResearchService.get_report_and_thought_by_message_id(msg.message_id) + content_to_use = processed_report.thought.messages + [processed_report.report] + + # @TODO: 还差updated time + + processed_message = ServiceMessage( + id=str(msg.message_id), + content=content_to_use, + role=msg.type, + created_at=msg.created_time.isoformat() if msg.created_time else None + ) + + processed_messages.append(processed_message) + return processed_messages + + +if __name__ == '__main__': + print(ConversationService.add_conversation(user_id="1")) diff --git a/deepinsight/service/deep_research.py b/deepinsight/service/deep_research.py new file mode 100644 index 0000000000000000000000000000000000000000..2240b1ed498fdafa928b6893282994858bf93f18 --- /dev/null +++ b/deepinsight/service/deep_research.py @@ -0,0 +1,507 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import logging +import uuid +from datetime import datetime +from enum import Enum +from typing import Generator, Union, List, Dict, Optional + +from camel.types import ModelPlatformType, ModelType +from camel.types.agents import ToolCallingRecord +from pydantic import BaseModel + +from deepinsight.config.model import ModelConfig +from deepinsight.core.orchestrator import Orchestrator, OrchestrationResult, OrchestratorStatusType, OrchestrationRequest +from deepinsight.core.prompt.prompt_template import PromptStage +from deepinsight.core.types.agent import AgentType, AgentMessageAdditionType +from deepinsight.core.types.historical_message import HistoricalMessage, HistoricalMessageType +from deepinsight.core.types.messages import StartMessage, ChunkMessage, EndMessage, MessageMetadataKey, CompleteMessage, \ + ErrorMessage, Message as CoreMessage +from deepinsight.stores.postgresql.database import get_database_session +from deepinsight.stores.postgresql.repositories.conversation_repository import ConversationRepository +from deepinsight.stores.postgresql.repositories.message_repository import MessageRepository +from deepinsight.stores.postgresql.repositories.report_repository import ReportRepository +from deepinsight.stores.postgresql.schemas import report as DBSchemaReport +from deepinsight.stores.postgresql.schemas import message as DBSchemaMessage +from deepinsight.service.schemas.chat import ServiceMessage + +AGENT_PROCESS_TIPS_TEMPLATE = { + PromptStage.PLAN_START_TIPS: "研究计划如下\n{search_plans}", + PromptStage.RESEARCH_START_TIPS: "任务{task_id}: {task_title}", + PromptStage.REPORT_PLAN_TIPS: "正在分析结果,计划生成报告", + PromptStage.REPORT_WRITE_TIPS: "正在生成报告", +} + + +class ThoughtProcessType(str, Enum): + TITLE = "title" + CONTENT = "content" + TOOL_CALL = "tool_call" + + +class MessageType(str, Enum): + USER = "user" + SEARCH_PLAN = "search_plan" + REPORT = "report" + + +class MessageItem(BaseModel): + type: MessageType + content: str = "" + created_at: datetime + message_id: str + + +class ThoughtItem(BaseModel): + type: ThoughtProcessType + content: Union[str, ToolCallingRecord] = "" + created_at: datetime + + +class ReportItem(BaseModel): + type: MessageType = MessageType.REPORT + content: str = "" + created_at: datetime + + +class ThoughtMessages(BaseModel): + messages: List[ThoughtItem] + + +class DeepResearchResponse(BaseModel): + messages: List[Union[MessageItem, ThoughtItem, ReportItem]] + + +class ThoughtAndReport(BaseModel): + thought: ThoughtMessages + report: ReportItem + + +class DeepResearchService: + """ + Stateless service for deep research operations that: + - Coordinates research orchestration + - Handles database persistence + - Manages streaming responses + + Designed for single-use per API request with no shared state. + """ + + @classmethod + def research( + cls, + query: str, + conversation_id: str, + user_id: str, + ) -> Generator[List[ServiceMessage], None, None]: + # Initialize repositories with fresh session + db = get_database_session() + try: + conversation_repo = ConversationRepository(db) + message_repo = MessageRepository(db) + report_repo = ReportRepository(db) + + conversation = conversation_repo.get_by_id(conversation_id) + if not conversation: + raise ValueError(f"Conversation {conversation_id} not found") + + # Add user query to history + user_message = DBSchemaMessage( + conversation_id=conversation_id, + content=query, + type=MessageType.USER.value, + ) + message_repo.create(user_message) + + run_orchestration_generator = cls._run_orchestration( + query=query, + conversation=conversation, + user_id=user_id, + report_repo=report_repo, + message_repo=message_repo, + ) + try: + while True: + item = next(run_orchestration_generator) + yield cls._wrap_response(item) + except StopIteration as e: + pass + finally: + db.close() + + @classmethod + def _wrap_response( + cls, + original_response: DeepResearchResponse, + ) -> List[ServiceMessage]: + processed_messages = [] + thought_and_report_message = ServiceMessage( + id=str(uuid.uuid4()), + content=[], + role=MessageType.REPORT.value, + created_at=datetime.now() + ) + for msg in original_response.messages: + if isinstance(msg, MessageItem): + processed_message = ServiceMessage( + id=str(msg.message_id), + content=msg.content, + role=msg.type, + created_at=msg.created_at + ) + processed_messages.append(processed_message) + else: + thought_and_report_message.content.append( + msg + ) + if thought_and_report_message.content: + processed_messages.append(thought_and_report_message) + return processed_messages + + @classmethod + def get_report_and_thought_by_message_id(cls, message_id: str): + db = get_database_session() + try: + report_repo = ReportRepository(db) + report_and_thought_data = report_repo.get_by_message_id(message_id) + thought_process = ThoughtMessages.model_validate_json(report_and_thought_data.thought) + report = ReportItem( + content=report_and_thought_data.report_content, + created_at=report_and_thought_data.created_time, + ) + return ThoughtAndReport(thought=thought_process, report=report) + finally: + db.close() + + @classmethod + def _run_orchestration( + cls, + query, + conversation, + user_id, + report_repo: ReportRepository, + message_repo: MessageRepository, + ) -> Generator[DeepResearchResponse, None, None]: + full_response = DeepResearchResponse( + messages=[] + ) + + thought_process = ThoughtMessages( + messages=[], + ) + + final_report: ReportItem = None + + history_interactive_messages = cls._get_messages_by_conversation_id_until_report( + conversation_id=conversation.conversation_id, + message_repo=message_repo, + ) + orchestration = Orchestrator( + model_config=ModelConfig( + model_platform=ModelPlatformType.DEEPSEEK, + model_type=ModelType.DEEPSEEK_CHAT, + model_config_dict=dict( + stream=True + ), + ), + mcp_tools_config_path="./mcp_config.json", + research_round_limit=1, + init_request=OrchestrationRequest( + agent_historical_messages={ + AgentType.PLANNER: [cls._convert_message_to_orchestration_message(each) for each in + history_interactive_messages], + } + ), + execute_tips_template_dict=AGENT_PROCESS_TIPS_TEMPLATE, + ) + orchestration_generator = orchestration.run(query) + current_orchestration_phase = OrchestratorStatusType.PENDING + stream_chunk_caches: Dict[str, Union[MessageItem, ThoughtItem, ReportItem]] = {} + report_db_item: DBSchemaReport = None + message_db_item: DBSchemaMessage = None + + try: + while True: + item = next(orchestration_generator) + if isinstance(item, OrchestratorStatusType): + current_orchestration_phase = item + message_new_db_item = cls._insert_message_by_orchestration_phase( + conversation=conversation, + message_repo=message_repo, + current_orchestration_phase=current_orchestration_phase, + message=item, + ) + if message_new_db_item: + message_db_item = message_new_db_item + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=item, + thought_process=thought_process, + ) + else: + if isinstance(item, StartMessage): + stream_id = item.stream_id + cached_item = stream_chunk_caches.setdefault( + stream_id, + cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + ) + + if cached_item: + # Add message to report history + full_response.messages.append( + cached_item + ) + + if isinstance(cached_item, ReportItem): + final_report = cached_item + + elif isinstance(item, ChunkMessage): + stream_id = item.stream_id + cached_item = stream_chunk_caches.setdefault( + stream_id, + cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + ) + cached_item.content += item.payload + elif isinstance(item, EndMessage): + stream_id = item.stream_id + cached_item = stream_chunk_caches.setdefault( + stream_id, + cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + ) + cached_item.content += item.payload + if cached_item.content: + message_new_db_item = cls._insert_message_by_orchestration_phase( + conversation=conversation, + message_repo=message_repo, + current_orchestration_phase=current_orchestration_phase, + message=cached_item, + ) + if message_new_db_item: + message_db_item = message_new_db_item + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=cached_item, + thought_process=thought_process, + ) + else: + if cached_item in full_response.messages: + full_response.messages.remove(cached_item) + stream_chunk_caches.pop(stream_id) + else: + if isinstance(item, CompleteMessage) and not item.payload: + continue + if isinstance(item, ErrorMessage) and not item.error_message: + continue + process_item = cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + if process_item: + message_new_db_item = cls._insert_message_by_orchestration_phase( + conversation=conversation, + message_repo=message_repo, + current_orchestration_phase=current_orchestration_phase, + message=process_item, + ) + if message_new_db_item: + message_db_item = message_new_db_item + full_response.messages.append( + process_item + ) + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=process_item, + thought_process=thought_process, + ) + yield full_response + + except StopIteration as e: + report_artifact: OrchestrationResult = e.value + if report_artifact.report: + final_report = ReportItem( + content=report_artifact.report, + created_at=datetime.now(), + ) + full_response.messages.append(final_report) + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=final_report, + thought_process=thought_process, + ) + yield full_response + + @classmethod + def _get_messages_by_conversation_id_until_report(cls, conversation_id: str, message_repo: MessageRepository) -> \ + List[DBSchemaMessage]: + """ + Retrieve messages for a conversation in reverse chronological order, + stopping when a report message is encountered. + + Args: + conversation_id: str of the conversation to query + message_repo: SQLAlchemy session object + + Returns: + List of Message objects in chronological order (oldest first) + """ + # Query messages in descending order (newest first) + messages = message_repo.get_all_by_conversation_id(conversation_id=conversation_id) + + # Collect messages until we hit a report + filtered_messages = [] + for msg in reversed(messages): # 从最新消息开始检查 + if msg.type == MessageType.REPORT: + break + filtered_messages.append(msg) + + # Return in chronological order (oldest first) + return list(reversed(filtered_messages)) + + @classmethod + def _convert_message_to_orchestration_message(cls, message: DBSchemaMessage) -> HistoricalMessage: + return HistoricalMessage( + content=message.content, + type=HistoricalMessageType.RESEARCH_PLAN if message.type == "search_plan" else HistoricalMessageType(message.type), + created_time=message.created_time, + message_id=str(message.message_id) + ) + + @classmethod + def _insert_message_by_orchestration_phase( + cls, + conversation, + message_repo: MessageRepository, + current_orchestration_phase: OrchestratorStatusType, + message: Union[OrchestratorStatusType, MessageItem, ThoughtItem, ReportItem] + ) -> Optional[DBSchemaMessage]: + if isinstance(message, OrchestratorStatusType): + # When receiving a PhaseStartMessage of RESEARCHING phase, + # insert an empty REPORT type message as a placeholder + if message == OrchestratorStatusType.RESEARCHING: + return message_repo.create( + DBSchemaMessage( + conversation_id=conversation.conversation_id, + content="", + type=MessageType.REPORT, + ) + ) + elif current_orchestration_phase == OrchestratorStatusType.PLANNING: + # Stores CompleteMessage/ErrorMessage as SEARCH_PLAN during PLANNING phase + return message_repo.create( + DBSchemaMessage( + conversation_id=conversation.conversation_id, + content=message.content, + type=MessageType.SEARCH_PLAN, + ) + ) + logging.warning(f"Orchestration phase {current_orchestration_phase!s} does not require message insertion") + return None + + @classmethod + def _insert_or_update_thought_and_report_process( + cls, + conversation, + current_orchestration_phase: OrchestratorStatusType, + report_repo: ReportRepository, + relative_message: Optional[DBSchemaMessage], + has_insert_report: DBSchemaReport, + message: Optional[Union[OrchestratorStatusType, ThoughtItem, ReportItem]], + thought_process: ThoughtMessages, + ): + if relative_message is None: + logging.warning(f"Relative message is None, can not insert or update report.") + return None + if isinstance(message, OrchestratorStatusType): + if message == OrchestratorStatusType.RESEARCHING and has_insert_report is None: + has_insert_report = report_repo.create( + DBSchemaReport( + message_id=relative_message.message_id, + conversation_id=conversation.conversation_id, + thought="", + report_content="", + ) + ) + else: + if current_orchestration_phase == OrchestratorStatusType.RESEARCHING or current_orchestration_phase == OrchestratorStatusType.REPORT_PLANNING: + if isinstance(message, ThoughtItem): + thought_process.messages.append(message) + if has_insert_report is not None: + has_insert_report.thought = thought_process.model_dump_json() + report_repo.update(has_insert_report) + return has_insert_report + elif current_orchestration_phase == OrchestratorStatusType.REPORT_WRITING or current_orchestration_phase == OrchestratorStatusType.COMPLETED: + if has_insert_report is not None and isinstance(message, ReportItem): + has_insert_report.report_content = message.content + report_repo.update(has_insert_report) + return has_insert_report + + return has_insert_report + + @classmethod + def _create_precess_item_by_orchestration_phase( + cls, + current_orchestration_phase: OrchestratorStatusType, + data: CoreMessage + ) -> Optional[Union[MessageItem, ThoughtItem, ReportItem]]: + if isinstance(data, ErrorMessage): + return MessageItem( + type=MessageType.SEARCH_PLAN, + content=data.error_message, + created_at=data.timestamp, + message_id=str(uuid.uuid4()), + ) + elif current_orchestration_phase == OrchestratorStatusType.PLANNING: + if isinstance(data.payload, str): + return MessageItem( + type=MessageType.SEARCH_PLAN, + content=data.payload, + created_at=data.timestamp, + message_id=str(uuid.uuid4()), + ) + elif current_orchestration_phase == OrchestratorStatusType.RESEARCHING or current_orchestration_phase == OrchestratorStatusType.REPORT_PLANNING: + thought_type = ThoughtProcessType.CONTENT + if isinstance(data.payload, ToolCallingRecord): + thought_type = ThoughtProcessType.TOOL_CALL + elif data.metadata.get(MessageMetadataKey.ADDITION_TYPE, + None) == AgentMessageAdditionType.TIPS: + thought_type = ThoughtProcessType.TITLE + return ThoughtItem( + type=thought_type, + content=data.payload.model_dump_json() if isinstance(data.payload, + ToolCallingRecord) else data.payload, + created_at=data.timestamp, + ) + elif current_orchestration_phase == OrchestratorStatusType.REPORT_WRITING: + return ReportItem( + content=data.payload, + created_at=data.timestamp, + ) + return None diff --git a/deepinsight/service/schemas/__init__.py b/deepinsight/service/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepinsight/service/schemas/chat.py b/deepinsight/service/schemas/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1a7294ac241f6e588a34e7c8f8dee9e134c8b6 --- /dev/null +++ b/deepinsight/service/schemas/chat.py @@ -0,0 +1,42 @@ +from datetime import datetime +from typing import Union, List + +from pydantic import BaseModel, Field + + +class ServiceMessage(BaseModel): + """ + Represents a single message within a conversation. + """ + id: str + content: Union[str, List] + role: str + created_at: datetime + + +class GetChatHistoryData(BaseModel): + """ + Schema for the request body when fetching chat history. + """ + conversationId: str = Field(alias="conversationId") + +class GetChatHistoryStructure(BaseModel): + """ + Represents the structured data for a chat history response. + """ + conversation_id: str = Field(alias="conversationId") + user_id: str = Field(alias="userId") + created_time: str + title: str + status: str + messages: List[ServiceMessage] + # TODO: Missing 'updated_time' field for conversation update timestamp + + +class GetChatHistoryRsp(BaseModel): + """ + The complete response schema for fetching chat history. + """ + code: int + message: str + data: GetChatHistoryStructure diff --git a/deepinsight/service/schemas/conversation.py b/deepinsight/service/schemas/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..63875627d42e57f1da35731358294fa845b1fd5b --- /dev/null +++ b/deepinsight/service/schemas/conversation.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from typing import Optional, List + +from pydantic import BaseModel, Field + + +class ConversationListItem(BaseModel): + conversationId: str + title: str + createdTime: str + type: str = "normal" + + +class ConversationListMsg(BaseModel): + conversations: List[ConversationListItem] + + +class ConversationListRsp(BaseModel): + code: int + message: str + data: ConversationListMsg + + +class AddConversationMsg(BaseModel): + conversationId: str + created_time: str + + +class AddConversationRsp(BaseModel): + code: int + message: str + data: AddConversationMsg + + +class BodyAddConversation(BaseModel): + conversation_id: str + user_id: str = "test_user" + title: Optional[str] = "" + + +class ModifyConversationData(BaseModel): + title: str = Field(..., min_length=1, max_length=2000) + + +class UpdateConversationRsp(BaseModel): + code: int + message: str + data: ConversationListItem + + +class DeleteConversationData(BaseModel): + conversation_list: List[str] + + +class ResponseData(BaseModel): + code: int + message: str + data: Optional[dict] = {} + + +class RenameConversationData(BaseModel): + conversation_id: str + new_name: str + ori_name: Optional[str] = "default name" + + +class BodyGetList(BaseModel): + user_id: Optional[str] = "test_user" + offset: Optional[int] = 0 + limit: Optional[int] = 100 + + + diff --git a/deepinsight/service/schemas/model.py b/deepinsight/service/schemas/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3459d0e5e00be81aa2d444c915f5249f4a4ba230 --- /dev/null +++ b/deepinsight/service/schemas/model.py @@ -0,0 +1,15 @@ +# ====== 数据模型定义 ====== +from typing import Optional + +from pydantic import BaseModel + + +class LLMIteam(BaseModel): + icon: Optional[str] = "" + llmId: str = "empty" + modelName: str = "Ollama LLM" + + +class KbIteam(BaseModel): + kbId: str + kbName: str \ No newline at end of file diff --git a/deepinsight/stores/__init__.py b/deepinsight/stores/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f294e4646d6117d10d709fd03f03cd74044da42 --- /dev/null +++ b/deepinsight/stores/__init__.py @@ -0,0 +1 @@ +# 数据库操作包 diff --git a/deepinsight/stores/mongodb/__init__.py b/deepinsight/stores/mongodb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepinsight/stores/mongodb/database.py b/deepinsight/stores/mongodb/database.py new file mode 100644 index 0000000000000000000000000000000000000000..c2cd38010eb3013cc5ce07f5ef9d36566ec22edb --- /dev/null +++ b/deepinsight/stores/mongodb/database.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +from __future__ import annotations + +import logging +import os + +from pydantic import BaseModel, Field +from datetime import datetime +from pymongo import AsyncMongoClient +from typing import TYPE_CHECKING, Optional +import uuid + + +class Session(BaseModel): + """ + Session + + collection: session + """ + + id: str = Field(alias="_id") + ip: str + user_sub: Optional[str] = Field(default=None) + nonce: Optional[str] = Field(default=None) + expired_at: datetime + + +class Task(BaseModel): + """ + collection: witchiand_task + """ + + task_id: uuid.UUID = Field(alias="_id") + status: str + created_time: datetime = Field(default_factory=datetime.now) + + +if TYPE_CHECKING: + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.collection import AsyncCollection + + +class MongoDB: + """MongoDB连接""" + + user = os.getenv('MONGODB_USER', 'admin') + password = os.getenv('MONGODB_PASSWORD', '') + host = os.getenv('MONGODB_HOST', 'localhost') + port = os.getenv('MONGODB_PORT', 27017) + _client: AsyncMongoClient = AsyncMongoClient( + f"mongodb://{user}:{password}@{host}:{port}/?directConnection=true&replicaSet=rs0", + uuidRepresentation="standard" + ) + + @classmethod + def get_collection(cls, collection_name: str) -> AsyncCollection: + """获取MongoDB集合(表)""" + try: + return cls._client[os.getenv('MONGODB_DATABASE', '')][collection_name] + except Exception as e: + logging.exception("[MongoDB] 获取集合 %s 失败", collection_name) + raise RuntimeError(str(e)) from e + + @classmethod + async def clear_collection(cls, collection_name: str) -> None: + """清空MongoDB集合(表)""" + try: + await cls._client[os.getenv('MONGODB_DATABASE', '')][collection_name].delete_many({}) + except Exception: + logging.exception("[MongoDB] 清空集合 %s 失败", collection_name) + + @classmethod + def get_session(cls) -> AsyncClientSession: + """获取MongoDB会话""" + return cls._client.start_session() diff --git a/deepinsight/stores/postgresql/__init__.py b/deepinsight/stores/postgresql/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepinsight/stores/postgresql/database.py b/deepinsight/stores/postgresql/database.py new file mode 100644 index 0000000000000000000000000000000000000000..36ca5008ce2adb9a21b7fa1a8c19f7d1654f0467 --- /dev/null +++ b/deepinsight/stores/postgresql/database.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import os + +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +load_dotenv() + + +class DatabaseConfig: + """数据库配置类""" + + def __init__(self, db_type: str = None, connection_string: str = None): + """ + 初始化数据库配置 + + :param db_type: 数据库类型,如"postgresql"或"sqlite" + :param connection_string: 数据库连接字符串 + """ + # 优先使用传入的参数,否则从环境变量获取,最后使用默认值 + self.db_type = db_type or os.getenv("DB_TYPE", "") + + if connection_string: + self.connection_string = connection_string + else: + if self.db_type == "postgresql": + # 从环境变量中获取数据库连接信息 + db_user = os.getenv('POSTGRES_USER', 'default_user') + db_password = os.getenv('POSTGRES_PASSWORD', 'default_password') + db_host = os.getenv('POSTGRES_HOST', 'localhost') + db_port = os.getenv('POSTGRES_PORT', '5432') + db_name = os.getenv('POSTGRES_DB', 'postgres') + + # 构建数据库连接字符串 + self.connection_string = f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" + + elif self.db_type == "sqlite": + self.connection_string = os.getenv( + "SQLITE_CONN_STR", + "sqlite:///./chat_db.stores" + ) + else: + raise ValueError(f"不支持的数据库类型: {self.db_type}") + + +# 创建默认数据库配置 +default_config = DatabaseConfig() + +# 创建引擎 +engine = create_engine( + default_config.connection_string, + echo=False, # 设置为True可打印SQL语句,调试时使用 + connect_args={"check_same_thread": False} if default_config.db_type == "sqlite" else {} +) + +# 创建会话 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# SQLAlchemy模型基类 +# 所有数据模型都应继承此类,用于表结构定义 +DatabaseModel = declarative_base() + +# 数据库会话工厂 +DatabaseSession = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def get_database_session(): + """ + 获取数据库会话的生成器函数 + + Yields: + Session: 数据库会话对象 + + Note: + 使用FastAPI等框架时,可作为依赖项注入 + 会话会在使用后自动关闭 + """ + session = DatabaseSession() + try: + yield session + finally: + session.close() diff --git a/deepinsight/stores/postgresql/repositories/__init__.py b/deepinsight/stores/postgresql/repositories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aea4ccb7fa397dd2590213f78ddc057b7549c350 --- /dev/null +++ b/deepinsight/stores/postgresql/repositories/__init__.py @@ -0,0 +1 @@ +# 数据访问层 diff --git a/deepinsight/stores/postgresql/repositories/base_repository.py b/deepinsight/stores/postgresql/repositories/base_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..33c2dce4349c4b587e92e452590363b5e9f885f6 --- /dev/null +++ b/deepinsight/stores/postgresql/repositories/base_repository.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy.orm import Session +from typing import Generic, TypeVar, List, Optional, Type + +# 泛型类型变量 +ModelType = TypeVar("ModelType") + + +class BaseRepository(Generic[ModelType]): + """基础数据访问类,提供通用的CRUD操作""" + + def __init__(self, db: Session, model: Type[ModelType]): + """ + 初始化基础仓库 + + :param db: 数据库会话 + :param model: 数据模型类 + """ + self.db = db + self.model = model + + def create(self, obj: ModelType) -> ModelType: + """ + 创建新记录 + + :param obj: 要创建的对象 + :return: 创建后的对象 + """ + self.db.add(obj) + self.db.commit() + self.db.refresh(obj) + return obj + + def get_by_id(self, id: str) -> Optional[ModelType]: + """ + 根据ID获取记录 + + :param id: 记录ID + :return: 找到的对象或None + """ + return self.db.query(self.model).filter(self.model.id == id).first() + + def get_all(self, skip: int = 0, limit: int = 100) -> List[ModelType]: + """ + 获取所有记录 + + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 记录列表 + """ + return self.db.query(self.model).offset(skip).limit(limit).all() + + def update(self, obj: ModelType) -> ModelType: + """ + 更新记录 + + :param obj: 要更新的对象 + :return: 更新后的对象 + """ + self.db.merge(obj) + self.db.commit() + self.db.refresh(obj) + return obj + + def delete(self, obj: ModelType) -> None: + """ + 删除记录 + + :param obj: 要删除的对象 + """ + self.db.delete(obj) + self.db.commit() diff --git a/deepinsight/stores/postgresql/repositories/conversation_repository.py b/deepinsight/stores/postgresql/repositories/conversation_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..7361c8be790f19a405d236b02267c2edc8e8a261 --- /dev/null +++ b/deepinsight/stores/postgresql/repositories/conversation_repository.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy.orm import Session +from typing import List, Optional + +from deepinsight.stores.postgresql.repositories.base_repository import BaseRepository +from deepinsight.stores.postgresql.schemas.conversation import Conversation + + +class ConversationRepository(BaseRepository[Conversation]): + """对话数据访问类""" + + def __init__(self, db: Session): + """初始化对话仓库""" + super().__init__(db, Conversation) + + def get_by_id(self, conversation_id: str) -> Optional[Conversation]: + """ + 根据对话ID获取对话 + + :param conversation_id: 对话ID + :return: 找到的对话或None + """ + return self.db.query(Conversation).filter(Conversation.conversation_id == conversation_id).first() + + def get_by_user_id(self, user_id: str, offset: int = 0, limit: int = 100) -> List[Conversation]: + """ + 根据用户ID获取对话列表 + + :param user_id: 用户ID + :param offset: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 对话列表 + """ + return self.db.query(Conversation)\ + .filter(Conversation.user_id == user_id)\ + .order_by(Conversation.created_time.desc())\ + .offset(offset)\ + .limit(limit)\ + .all() + + def get_active_by_user_id(self, user_id: str) -> List[Conversation]: + """ + 获取用户的活跃对话 + + :param user_id: 用户ID + :return: 活跃对话列表 + """ + return self.db.query(Conversation)\ + .filter( + Conversation.user_id == user_id, + Conversation.status == "active" + )\ + .order_by(Conversation.created_time.desc())\ + .all() + + def update_status(self, conversation_id: str, status: str) -> Optional[Conversation]: + """ + 更新对话状态 + + :param conversation_id: 对话ID + :param status: 新状态 + :return: 更新后的对话或None + """ + conversation = self.get_by_id(conversation_id) + if conversation: + conversation.status = status + self.db.commit() + self.db.refresh(conversation) + return conversation diff --git a/deepinsight/stores/postgresql/repositories/message_repository.py b/deepinsight/stores/postgresql/repositories/message_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6734ed8040062c519d41cf78640eae80553083 --- /dev/null +++ b/deepinsight/stores/postgresql/repositories/message_repository.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy.orm import Session +from typing import List, Optional + +from deepinsight.stores.postgresql.repositories.base_repository import BaseRepository +from deepinsight.stores.postgresql.schemas.message import Message + + +class MessageRepository(BaseRepository[Message]): + """消息数据访问类""" + + def __init__(self, db: Session): + """初始化消息仓库""" + super().__init__(db, Message) + + def get_by_id(self, message_id: str) -> Optional[Message]: + """ + 根据消息ID获取消息 + + :param message_id: 消息ID + :return: 找到的消息或None + """ + return self.db.query(Message).filter(Message.message_id == message_id).first() + + def get_by_conversation_id(self, conversation_id: str, skip: int = 0, limit: int = 100) -> List[Message]: + """ + 根据对话ID获取消息列表 + + :param conversation_id: 对话ID + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 消息列表 + """ + return self.db.query(Message)\ + .filter(Message.conversation_id == conversation_id)\ + .order_by(Message.created_time.asc())\ + .offset(skip)\ + .limit(limit)\ + .all() + + def get_all_by_conversation_id(self, conversation_id: str) -> List[Message]: + """ + 根据对话ID获取消息列表 + + :param conversation_id: 对话ID + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 消息列表 + """ + return self.db.query(Message) \ + .filter(Message.conversation_id == conversation_id) \ + .order_by(Message.created_time.asc()) \ + .all() + + def get_by_type(self, conversation_id: str, message_type: str) -> List[Message]: + """ + 根据类型获取特定对话中的消息 + + :param conversation_id: 对话ID + :param message_type: 消息类型 + :return: 消息列表 + """ + return self.db.query(Message)\ + .filter( + Message.conversation_id == conversation_id, + Message.type == message_type + )\ + .order_by(Message.created_time.asc())\ + .all() + + def delete_by_conversation_id(self, conversation_id: str) -> None: + """ + 删除特定对话的所有消息 + + :param conversation_id: 对话ID + """ + self.db.query(Message)\ + .filter(Message.conversation_id == conversation_id)\ + .delete() + self.db.commit() diff --git a/deepinsight/stores/postgresql/repositories/report_repository.py b/deepinsight/stores/postgresql/repositories/report_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2a30fe0c1bc1d5963934db70c726e5b4b4e382 --- /dev/null +++ b/deepinsight/stores/postgresql/repositories/report_repository.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy.orm import Session +from typing import List, Optional + +from deepinsight.stores.postgresql.repositories.base_repository import BaseRepository +from deepinsight.stores.postgresql.schemas.report import Report + + +class ReportRepository(BaseRepository[Report]): + """报告数据访问类""" + + def __init__(self, db: Session): + """初始化报告仓库""" + super().__init__(db, Report) + + def get_by_id(self, report_id: str) -> Optional[Report]: + """ + 根据报告ID获取报告 + + :param report_id: 报告ID + :return: 找到的报告或None + """ + return self.db.query(Report).filter(Report.report_id == report_id).first() + + def get_by_message_id(self, message_id: str) -> Optional[Report]: + """ + 根据消息ID获取报告 + + :param message_id: 消息ID + :return: 找到的报告或None + """ + return self.db.query(Report).filter(Report.message_id == message_id).first() + + def get_by_conversation_id(self, conversation_id: str, skip: int = 0, limit: int = 100) -> List[Report]: + """ + 根据对话ID获取报告列表 + + :param conversation_id: 对话ID + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 报告列表 + """ + return self.db.query(Report)\ + .filter(Report.conversation_id == conversation_id)\ + .order_by(Report.created_time.desc())\ + .offset(skip)\ + .limit(limit)\ + .all() + + def delete_by_conversation_id(self, conversation_id: str) -> None: + """ + 删除特定对话的所有报告 + + :param conversation_id: 对话ID + """ + self.db.query(Report)\ + .filter(Report.conversation_id == conversation_id)\ + .delete() + self.db.commit() diff --git a/deepinsight/stores/postgresql/schemas/__init__.py b/deepinsight/stores/postgresql/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepinsight/stores/postgresql/schemas/conversation.py b/deepinsight/stores/postgresql/schemas/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6e1c266af727e4c7c137b4f8a843ec60bef227 --- /dev/null +++ b/deepinsight/stores/postgresql/schemas/conversation.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy import Column, String, DateTime +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func +import uuid + +from deepinsight.stores.postgresql.database import DatabaseModel + + +class Conversation(DatabaseModel): + """对话表模型""" + __tablename__ = "conversation" + + conversation_id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + comment="对话ID" + ) + user_id = Column( + String(36), + nullable=False, + comment="用户ID" + ) + created_time = Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + comment="创建时间" + ) + title = Column( + String(255), + default="新建对话", + comment="对话标题" + ) + status = Column( + String(50), + nullable=False, + default="active", + comment="对话状态" + ) + + def __repr__(self): + return f"" diff --git a/deepinsight/stores/postgresql/schemas/message.py b/deepinsight/stores/postgresql/schemas/message.py new file mode 100644 index 0000000000000000000000000000000000000000..8409f3a3f6c942851604dbec5ad894c81254f593 --- /dev/null +++ b/deepinsight/stores/postgresql/schemas/message.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy import Column, String, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +import uuid + +from deepinsight.stores.postgresql.database import DatabaseModel + + +class Message(DatabaseModel): + """消息表模型""" + __tablename__ = "message" + + message_id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + comment="消息ID" + ) + conversation_id = Column( + UUID(as_uuid=True), + ForeignKey("conversation.conversation_id", ondelete="CASCADE"), + nullable=False, + comment="关联的对话ID" + ) + content = Column( + Text, + nullable=False, + comment="消息内容" + ) + type = Column( + String(50), + nullable=False, + comment="消息类型:chat, planner, report" + ) + created_time = Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + comment="创建时间" + ) + + # 关系 + conversation = relationship("Conversation", backref="messages") + report = relationship("Report", backref="message", uselist=False, cascade="all, delete-orphan") + + def __repr__(self): + return f"" diff --git a/deepinsight/stores/postgresql/schemas/report.py b/deepinsight/stores/postgresql/schemas/report.py new file mode 100644 index 0000000000000000000000000000000000000000..11f0e32b015999cf1c1b5fa83a2a428ef4648fe7 --- /dev/null +++ b/deepinsight/stores/postgresql/schemas/report.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy import Column, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +import uuid + +from deepinsight.stores.postgresql.database import DatabaseModel + + +class Report(DatabaseModel): + """报告表模型""" + __tablename__ = "report" + + report_id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + comment="报告ID" + ) + message_id = Column( + UUID(as_uuid=True), + ForeignKey("message.message_id", ondelete="CASCADE"), + nullable=False, + unique=True, + comment="关联的消息ID" + ) + conversation_id = Column( + UUID(as_uuid=True), + ForeignKey("conversation.conversation_id", ondelete="CASCADE"), + nullable=False, + comment="关联的对话ID" + ) + thought = Column( + Text, + comment="思考过程" + ) + report_content = Column( + Text, + nullable=False, + comment="报告内容" + ) + created_time = Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + comment="创建时间" + ) + + # 关系 + conversation = relationship("Conversation", backref="reports") + + def __repr__(self): + return f"" diff --git a/pyproject.toml b/pyproject.toml index 0452803a47e222418bc81d6e9b13da2cb9b011c7..de952b1a98020873eaecb4ad643e60a1562be003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,13 +3,39 @@ name = "deepInsight" version = "0.1.0" description = "" readme = "README.md" -packages = [{include = "deepinsight", from = '.'}] +# 明确指定需要包含的顶级包 +packages = [ + {include = "deepinsight"}, + {include = "integrations"} +] +# 排除非Python包目录 +exclude = ["License"] [tool.poetry.dependencies] python = ">=3.10,<3.13" camel-ai = ">=0.2.60" pydantic = ">=2.10.6" rich = ">=14.0.0" +fastapi = "0.116.1" +sqlalchemy = "2.0.41" +pymongo = "4.13.2" +flask = "3.1.1" +flask_cors = "6.0.1" +python-dotenv = "*" # 新增环境变量管理依赖 + +# 跨平台可选依赖组 +[tool.poetry.group.win.dependencies] +psycopg2-binary = "2.9.9" # Windows平台专用 + +[tool.poetry.group.unix.dependencies] +psycopg2 = "2.9.9" # Linux/macOS平台专用 [tool.poetry.group.dev.dependencies] -pytest = "*" \ No newline at end of file +pytest = "*" +black = "*" +flake8 = "*" +mypy = "*" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/tests/core/agents/test_planner.py b/tests/core/agents/test_planner.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbd0c55e6f43ffddd97b933e3c9ac6a3619af94 --- /dev/null +++ b/tests/core/agents/test_planner.py @@ -0,0 +1,36 @@ +import os +import unittest +from unittest.mock import patch + +from camel.types import ModelPlatformType, ModelType + +from deepinsight.config.model import ModelConfig +from deepinsight.core.agent.planner import Planner + + +class TestPlannerAgent(unittest.TestCase): + def setUp(self): + """Test setup that runs before each test method.""" + self.patcher1 = patch('camel.models.openai_model.OpenAIModel.token_counter') + self.mock_token_counter = self.patcher1.start() + + def side_effect(text): + return len(text.split()) + self.mock_token_counter.side_effect = side_effect + + os.environ["OPENAI_API_KEY"] = "sk-test" + + self.model_config = { + "stream": True, + } + + def test_default_plan_parser(self): + planner = Planner( + model_config=ModelConfig( + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, + model_config_dict=self.model_config + ) + ) + parser_response = planner._default_plan_parser("1. 分析谷歌a2a基本信息") + self.assertEqual(len(parser_response.search_plans), 1)