From 93fae85c8815fa3b5d45563e1edaadf94260eb09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B7=A9=E6=97=AD=E7=BA=A2?= <2250244225@qq.com> Date: Sat, 26 Jul 2025 14:25:42 +0800 Subject: [PATCH 1/4] implement chat streaming capability & DeepInsight enhances backend data IO capabilities --- .gitignore | 2 +- deepinsight/api/app.py | 153 ++++++ deepinsight/app.py | 6 +- deepinsight/core/agent/base.py | 40 +- deepinsight/core/agent/planner.py | 24 +- deepinsight/core/agent/reporter.py | 55 +- deepinsight/core/agent/researcher.py | 42 +- deepinsight/core/orchestrator.py | 52 +- deepinsight/core/prompt/prompt_template.py | 4 + deepinsight/core/types/agent.py | 11 +- deepinsight/core/types/messages.py | 6 + deepinsight/db/__init__.py | 1 + deepinsight/db/config.py | 66 +++ deepinsight/db/init_db.py | 73 +++ deepinsight/db/main.py | 87 +++ deepinsight/db/repositories/__init__.py | 1 + .../db/repositories/base_repository.py | 71 +++ .../repositories/conversation_repository.py | 68 +++ .../db/repositories/message_repository.py | 80 +++ .../db/repositories/report_repository.py | 58 ++ deepinsight/db/schemas/__init__.py | 1 + deepinsight/db/schemas/conversation.py | 43 ++ deepinsight/db/schemas/message.py | 49 ++ deepinsight/db/schemas/report.py | 53 ++ deepinsight/service/__init__.py | 9 + deepinsight/service/conversation.py | 62 +++ deepinsight/service/deep_research.py | 507 ++++++++++++++++++ deepinsight/service/schemas/__init__.py | 0 deepinsight/service/schemas/chat.py | 42 ++ pyproject.toml | 2 + tests/core/agents/test_planner.py | 36 ++ 31 files changed, 1656 insertions(+), 48 deletions(-) create mode 100644 deepinsight/api/app.py create mode 100644 deepinsight/db/__init__.py create mode 100644 deepinsight/db/config.py create mode 100644 deepinsight/db/init_db.py create mode 100644 deepinsight/db/main.py create mode 100644 deepinsight/db/repositories/__init__.py create mode 100644 deepinsight/db/repositories/base_repository.py create mode 100644 deepinsight/db/repositories/conversation_repository.py create mode 100644 deepinsight/db/repositories/message_repository.py create mode 100644 deepinsight/db/repositories/report_repository.py create mode 100644 deepinsight/db/schemas/__init__.py create mode 100644 deepinsight/db/schemas/conversation.py create mode 100644 deepinsight/db/schemas/message.py create mode 100644 deepinsight/db/schemas/report.py create mode 100644 deepinsight/service/conversation.py create mode 100644 deepinsight/service/deep_research.py create mode 100644 deepinsight/service/schemas/__init__.py create mode 100644 deepinsight/service/schemas/chat.py create mode 100644 tests/core/agents/test_planner.py diff --git a/.gitignore b/.gitignore index 383d8a0..f0c638e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ .idea output -**/*/__pycache__ +**/__pycache__ mcp_config.json \ No newline at end of file diff --git a/deepinsight/api/app.py b/deepinsight/api/app.py new file mode 100644 index 0000000..038ce27 --- /dev/null +++ b/deepinsight/api/app.py @@ -0,0 +1,153 @@ +import asyncio +import os +import uuid +from datetime import datetime +from typing import Optional, Dict + +from fastapi import FastAPI, Request, APIRouter, HTTPException, Query, Body +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse + +from deepinsight.db.config import get_db +from deepinsight.service.conversation import ConversationService +from deepinsight.service.deep_research import MessageType, DeepResearchService +from deepinsight.service.schemas.chat import (ConversationListRsp, ConversationListMsg, ConversationListItem, + AddConversationRsp, BodyAddConversation, AddConversationMsg, + ResponseData, DeleteConversationData) +from deepinsight.service.schemas.chat import GetChatHistoryData, GetChatHistoryStructure, GetChatHistoryRsp +from deepinsight.service.schemas.model import LLMIteam, KbIteam + +# 读取环境变量中的 API 前缀 +API_PREFIX = os.getenv("API_PREFIX", "") + +# 创建 FastAPI 实例 +app_instance = FastAPI( + title="DeepInsight API", + description="A streaming chat API for DeepInsight", + version="1.0.0" +) +_conversations: Dict[str, ConversationListItem] = {} +# 创建路由 +router = APIRouter() + +# 跨域中间件配置 +app_instance.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@router.post("/api/chat") +async def chat_stream(request: Request): + try: + body = await request.json() + conversation_id = body.get("conversation_id", "") or str(uuid.uuid4()) + db_generator = get_db() + db_session = next(db_generator) + service = ConversationService(db_session) + conversation_info = service.get_conversation_info(conversation_id) + if not conversation_info: + raise HTTPException(status_code=404, detail="Conversation not found") + messages = body.get("messages", []) + if not isinstance(messages, list): + raise HTTPException(status_code=400, detail="messages must be a list") + query = None + for item in reversed(messages): + if item.get('role') == MessageType.USER.value and item.get("content", None): + query = item.get("content") + + async def fake_model_stream(): + for item in DeepResearchService.research(query=query, conversation_id=conversation_id, user_id=""): + resp = GetChatHistoryRsp( + code=0, + message="", + data=GetChatHistoryStructure( + conversation_id=conversation_id, + user_id=conversation_info.user_id, + created_time=str(datetime.now()), + title=conversation_info.title, # 原 name + status=conversation_info.status, + messages=item + ) + ).model_dump_json() + yield f"data: {resp}\n\n" + yield 'data: [DONE]\n\n' + + return StreamingResponse(fake_model_stream(), media_type="text/event-stream") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/api/conversation", response_model=ConversationListRsp, tags=["conversation"]) +async def get_conversation_list(): + try: + return ConversationListRsp( + code=200, + message="OK", + result=ConversationListMsg(conversations=list(_conversations.values())) + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/api/conversation", response_model=AddConversationRsp, tags=["conversation"]) +async def add_conversation( + appId: Optional[str] = Query(default=""), + debug: Optional[bool] = Query(default=False), + body: BodyAddConversation = Body(...) +): + try: + conversation_id = str(uuid.uuid4()) + new_conversation = ConversationListItem( + conversationId=conversation_id, + title="新会话", + docCount=0, + createdTime=datetime.utcnow().isoformat(), + appId=appId, + debug=debug, + llm=LLMIteam(llmId=body.llm_id), + kbList=[KbIteam(kbId=kb, kbName=f"知识库-{kb}") for kb in (body.kb_ids or [])] + ) + _conversations[conversation_id] = new_conversation + return AddConversationRsp( + code=200, + message="OK", + result=AddConversationMsg(conversationId=conversation_id) + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/api/conversation", response_model=ResponseData, tags=["conversation"]) +async def delete_conversation(data: DeleteConversationData = Body(...)): + try: + for cid in data.conversationList: + _conversations.pop(cid, None) + return ResponseData(code=200, message="Deleted", result={}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/api/history", response_model=GetChatHistoryRsp, tags=["conversation"]) +async def get_chat_history(data: GetChatHistoryData): + db_generator = get_db() + db_session = next(db_generator) + service = ConversationService(db_session) + + conversation_info = service.get_conversation_info(data.conversationId) + history_present = service.get_history_messages(data.conversationId) + data = GetChatHistoryStructure( + conversation_id=data.conversationId, + user_id=conversation_info.user_id, + created_time=str(conversation_info.created_time), + title=conversation_info.title, # 原 name + status=conversation_info.status, + messages=history_present + ) + return GetChatHistoryRsp(code=0, message="ok", data=data).model_dump() + + +app_instance.include_router(router, prefix=API_PREFIX) diff --git a/deepinsight/app.py b/deepinsight/app.py index df0d0c2..40d4b41 100644 --- a/deepinsight/app.py +++ b/deepinsight/app.py @@ -13,11 +13,11 @@ from pathlib import Path from camel.types import ModelPlatformType, ModelType from deepinsight.config.model import ModelConfig -from deepinsight.core.orchestrator import Orchestrator, OrchestratorResult +from deepinsight.core.orchestrator import Orchestrator, OrchestrationResult from deepinsight.utils.console_utils import display_stream -def save_artifact(output_dir: Path, data: OrchestratorResult) -> None: +def save_artifact(output_dir: Path, data: OrchestrationResult) -> None: """Save research artifacts to files""" output_dir.mkdir(exist_ok=True) @@ -66,7 +66,7 @@ def main(): mcp_tools_config_path="./mcp_config.json", research_round_limit=1, ) - result: OrchestratorResult = display_stream(orchestration.run(args.query)) + result: OrchestrationResult = display_stream(orchestration.run(args.query)) # Handle interactive states while result.require_user_interactive: diff --git a/deepinsight/core/agent/base.py b/deepinsight/core/agent/base.py index 4675b73..5cbdc20 100644 --- a/deepinsight/core/agent/base.py +++ b/deepinsight/core/agent/base.py @@ -9,6 +9,7 @@ # See the Mulan PSL v2 for more details. from __future__ import annotations +import uuid from abc import abstractmethod, ABC from typing import Any, Dict, Optional, TypeVar, Generator, Generic @@ -18,7 +19,9 @@ from camel.toolkits import MCPToolkit from deepinsight.config.model import ModelConfig from deepinsight.core.agent.stream_chat_agent import StreamChatAgent -from deepinsight.core.types.messages import Message +from deepinsight.core.prompt.prompt_template import PromptTemplate +from deepinsight.core.types.agent import AgentMessageAdditionType +from deepinsight.core.types.messages import Message, CompleteMessage, MessageMetadataKey from deepinsight.utils.aio import get_or_create_loop OutputType = TypeVar("OutputType") @@ -39,6 +42,7 @@ class BaseAgent(ABC, Generic[OutputType]): model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, ) -> None: """ Initialize the base agent with configuration. @@ -51,6 +55,7 @@ class BaseAgent(ABC, Generic[OutputType]): self.mcp_toolkit_instance = None self.mcp_tools_config_path = mcp_tools_config_path self.mcp_client_timeout = mcp_client_timeout + self.tips_prompt_template = tips_prompt_template self.connect_mcp() if model_config.model_config_dict and model_config.model_config_dict.get("stream", False): @@ -161,6 +166,7 @@ class BaseAgent(ABC, Generic[OutputType]): Returns: T: The parsed output from parse_output() """ + yield from self.pre_run(query, context) prompt = self.build_user_prompt(query=query, context=context) if isinstance(self.agent, StreamChatAgent): response = yield from self.agent.stream_step(prompt) @@ -170,7 +176,16 @@ class BaseAgent(ABC, Generic[OutputType]): self.post_run(output) return output - def post_run(self, output: OutputType) -> None: + def pre_run( + self, + query: str, + context: Dict[str, Any] | None = None, + ) -> Generator[Message, None, None]: + if False: # This technique is used to maintain the generator function characteristics + yield + return + + def post_run(self, output: OutputType) -> Generator[Message, None, None]: """ Post-processing hook that is called after run() completes successfully. @@ -180,4 +195,23 @@ class BaseAgent(ABC, Generic[OutputType]): Args: output: The output from run() method """ - pass \ No newline at end of file + if False: # This technique is used to maintain the generator function characteristics + yield + return + + def yield_tips_messages(self, tips_prompt_name: str, **variables) -> Generator[Message, None, None]: + if self.tips_prompt_template: + yield self._create_tips_messages(tips_prompt_name, **variables) + + def _create_tips_messages(self, tips_prompt_name: str, **variables) -> Message: + content = self.tips_prompt_template.get_prompt( + stage=tips_prompt_name, + variables=variables + ) + return CompleteMessage( + stream_id=str(uuid.uuid4()), + payload=content, + metadata={ + MessageMetadataKey.ADDITION_TYPE: AgentMessageAdditionType.TIPS + } + ) diff --git a/deepinsight/core/agent/planner.py b/deepinsight/core/agent/planner.py index da715cb..cfc1811 100644 --- a/deepinsight/core/agent/planner.py +++ b/deepinsight/core/agent/planner.py @@ -13,7 +13,7 @@ import logging import re from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, TypeAlias, Callable +from typing import Any, Dict, List, Optional, TypeAlias, Callable, Generator from camel.messages import BaseMessage from camel.responses import ChatAgentResponse @@ -23,8 +23,9 @@ from typing_extensions import override from deepinsight.config.model import ModelConfig from deepinsight.core.agent.base import BaseAgent -from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage +from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage, PromptTemplate from deepinsight.core.types.historical_message import HistoricalMessage, HistoricalMessageType +from deepinsight.core.types.messages import Message class NotSupportStreamException(Exception): @@ -172,6 +173,7 @@ class Planner(BaseAgent[PlanResult]): model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, plan_parser: PlanParser = None, latest_search_plan: Optional[str] = None, historical_messages: Optional[List[HistoricalMessage]] = None, @@ -187,7 +189,7 @@ class Planner(BaseAgent[PlanResult]): latest_search_plan: Current search plan context historical_messages: List of historical messages """ - super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) + super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout, tips_prompt_template) self.plan_parser = plan_parser or self._default_plan_parser # Init plan if historical_messages: @@ -233,7 +235,7 @@ class Planner(BaseAgent[PlanResult]): return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt( stage=PromptStage.PLAN_USER, variables=dict( query=query, - current_search_plan="\n".join([each.origin_plan for each in self.current_search_plan.search_plans]) if self.current_search_plan else "", + current_search_plan=self._search_plan_text(self.current_search_plan) if self.current_search_plan else "", **context if context is not None else {}, ) ) @@ -271,6 +273,8 @@ class Planner(BaseAgent[PlanResult]): ) if full_response.startswith("开始研究"): + if not self.current_search_plan: + raise NoPlanException("No plan can not start") # If need start research final_plan = deepcopy(self.current_search_plan) final_plan.status = PlanStatus.FINALIZED @@ -315,7 +319,15 @@ class Planner(BaseAgent[PlanResult]): ) return plan_result - def post_run(self, output: PlanResult) -> None: + def post_run(self, output: PlanResult) -> Generator[Message, None, None]: """Post run process.""" - super().post_run(output) self.current_search_plan = output + if output.status == PlanStatus.FINALIZED: + yield from self.yield_tips_messages(PromptStage.PLAN_START_TIPS, search_plans=self._search_plan_text(output)) + yield from super().post_run(output) + + def _search_plan_text(self, plan_result: PlanResult) -> str: + if plan_result.search_plans: + return "\n".join([each.origin_plan for each in plan_result.search_plans]) + else: + return "" diff --git a/deepinsight/core/agent/reporter.py b/deepinsight/core/agent/reporter.py index 8006c06..1dd2fbd 100644 --- a/deepinsight/core/agent/reporter.py +++ b/deepinsight/core/agent/reporter.py @@ -9,19 +9,19 @@ # See the Mulan PSL v2 for more details. import logging import re - -from camel.responses import ChatAgentResponse - -from deepinsight.config.model import ModelConfig -from deepinsight.core.agent.base import BaseAgent, OutputType +import uuid from typing import Any, Dict, Generator from typing import Optional, TypeAlias, Callable, List +from camel.responses import ChatAgentResponse from pydantic import BaseModel +from deepinsight.config.model import ModelConfig +from deepinsight.core.agent.base import BaseAgent, OutputType from deepinsight.core.agent.researcher import ResearchExecution -from deepinsight.core.types.messages import Message -from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage +from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage, PromptTemplate +from deepinsight.core.types.agent import AgentExecutePhase +from deepinsight.core.types.messages import Message, CompleteMessage, MessageMetadataKey from deepinsight.utils.parallel_worker_utils import Executor @@ -56,11 +56,13 @@ class GenerateSubTaskAgent(BaseAgent[List[WritingTask]]): Inherits from BaseAgent with List[WritingTask] as the concrete output type. """ + def __init__( self, model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, report_plan_parser: Optional[ReportPlanParser] = None, ) -> None: """ @@ -72,7 +74,7 @@ class GenerateSubTaskAgent(BaseAgent[List[WritingTask]]): mcp_client_timeout: Timeout for MCP client operations report_plan_parser: Custom parser for converting LLM responses to writing tasks """ - super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) + super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout, tips_prompt_template) self.report_plan_parser = report_plan_parser def build_system_prompt(self) -> str: @@ -89,6 +91,13 @@ class GenerateSubTaskAgent(BaseAgent[List[WritingTask]]): return self.report_plan_parser(response.msg.content) return super().parse_output(response) + def pre_run(self, query: str, context: Dict[str, Any] | None = None) -> Generator[Message, None, None]: + yield from self.yield_tips_messages(tips_prompt_name=PromptStage.REPORT_PLAN_TIPS) + + def post_run(self, output: OutputType) -> Generator[Message, None, None]: + yield from self.yield_tips_messages(tips_prompt_name=PromptStage.REPORT_WRITE_TIPS) + yield from super().post_run(output) + class ExecuteSubTaskAgent(BaseAgent[str]): """ @@ -97,10 +106,15 @@ class ExecuteSubTaskAgent(BaseAgent[str]): Inherits from BaseAgent with str as the concrete output type. """ - def __init__(self, model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, - mcp_client_timeout: Optional[int] = None) -> None: + def __init__( + self, + model_config: ModelConfig, + mcp_tools_config_path: Optional[str] = None, + mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, + ) -> None: """Initialize with model and MCP configuration.""" - super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) + super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout, tips_prompt_template) def build_system_prompt(self) -> str: return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt(PromptStage.REPORT_WRITE_SYSTEM) @@ -130,6 +144,7 @@ class Reporter: model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, report_plan_parser: Optional[ReportPlanParser] = None, report_post_processer: Optional[ReportPostProcesser] = None, ): @@ -146,6 +161,7 @@ class Reporter: self.model_config = model_config self.mcp_tools_config_path = mcp_tools_config_path self.mcp_client_timeout = mcp_client_timeout + self.tips_prompt_template = tips_prompt_template self.report_plan_parser = report_plan_parser or self._default_report_plan_parser self.report_post_processer = report_post_processer or self._default_report_post_processer @@ -168,6 +184,14 @@ class Reporter: str: The final generated report """ writing_tasks = yield from self._generate_writing_task(query=query, research_executions=research_executions) + # Report plan has complete, send a phase message + yield CompleteMessage( + stream_id=str(uuid.uuid4()), + payload=None, + metadata={ + MessageMetadataKey.AGENT_EXECUTE_PHASE: AgentExecutePhase.REPORT_WRITING + } + ) writing_report_executor = Executor("writing_report") def report_writing_worker(i, writing_task: WritingTask): @@ -190,6 +214,7 @@ class Reporter: self.model_config, self.mcp_tools_config_path, self.mcp_client_timeout, + self.tips_prompt_template, self.report_plan_parser, ) research_info = self._construct_research_info(research_executions) @@ -208,7 +233,13 @@ class Reporter: def _write_task(self, query, writing_task: WritingTask) -> Generator[Message, None, str]: """Execute an individual writing task.""" - write_agent = ExecuteSubTaskAgent(self.model_config, self.mcp_tools_config_path, self.mcp_client_timeout) + write_agent = ExecuteSubTaskAgent( + self.model_config, + self.mcp_tools_config_path, + self.mcp_client_timeout, + self.tips_prompt_template, + ) + report = yield from write_agent.run( query=query, context=dict( diff --git a/deepinsight/core/agent/researcher.py b/deepinsight/core/agent/researcher.py index 1d994b1..fce4aae 100644 --- a/deepinsight/core/agent/researcher.py +++ b/deepinsight/core/agent/researcher.py @@ -19,8 +19,9 @@ from pydantic import BaseModel, Field from deepinsight.config.model import ModelConfig from deepinsight.core.agent.base import BaseAgent from deepinsight.core.agent.planner import SearchPlan, PlanResult -from deepinsight.core.types.messages import Message -from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage +from deepinsight.core.types.agent import AgentMessageAdditionType +from deepinsight.core.types.messages import Message, CompleteMessage, MessageMetadataKey +from deepinsight.core.prompt.prompt_template import GLOBAL_DEFAULT_PROMPT_REPOSITORY, PromptStage, PromptTemplate from deepinsight.utils.parallel_worker_utils import Executor @@ -95,6 +96,7 @@ class RolePlayingUser(BaseAgent[ChatAgentResponse]): Specializes BaseAgent to handle user-side interactions in research dialogues. """ + def __init__( self, model_config: ModelConfig, @@ -106,7 +108,7 @@ class RolePlayingUser(BaseAgent[ChatAgentResponse]): def build_system_prompt(self) -> str: return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt(PromptStage.RESEARCH_ROLE_PLAYING_USER_SYSTEM) - def build_user_prompt(self, *, query:str, context: Dict[str, Any] | None = None) -> str: + def build_user_prompt(self, *, query: str, context: Dict[str, Any] | None = None) -> str: return query @@ -116,6 +118,7 @@ class RolePlayingAssistant(BaseAgent[ChatAgentResponse]): Specializes BaseAgent to handle assistant-side interactions in research dialogues. """ + def __init__(self, model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None) -> None: super().__init__(model_config, mcp_tools_config_path, mcp_client_timeout) @@ -123,8 +126,7 @@ class RolePlayingAssistant(BaseAgent[ChatAgentResponse]): def build_system_prompt(self) -> str: return GLOBAL_DEFAULT_PROMPT_REPOSITORY.get_prompt(PromptStage.RESEARCH_ROLE_PLAYING_ASSISTANT_SYSTEM) - - def build_user_prompt(self, *, query:str, context: Dict[str, Any] | None = None) -> str: + def build_user_prompt(self, *, query: str, context: Dict[str, Any] | None = None) -> str: return query @@ -137,6 +139,7 @@ class StreamRolePlaying: - Message exchange - Termination conditions """ + def __init__( self, model_config: ModelConfig, @@ -144,7 +147,7 @@ class StreamRolePlaying: mcp_client_timeout: Optional[int] = None, should_terminate_callback: Optional[ResearchShouldTerminateCallback] = None, **kwargs, - ) -> None: + ) -> None: """ Initialize the role-playing orchestrator. @@ -179,7 +182,7 @@ class StreamRolePlaying: self, query: str, context: Dict[str, Any] | None = None, - ) -> Generator[Message, None, Tuple[ChatAgentResponse, ChatAgentResponse]]: + ) -> Generator[Message, None, Tuple[ChatAgentResponse, ChatAgentResponse]]: """ Execute a complete role-playing dialogue turn. @@ -218,7 +221,8 @@ class StreamRolePlaying: ) user_msg = user_response.msg - assistant_response: ChatAgentResponse = yield from self.assistant_agent.run(query=user_msg.content, context=context) + assistant_response: ChatAgentResponse = yield from self.assistant_agent.run(query=user_msg.content, + context=context) if assistant_response.terminated or assistant_response.msgs is None: return ( ChatAgentResponse( @@ -261,6 +265,7 @@ class Researcher: model_config: ModelConfig, mcp_tools_config_path: Optional[str] = None, mcp_client_timeout: Optional[int] = None, + tips_prompt_template: Optional[PromptTemplate] = None, round_limit: int = 15, should_terminate_callback: ResearchShouldTerminateCallback = None, ) -> None: @@ -277,6 +282,7 @@ class Researcher: self.model_config = model_config self.mcp_tools_config_path = mcp_tools_config_path self.mcp_client_timeout = mcp_client_timeout + self.tips_prompt_template = tips_prompt_template self.round_limit = round_limit self.should_terminate_callback: ResearchShouldTerminateCallback = should_terminate_callback or self._default_should_terminate_callback @@ -296,12 +302,27 @@ class Researcher: """ # Parallel search info search_parallel_executor = Executor("Search") + def search_info_worker(i, search_step: SearchPlan): + yield CompleteMessage( + stream_id=str(uuid.uuid4()), + payload=self.tips_prompt_template.get_prompt( + stage=PromptStage.RESEARCH_START_TIPS, + variables=dict( + task_id=i + 1, + task_title=search_step.title, + ) + ), + metadata={ + MessageMetadataKey.ADDITION_TYPE: AgentMessageAdditionType.TIPS + } + ) one_search_content = yield from self._search_info_with_role_playing( query=query, search_plan=search_step, ) return one_search_content or [] + all_content = yield from search_parallel_executor(search_info_worker, list(enumerate(plan_result.search_plans))) flattened = [] @@ -350,7 +371,7 @@ class Researcher: stage=PromptStage.RESEARCH_ROLE_PLAYING_USER_USER, variables=dict( query=query, - current_plan=search_plan.origin_plan, + current_plan=search_plan.origin_plan, ) ) @@ -368,7 +389,8 @@ class Researcher: last_user_response = user_response assistant_execution_step = ExecutionStep( content=assistant_response.msg.content, - tool_calls=assistant_response.info.get("tool_calls") if assistant_response.info.get("tool_calls") else None + tool_calls=assistant_response.info.get("tool_calls") if assistant_response.info.get( + "tool_calls") else None ) user_execution_stop = ExecutionStep( diff --git a/deepinsight/core/orchestrator.py b/deepinsight/core/orchestrator.py index 2aa969f..18175ef 100644 --- a/deepinsight/core/orchestrator.py +++ b/deepinsight/core/orchestrator.py @@ -17,9 +17,10 @@ from deepinsight.config.model import ModelConfig from deepinsight.core.agent.planner import Planner, PlanStatus from deepinsight.core.agent.reporter import Reporter from deepinsight.core.agent.researcher import Researcher -from deepinsight.core.types.agent import AgentType +from deepinsight.core.prompt.prompt_template import PromptTemplate, PromptStage +from deepinsight.core.types.agent import AgentType, AgentExecutePhase from deepinsight.core.types.historical_message import HistoricalMessage -from deepinsight.core.types.messages import Message +from deepinsight.core.types.messages import Message, MessageMetadataKey class OrchestratorStatus(BaseModel): @@ -42,7 +43,8 @@ class OrchestratorStatusType(str, Enum): PENDING = OrchestratorStatus(status="pending") PLANNING = OrchestratorStatus(status="planning") RESEARCHING = OrchestratorStatus(status="researching") - REPORTING = OrchestratorStatus(status="reporting") + REPORT_PLANNING = OrchestratorStatus(status="report_planning") + REPORT_WRITING = OrchestratorStatus(status="report_writing") COMPLETED = OrchestratorStatus(status="completed") FAILED = OrchestratorStatus(status="failed") @@ -78,7 +80,7 @@ class OrchestrationRequest(BaseModel): ) -class OrchestratorResult(BaseModel): +class OrchestrationResult(BaseModel): """ Container for orchestration process outputs with interactive capabilities. @@ -130,6 +132,7 @@ class Orchestrator: mcp_client_timeout: Optional[int] = None, research_round_limit: int = 5, init_request: Optional[OrchestrationRequest] = None, + execute_tips_template_dict: Optional[Dict[Union[str, PromptStage], str]] = None, ) -> None: """ Initialize the orchestration engine with configuration. @@ -141,10 +144,12 @@ class Orchestrator: research_round_limit: Maximum number of research iterations init_request: Init request for orchestration """ + self.agent_execute_phase_tips_template = self._init_tips_template(execute_tips_template_dict) self.planner = Planner( model_config=model_config, mcp_tools_config_path=mcp_tools_config_path, mcp_client_timeout=mcp_client_timeout, + tips_prompt_template=self.agent_execute_phase_tips_template, historical_messages=init_request.agent_historical_messages.get( AgentType.PLANNER, [] ) if init_request else [] @@ -154,6 +159,7 @@ class Orchestrator: model_config=model_config, mcp_tools_config_path=mcp_tools_config_path, mcp_client_timeout=mcp_client_timeout, + tips_prompt_template=self.agent_execute_phase_tips_template, round_limit=research_round_limit ) @@ -161,6 +167,7 @@ class Orchestrator: model_config=model_config, mcp_tools_config_path=mcp_tools_config_path, mcp_client_timeout=mcp_client_timeout, + tips_prompt_template=self.agent_execute_phase_tips_template, ) self.current_phase = OrchestratorStatusType.PENDING self.start_time = None @@ -169,7 +176,7 @@ class Orchestrator: def run( self, query: str - ) -> Generator[Union[Message, OrchestratorStatusType], None, OrchestratorResult]: + ) -> Generator[Union[Message, OrchestratorStatusType], None, OrchestrationResult]: """ Execute the full orchestration workflow for a given query. @@ -180,7 +187,7 @@ class Orchestrator: Union[Message, PhaseStartMessage]: Progress messages during execution Returns: - OrchestratorResult: Final output artifacts + OrchestrationResult: Final output artifacts Raises: OrchestrationException: If any phase fails @@ -197,13 +204,13 @@ class Orchestrator: plan_result = yield from self.planner.run(query) if plan_result.requires_user_input: # Need user feedback - return OrchestratorResult( + return OrchestrationResult( require_user_interactive=True, require_user_feedback=plan_result.information_required, ) elif plan_result.status == PlanStatus.DRAFT: # Need user confirm plan draft - return OrchestratorResult( + return OrchestrationResult( require_user_interactive=True, plan_draft="\n".join([plan.origin_plan for plan in plan_result.search_plans]) ) @@ -218,16 +225,27 @@ class Orchestrator: ) # Phase 3: report - self.current_phase = OrchestratorStatusType.REPORTING - yield OrchestratorStatusType.RESEARCHING - report_result = yield from self.reporter.run( + self.current_phase = OrchestratorStatusType.REPORT_PLANNING + yield OrchestratorStatusType.REPORT_PLANNING + + reporter_run_generator = self.reporter.run( query=query, research_executions=research_executions, ) + try: + while True: + item = next(reporter_run_generator) + if item.metadata.get(MessageMetadataKey.AGENT_EXECUTE_PHASE, None) == AgentExecutePhase.REPORT_WRITING: + self.current_phase = OrchestratorStatusType.REPORT_WRITING + yield OrchestratorStatusType.REPORT_WRITING + else: + yield item + except StopIteration as e: + report_result = e.value self.current_phase = OrchestratorStatusType.COMPLETED - return OrchestratorResult(report=report_result) + return OrchestrationResult(report=report_result) except Exception as exc: self.current_phase = OrchestratorStatusType.FAILED @@ -237,3 +255,13 @@ class Orchestrator: ) from exc finally: self.end_time = datetime.utcnow() + + + def _init_tips_template(self, execute_tips_template_dict: Dict): + tips_template = PromptTemplate( + template_dict={} + ) + if execute_tips_template_dict: + for key, value in execute_tips_template_dict.items(): + tips_template.add_template(key, value) + return tips_template diff --git a/deepinsight/core/prompt/prompt_template.py b/deepinsight/core/prompt/prompt_template.py index cdaa9f5..df6757f 100644 --- a/deepinsight/core/prompt/prompt_template.py +++ b/deepinsight/core/prompt/prompt_template.py @@ -31,6 +31,10 @@ class PromptStage(str, Enum): REPORT_PLAN_USER = "report_plan_user" REPORT_WRITE_SYSTEM = "report_write_system" REPORT_WRITE_USER = "report_write_user" + PLAN_START_TIPS = "plan_start_tips" + RESEARCH_START_TIPS = "research_start_tips" + REPORT_PLAN_TIPS = "report_plan_tips" + REPORT_WRITE_TIPS = "report_write_tips" ERROR_RECOVERY = "error_recovery" CUSTOM = "custom" diff --git a/deepinsight/core/types/agent.py b/deepinsight/core/types/agent.py index 562fdef..8877b28 100644 --- a/deepinsight/core/types/agent.py +++ b/deepinsight/core/types/agent.py @@ -13,4 +13,13 @@ from enum import Enum class AgentType(str, Enum): PLANNER = "planner" RESEARCHER = "researcher" - REPORTER = "reporter" \ No newline at end of file + REPORTER = "reporter" + + +class AgentMessageAdditionType(str, Enum): + TIPS = "tips" + + +class AgentExecutePhase(str, Enum): + REPORT_PLANING = "report_planning" + REPORT_WRITING = "report_writing" \ No newline at end of file diff --git a/deepinsight/core/types/messages.py b/deepinsight/core/types/messages.py index 99bc6fb..589f088 100644 --- a/deepinsight/core/types/messages.py +++ b/deepinsight/core/types/messages.py @@ -126,6 +126,12 @@ class HeartbeatMessage(BaseMessage): message_type: Literal[MessageType.HEARTBEAT] = MessageType.HEARTBEAT latency_ms: Optional[int] = Field(None, ge=0) + +class MessageMetadataKey(str, Enum): + ADDITION_TYPE = "addition_type" + AGENT_EXECUTE_PHASE = "agent_execute_phase" + + # Union type representing all possible message types Message = Union[ StartMessage[T], diff --git a/deepinsight/db/__init__.py b/deepinsight/db/__init__.py new file mode 100644 index 0000000..7f294e4 --- /dev/null +++ b/deepinsight/db/__init__.py @@ -0,0 +1 @@ +# 数据库操作包 diff --git a/deepinsight/db/config.py b/deepinsight/db/config.py new file mode 100644 index 0000000..3f0881c --- /dev/null +++ b/deepinsight/db/config.py @@ -0,0 +1,66 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +import os +from dotenv import load_dotenv + +os.environ['POSTGRES_CONN_STR'] = 'postgresql+psycopg2://xuhonggong:@127.0.0.1:5432/postgres' +os.environ['DB_TYPE'] = 'postgresql' +# 加载环境变量 +load_dotenv() + + +class DatabaseConfig: + """数据库配置类""" + + def __init__(self, db_type: str = None, connection_string: str = None): + """ + 初始化数据库配置 + + :param db_type: 数据库类型,如"postgresql"或"sqlite" + :param connection_string: 数据库连接字符串 + """ + # 优先使用传入的参数,否则从环境变量获取,最后使用默认值 + self.db_type = db_type or os.getenv("DB_TYPE", "postgresql") + + if connection_string: + self.connection_string = connection_string + else: + if self.db_type == "postgresql": + self.connection_string = os.getenv( + "POSTGRES_CONN_STR", + "postgresql+psycopg2://postgres:postgres@localhost:5432/chat_db" + ) + elif self.db_type == "sqlite": + self.connection_string = os.getenv( + "SQLITE_CONN_STR", + "sqlite:///./chat_db.db" + ) + else: + raise ValueError(f"不支持的数据库类型: {self.db_type}") + + +# 创建默认数据库配置 +default_config = DatabaseConfig() + +# 创建引擎 +engine = create_engine( + default_config.connection_string, + echo=False, # 设置为True可打印SQL语句,调试时使用 + connect_args={"check_same_thread": False} if default_config.db_type == "sqlite" else {} +) + +# 创建会话 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 声明基类 +Base = declarative_base() + + +def get_db(): + """获取数据库会话的依赖项""" + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/deepinsight/db/init_db.py b/deepinsight/db/init_db.py new file mode 100644 index 0000000..f19c356 --- /dev/null +++ b/deepinsight/db/init_db.py @@ -0,0 +1,73 @@ +from sqlalchemy import create_engine, text +from sqlalchemy.exc import SQLAlchemyError +import os + + +def init_database(): + """初始化数据库,创建UUID扩展和所需表结构""" + try: + # 获取数据库连接URL + db_url = 'postgresql+psycopg2://xuhonggong:@127.0.0.1:5432/postgres' + engine = create_engine(db_url) + + # 连接数据库并执行初始化操作 + # 使用begin()替代connect(),自动管理事务 + with engine.begin() as conn: + # 创建UUID扩展(仅PostgreSQL需要) + if db_url.startswith('postgresql'): + print("创建UUID扩展...") + conn.execute(text("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")) + + # 创建conversation表 + print("创建conversation表...") + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS conversation ( + conversation_id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), + user_id VARCHAR(36) NOT NULL, + created_time TIMESTAMP(3) WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP(3), + title VARCHAR(255) DEFAULT '新建对话', + status VARCHAR(50) NOT NULL DEFAULT 'active' + ); + """)) + + # 创建message表 + print("创建message表...") + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS message ( + message_id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), + conversation_id UUID NOT NULL, + content TEXT NOT NULL, + type VARCHAR(50) NOT NULL, + created_time TIMESTAMP(3) WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP(3), + FOREIGN KEY (conversation_id) REFERENCES conversation(conversation_id) ON DELETE CASCADE + ); + """)) + + # 创建report表 + print("创建report表...") + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS report ( + report_id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), + message_id UUID NOT NULL, + conversation_id UUID NOT NULL, + thought TEXT, + report_content TEXT NOT NULL, + created_time TIMESTAMP(3) WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP(3), + FOREIGN KEY (message_id) REFERENCES message(message_id) ON DELETE CASCADE, + FOREIGN KEY (conversation_id) REFERENCES conversation(conversation_id) ON DELETE CASCADE + ); + """)) + + # 不需要手动调用commit(),engine.begin()会自动提交 + print("数据库表结构创建成功!") + + except SQLAlchemyError as e: + print(f"数据库初始化失败: {str(e)}") + raise + except Exception as e: + print(f"发生错误: {str(e)}") + raise + + +if __name__ == "__main__": + init_database() diff --git a/deepinsight/db/main.py b/deepinsight/db/main.py new file mode 100644 index 0000000..5317869 --- /dev/null +++ b/deepinsight/db/main.py @@ -0,0 +1,87 @@ +import uuid + +from deepinsight.db.config import get_db +from deepinsight.db.repositories.conversation_repository import ConversationRepository +from deepinsight.db.repositories.message_repository import MessageRepository +from deepinsight.db.repositories.report_repository import ReportRepository +from deepinsight.db.schemas.conversation import Conversation +from deepinsight.db.schemas.message import Message +from deepinsight.db.schemas.report import Report + + +def example_usage(): + """数据库操作示例""" + # 获取数据库会话 + db = next(get_db()) + + try: + # 初始化仓库 + conversation_repo = ConversationRepository(db) + message_repo = MessageRepository(db) + report_repo = ReportRepository(db) + + # 创建新对话 + new_conversation = \ + Conversation( + user_id=str(uuid.uuid4()), + title="示例对话" + ) + saved_conversation = conversation_repo.create(new_conversation) + print(f"创建对话: {saved_conversation}") + + # 添加消息 + user_message = Message( + conversation_id=saved_conversation.conversation_id, + content="你好,这是一条用户消息", + type="chat" + ) + saved_message = message_repo.create(user_message) + print(f"创建消息: {saved_message}") + + # 添加报告消息 + report_message = Message( + conversation_id=saved_conversation.conversation_id, + content="正在生成报告...", + type="report" + ) + saved_report_message = message_repo.create(report_message) + print(f"创建报告消息: {saved_report_message}") + + # 添加报告 + new_report = Report( + message_id=saved_report_message.message_id, + conversation_id=saved_conversation.conversation_id, + thought="这是思考过程...", + report_content="这是生成的报告内容..." + ) + saved_report = report_repo.create(new_report) + print(f"创建报告: {saved_report}") + + # 查询对话中的所有消息 + messages = message_repo.get_by_conversation_id(saved_conversation.conversation_id) + print(f"\n对话 {saved_conversation.conversation_id} 中的消息:") + for msg in messages: + print(f"- {msg.type}: {msg.content}") + + # 查询对话中的报告 + reports = report_repo.get_by_conversation_id(saved_conversation.conversation_id) + print(f"\n对话 {saved_conversation.conversation_id} 中的报告:") + for rep in reports: + print(f"- 报告内容: {rep.report_content}") + + # 更新对话标题 + saved_conversation.title = "更新后的对话标题" + updated_conversation = conversation_repo.update(saved_conversation) + print(f"\n更新后的对话: {updated_conversation}") + + # 删除对话(会级联删除相关消息和报告) + # conversation_repo.delete(saved_conversation) + # print(f"\n已删除对话: {saved_conversation.conversation_id}") + + finally: + # 关闭数据库会话 + db.close() + + +if __name__ == "__main__": + example_usage() diff --git a/deepinsight/db/repositories/__init__.py b/deepinsight/db/repositories/__init__.py new file mode 100644 index 0000000..aea4ccb --- /dev/null +++ b/deepinsight/db/repositories/__init__.py @@ -0,0 +1 @@ +# 数据访问层 diff --git a/deepinsight/db/repositories/base_repository.py b/deepinsight/db/repositories/base_repository.py new file mode 100644 index 0000000..3923e86 --- /dev/null +++ b/deepinsight/db/repositories/base_repository.py @@ -0,0 +1,71 @@ +from sqlalchemy.orm import Session +from typing import Generic, TypeVar, List, Optional, Type + +# 泛型类型变量 +ModelType = TypeVar("ModelType") + + +class BaseRepository(Generic[ModelType]): + """基础数据访问类,提供通用的CRUD操作""" + + def __init__(self, db: Session, model: Type[ModelType]): + """ + 初始化基础仓库 + + :param db: 数据库会话 + :param model: 数据模型类 + """ + self.db = db + self.model = model + + def create(self, obj: ModelType) -> ModelType: + """ + 创建新记录 + + :param obj: 要创建的对象 + :return: 创建后的对象 + """ + self.db.add(obj) + self.db.commit() + self.db.refresh(obj) + return obj + + def get_by_id(self, id: str) -> Optional[ModelType]: + """ + 根据ID获取记录 + + :param id: 记录ID + :return: 找到的对象或None + """ + return self.db.query(self.model).filter(self.model.id == id).first() + + def get_all(self, skip: int = 0, limit: int = 100) -> List[ModelType]: + """ + 获取所有记录 + + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 记录列表 + """ + return self.db.query(self.model).offset(skip).limit(limit).all() + + def update(self, obj: ModelType) -> ModelType: + """ + 更新记录 + + :param obj: 要更新的对象 + :return: 更新后的对象 + """ + self.db.merge(obj) + self.db.commit() + self.db.refresh(obj) + return obj + + def delete(self, obj: ModelType) -> None: + """ + 删除记录 + + :param obj: 要删除的对象 + """ + self.db.delete(obj) + self.db.commit() diff --git a/deepinsight/db/repositories/conversation_repository.py b/deepinsight/db/repositories/conversation_repository.py new file mode 100644 index 0000000..f48398d --- /dev/null +++ b/deepinsight/db/repositories/conversation_repository.py @@ -0,0 +1,68 @@ +from sqlalchemy.orm import Session +from typing import List, Optional + +from deepinsight.db.repositories.base_repository import BaseRepository +from deepinsight.db.schemas.conversation import Conversation + + +class ConversationRepository(BaseRepository[Conversation]): + """对话数据访问类""" + + def __init__(self, db: Session): + """初始化对话仓库""" + super().__init__(db, Conversation) + + def get_by_id(self, conversation_id: str) -> Optional[Conversation]: + """ + 根据对话ID获取对话 + + :param conversation_id: 对话ID + :return: 找到的对话或None + """ + return self.db.query(Conversation).filter(Conversation.conversation_id == conversation_id).first() + + def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List[Conversation]: + """ + 根据用户ID获取对话列表 + + :param user_id: 用户ID + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 对话列表 + """ + return self.db.query(Conversation)\ + .filter(Conversation.user_id == user_id)\ + .order_by(Conversation.created_time.desc())\ + .offset(skip)\ + .limit(limit)\ + .all() + + def get_active_by_user_id(self, user_id: str) -> List[Conversation]: + """ + 获取用户的活跃对话 + + :param user_id: 用户ID + :return: 活跃对话列表 + """ + return self.db.query(Conversation)\ + .filter( + Conversation.user_id == user_id, + Conversation.status == "active" + )\ + .order_by(Conversation.created_time.desc())\ + .all() + + def update_status(self, conversation_id: str, status: str) -> Optional[Conversation]: + """ + 更新对话状态 + + :param conversation_id: 对话ID + :param status: 新状态 + :return: 更新后的对话或None + """ + conversation = self.get_by_id(conversation_id) + if conversation: + conversation.status = status + self.db.commit() + self.db.refresh(conversation) + return conversation diff --git a/deepinsight/db/repositories/message_repository.py b/deepinsight/db/repositories/message_repository.py new file mode 100644 index 0000000..5ae9b6a --- /dev/null +++ b/deepinsight/db/repositories/message_repository.py @@ -0,0 +1,80 @@ +from sqlalchemy import desc +from sqlalchemy.orm import Session +from typing import List, Optional + +from deepinsight.db.repositories.base_repository import BaseRepository +from deepinsight.db.schemas.message import Message + + +class MessageRepository(BaseRepository[Message]): + """消息数据访问类""" + + def __init__(self, db: Session): + """初始化消息仓库""" + super().__init__(db, Message) + + def get_by_id(self, message_id: str) -> Optional[Message]: + """ + 根据消息ID获取消息 + + :param message_id: 消息ID + :return: 找到的消息或None + """ + return self.db.query(Message).filter(Message.message_id == message_id).first() + + def get_by_conversation_id(self, conversation_id: str, skip: int = 0, limit: int = 100) -> List[Message]: + """ + 根据对话ID获取消息列表 + + :param conversation_id: 对话ID + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 消息列表 + """ + return self.db.query(Message)\ + .filter(Message.conversation_id == conversation_id)\ + .order_by(Message.created_time.asc())\ + .offset(skip)\ + .limit(limit)\ + .all() + + def get_all_by_conversation_id(self, conversation_id: str) -> List[Message]: + """ + 根据对话ID获取消息列表 + + :param conversation_id: 对话ID + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 消息列表 + """ + return self.db.query(Message) \ + .filter(Message.conversation_id == conversation_id) \ + .order_by(Message.created_time.asc()) \ + .all() + + def get_by_type(self, conversation_id: str, message_type: str) -> List[Message]: + """ + 根据类型获取特定对话中的消息 + + :param conversation_id: 对话ID + :param message_type: 消息类型 + :return: 消息列表 + """ + return self.db.query(Message)\ + .filter( + Message.conversation_id == conversation_id, + Message.type == message_type + )\ + .order_by(Message.created_time.asc())\ + .all() + + def delete_by_conversation_id(self, conversation_id: str) -> None: + """ + 删除特定对话的所有消息 + + :param conversation_id: 对话ID + """ + self.db.query(Message)\ + .filter(Message.conversation_id == conversation_id)\ + .delete() + self.db.commit() diff --git a/deepinsight/db/repositories/report_repository.py b/deepinsight/db/repositories/report_repository.py new file mode 100644 index 0000000..7fdcb08 --- /dev/null +++ b/deepinsight/db/repositories/report_repository.py @@ -0,0 +1,58 @@ +from sqlalchemy.orm import Session +from typing import List, Optional + +from deepinsight.db.repositories.base_repository import BaseRepository +from deepinsight.db.schemas.report import Report + + +class ReportRepository(BaseRepository[Report]): + """报告数据访问类""" + + def __init__(self, db: Session): + """初始化报告仓库""" + super().__init__(db, Report) + + def get_by_id(self, report_id: str) -> Optional[Report]: + """ + 根据报告ID获取报告 + + :param report_id: 报告ID + :return: 找到的报告或None + """ + return self.db.query(Report).filter(Report.report_id == report_id).first() + + def get_by_message_id(self, message_id: str) -> Optional[Report]: + """ + 根据消息ID获取报告 + + :param message_id: 消息ID + :return: 找到的报告或None + """ + return self.db.query(Report).filter(Report.message_id == message_id).first() + + def get_by_conversation_id(self, conversation_id: str, skip: int = 0, limit: int = 100) -> List[Report]: + """ + 根据对话ID获取报告列表 + + :param conversation_id: 对话ID + :param skip: 跳过的记录数 + :param limit: 最大返回记录数 + :return: 报告列表 + """ + return self.db.query(Report)\ + .filter(Report.conversation_id == conversation_id)\ + .order_by(Report.created_time.desc())\ + .offset(skip)\ + .limit(limit)\ + .all() + + def delete_by_conversation_id(self, conversation_id: str) -> None: + """ + 删除特定对话的所有报告 + + :param conversation_id: 对话ID + """ + self.db.query(Report)\ + .filter(Report.conversation_id == conversation_id)\ + .delete() + self.db.commit() diff --git a/deepinsight/db/schemas/__init__.py b/deepinsight/db/schemas/__init__.py new file mode 100644 index 0000000..ede9439 --- /dev/null +++ b/deepinsight/db/schemas/__init__.py @@ -0,0 +1 @@ +# 数据模型定义 diff --git a/deepinsight/db/schemas/conversation.py b/deepinsight/db/schemas/conversation.py new file mode 100644 index 0000000..21de321 --- /dev/null +++ b/deepinsight/db/schemas/conversation.py @@ -0,0 +1,43 @@ +from sqlalchemy import Column, String, DateTime, Text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func +import uuid + +from deepinsight.db.config import Base + + +class Conversation(Base): + """对话表模型""" + __tablename__ = "conversation" + + conversation_id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + comment="对话ID" + ) + user_id = Column( + String(36), + nullable=False, + comment="用户ID" + ) + created_time = Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + comment="创建时间" + ) + title = Column( + String(255), + default="新建对话", + comment="对话标题" + ) + status = Column( + String(50), + nullable=False, + default="active", + comment="对话状态" + ) + + def __repr__(self): + return f"" diff --git a/deepinsight/db/schemas/message.py b/deepinsight/db/schemas/message.py new file mode 100644 index 0000000..4349f1f --- /dev/null +++ b/deepinsight/db/schemas/message.py @@ -0,0 +1,49 @@ +from sqlalchemy import Column, String, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +import uuid + +from deepinsight.db.config import Base +from .report import Report # 保证先定义Report,再定义Message + + +class Message(Base): + """消息表模型""" + __tablename__ = "message" + + message_id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + comment="消息ID" + ) + conversation_id = Column( + UUID(as_uuid=True), + ForeignKey("conversation.conversation_id", ondelete="CASCADE"), + nullable=False, + comment="关联的对话ID" + ) + content = Column( + Text, + nullable=False, + comment="消息内容" + ) + type = Column( + String(50), + nullable=False, + comment="消息类型:chat, planner, report" + ) + created_time = Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + comment="创建时间" + ) + + # 关系 + conversation = relationship("Conversation", backref="messages") + report = relationship("Report", backref="message", uselist=False, cascade="all, delete-orphan") + + def __repr__(self): + return f"" diff --git a/deepinsight/db/schemas/report.py b/deepinsight/db/schemas/report.py new file mode 100644 index 0000000..71570cd --- /dev/null +++ b/deepinsight/db/schemas/report.py @@ -0,0 +1,53 @@ +from sqlalchemy import Column, String, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +import uuid + +from deepinsight.db.config import Base + + +class Report(Base): + """报告表模型""" + __tablename__ = "report" + + report_id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + comment="报告ID" + ) + message_id = Column( + UUID(as_uuid=True), + ForeignKey("message.message_id", ondelete="CASCADE"), + nullable=False, + unique=True, + comment="关联的消息ID" + ) + conversation_id = Column( + UUID(as_uuid=True), + ForeignKey("conversation.conversation_id", ondelete="CASCADE"), + nullable=False, + comment="关联的对话ID" + ) + thought = Column( + Text, + comment="思考过程" + ) + report_content = Column( + Text, + nullable=False, + comment="报告内容" + ) + created_time = Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + comment="创建时间" + ) + + # 关系 + conversation = relationship("Conversation", backref="reports") + + def __repr__(self): + return f"" diff --git a/deepinsight/service/__init__.py b/deepinsight/service/__init__.py index e69de29..d2976bf 100644 --- a/deepinsight/service/__init__.py +++ b/deepinsight/service/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. \ No newline at end of file diff --git a/deepinsight/service/conversation.py b/deepinsight/service/conversation.py new file mode 100644 index 0000000..8b60b7e --- /dev/null +++ b/deepinsight/service/conversation.py @@ -0,0 +1,62 @@ +import json +from typing import List + +from deepinsight.db.config import get_db +from deepinsight.db.repositories.conversation_repository import ConversationRepository +from deepinsight.db.repositories.message_repository import MessageRepository +from deepinsight.service.deep_research import DeepResearchService +from deepinsight.service.schemas.chat import ServiceMessage + + +class ConversationService: + def __init__(self, storage): + self.storage = storage + + def get_conversation_info(self, conversation_id_str: str): + repository = ConversationRepository(self.storage) + return repository.get_by_id(conversation_id_str) + + def get_history_messages(self, conversation_id_str: str) -> List[ServiceMessage]: + repository = MessageRepository(self.storage) + messages_from_db = repository.get_by_conversation_id(conversation_id_str) + + processed_messages = [] + for msg in messages_from_db: + content_to_use = msg.content + if msg.type == "report": + processed_report = DeepResearchService.get_report_and_thought_by_message_id(msg.message_id) + content_to_use = processed_report.thought.messages + [processed_report.report] + + # @TODO: 还差updated time + + processed_message = ServiceMessage( + id=str(msg.message_id), + content=content_to_use, + role=msg.type, + created_at=msg.created_time.isoformat() if msg.created_time else None + ) + + processed_messages.append(processed_message) + return processed_messages + + +# 示例用法 +if __name__ == "__main__": + db_generator = get_db() + db_session = next(db_generator) + + service = ConversationService(db_session) + + # 尝试获取一个存在的对话历史 + conv_id_present = 'e6653788-7aa0-489a-b08d-79b2c3fb0b76' + print(f"Fetching history for conversation ID: {conv_id_present}") + # try: + history_present = service.get_history_messages(conv_id_present) + for each in history_present: + print(each.model_dump_json()) + + # print(history_present) + # except ValueError as e: + # print(f"Error: {e}") + + print("\n" + "=" * 50 + "\n") diff --git a/deepinsight/service/deep_research.py b/deepinsight/service/deep_research.py new file mode 100644 index 0000000..5d925d0 --- /dev/null +++ b/deepinsight/service/deep_research.py @@ -0,0 +1,507 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import logging +import uuid +from datetime import datetime +from enum import Enum +from typing import Generator, Union, List, Dict, Optional + +from camel.types import ModelPlatformType, ModelType +from camel.types.agents import ToolCallingRecord +from pydantic import BaseModel + +from deepinsight.config.model import ModelConfig +from deepinsight.core.orchestrator import Orchestrator, OrchestrationResult, OrchestratorStatusType, OrchestrationRequest +from deepinsight.core.prompt.prompt_template import PromptStage +from deepinsight.core.types.agent import AgentType, AgentMessageAdditionType +from deepinsight.core.types.historical_message import HistoricalMessage, HistoricalMessageType +from deepinsight.core.types.messages import StartMessage, ChunkMessage, EndMessage, MessageMetadataKey, CompleteMessage, \ + ErrorMessage, Message as CoreMessage +from deepinsight.db.config import get_db +from deepinsight.db.repositories.conversation_repository import ConversationRepository +from deepinsight.db.repositories.message_repository import MessageRepository +from deepinsight.db.repositories.report_repository import ReportRepository +from deepinsight.db.schemas.message import Message as DBSchemaMessage +from deepinsight.db.schemas.report import Report as DBSchemaReport +from deepinsight.service.schemas.chat import ServiceMessage + +AGENT_PROCESS_TIPS_TEMPLATE = { + PromptStage.PLAN_START_TIPS: "研究计划如下\n{search_plans}", + PromptStage.RESEARCH_START_TIPS: "任务{task_id}: {task_title}", + PromptStage.REPORT_PLAN_TIPS: "正在分析结果,计划生成报告", + PromptStage.REPORT_WRITE_TIPS: "正在生成报告", +} + + +class ThoughtProcessType(str, Enum): + TITLE = "title" + CONTENT = "content" + TOOL_CALL = "tool_call" + + +class MessageType(str, Enum): + USER = "user" + SEARCH_PLAN = "search_plan" + REPORT = "report" + + +class MessageItem(BaseModel): + type: MessageType + content: str = "" + created_at: datetime + message_id: str + + +class ThoughtItem(BaseModel): + type: ThoughtProcessType + content: Union[str, ToolCallingRecord] = "" + created_at: datetime + + +class ReportItem(BaseModel): + type: MessageType = MessageType.REPORT + content: str = "" + created_at: datetime + + +class ThoughtMessages(BaseModel): + messages: List[ThoughtItem] + + +class DeepResearchResponse(BaseModel): + messages: List[Union[MessageItem, ThoughtItem, ReportItem]] + + +class ThoughtAndReport(BaseModel): + thought: ThoughtMessages + report: ReportItem + + +class DeepResearchService: + """ + Stateless service for deep research operations that: + - Coordinates research orchestration + - Handles database persistence + - Manages streaming responses + + Designed for single-use per API request with no shared state. + """ + + @classmethod + def research( + cls, + query: str, + conversation_id: str, + user_id: str, + ) -> Generator[List[ServiceMessage], None, None]: + # Initialize repositories with fresh session + db = next(get_db()) + try: + conversation_repo = ConversationRepository(db) + message_repo = MessageRepository(db) + report_repo = ReportRepository(db) + + conversation = conversation_repo.get_by_id(conversation_id) + if not conversation: + raise ValueError(f"Conversation {conversation_id} not found") + + # Add user query to history + user_message = DBSchemaMessage( + conversation_id=conversation_id, + content=query, + type=MessageType.USER.value, + ) + message_repo.create(user_message) + + run_orchestration_generator = cls._run_orchestration( + query=query, + conversation=conversation, + user_id=user_id, + report_repo=report_repo, + message_repo=message_repo, + ) + try: + while True: + item = next(run_orchestration_generator) + yield cls._wrap_response(item) + except StopIteration as e: + pass + finally: + db.close() + + @classmethod + def _wrap_response( + cls, + original_response: DeepResearchResponse, + ) -> List[ServiceMessage]: + processed_messages = [] + thought_and_report_message = ServiceMessage( + id=str(uuid.uuid4()), + content=[], + role=MessageType.REPORT.value, + created_at=datetime.now() + ) + for msg in original_response.messages: + if isinstance(msg, MessageItem): + processed_message = ServiceMessage( + id=str(msg.message_id), + content=msg.content, + role=msg.type, + created_at=msg.created_at + ) + processed_messages.append(processed_message) + else: + thought_and_report_message.content.append( + msg + ) + if thought_and_report_message.content: + processed_messages.append(thought_and_report_message) + return processed_messages + + @classmethod + def get_report_and_thought_by_message_id(cls, message_id: str): + db = next(get_db()) + try: + report_repo = ReportRepository(db) + report_and_thought_data = report_repo.get_by_message_id(message_id) + thought_process = ThoughtMessages.model_validate_json(report_and_thought_data.thought) + report = ReportItem( + content=report_and_thought_data.report_content, + created_at=report_and_thought_data.created_time, + ) + return ThoughtAndReport(thought=thought_process, report=report) + finally: + db.close() + + @classmethod + def _run_orchestration( + cls, + query, + conversation, + user_id, + report_repo: ReportRepository, + message_repo: MessageRepository, + ) -> Generator[DeepResearchResponse, None, None]: + full_response = DeepResearchResponse( + messages=[] + ) + + thought_process = ThoughtMessages( + messages=[], + ) + + final_report: ReportItem = None + + history_interactive_messages = cls._get_messages_by_conversation_id_until_report( + conversation_id=conversation.conversation_id, + message_repo=message_repo, + ) + orchestration = Orchestrator( + model_config=ModelConfig( + model_platform=ModelPlatformType.DEEPSEEK, + model_type=ModelType.DEEPSEEK_CHAT, + model_config_dict=dict( + stream=True + ), + ), + mcp_tools_config_path="./mcp_config.json", + research_round_limit=1, + init_request=OrchestrationRequest( + agent_historical_messages={ + AgentType.PLANNER: [cls._convert_message_to_orchestration_message(each) for each in + history_interactive_messages], + } + ), + execute_tips_template_dict=AGENT_PROCESS_TIPS_TEMPLATE, + ) + orchestration_generator = orchestration.run(query) + current_orchestration_phase = OrchestratorStatusType.PENDING + stream_chunk_caches: Dict[str, Union[MessageItem, ThoughtItem, ReportItem]] = {} + report_db_item: DBSchemaReport = None + message_db_item: DBSchemaMessage = None + + try: + while True: + item = next(orchestration_generator) + if isinstance(item, OrchestratorStatusType): + current_orchestration_phase = item + message_new_db_item = cls._insert_message_by_orchestration_phase( + conversation=conversation, + message_repo=message_repo, + current_orchestration_phase=current_orchestration_phase, + message=item, + ) + if message_new_db_item: + message_db_item = message_new_db_item + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=item, + thought_process=thought_process, + ) + else: + if isinstance(item, StartMessage): + stream_id = item.stream_id + cached_item = stream_chunk_caches.setdefault( + stream_id, + cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + ) + + if cached_item: + # Add message to report history + full_response.messages.append( + cached_item + ) + + if isinstance(cached_item, ReportItem): + final_report = cached_item + + elif isinstance(item, ChunkMessage): + stream_id = item.stream_id + cached_item = stream_chunk_caches.setdefault( + stream_id, + cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + ) + cached_item.content += item.payload + elif isinstance(item, EndMessage): + stream_id = item.stream_id + cached_item = stream_chunk_caches.setdefault( + stream_id, + cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + ) + cached_item.content += item.payload + if cached_item.content: + message_new_db_item = cls._insert_message_by_orchestration_phase( + conversation=conversation, + message_repo=message_repo, + current_orchestration_phase=current_orchestration_phase, + message=cached_item, + ) + if message_new_db_item: + message_db_item = message_new_db_item + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=cached_item, + thought_process=thought_process, + ) + else: + if cached_item in full_response.messages: + full_response.messages.remove(cached_item) + stream_chunk_caches.pop(stream_id) + else: + if isinstance(item, CompleteMessage) and not item.payload: + continue + if isinstance(item, ErrorMessage) and not item.error_message: + continue + process_item = cls._create_precess_item_by_orchestration_phase( + current_orchestration_phase, item + ) + if process_item: + message_new_db_item = cls._insert_message_by_orchestration_phase( + conversation=conversation, + message_repo=message_repo, + current_orchestration_phase=current_orchestration_phase, + message=process_item, + ) + if message_new_db_item: + message_db_item = message_new_db_item + full_response.messages.append( + process_item + ) + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=process_item, + thought_process=thought_process, + ) + yield full_response + + except StopIteration as e: + report_artifact: OrchestrationResult = e.value + if report_artifact.report: + final_report = ReportItem( + content=report_artifact.report, + created_at=datetime.now(), + ) + full_response.messages.append(final_report) + report_db_item = cls._insert_or_update_thought_and_report_process( + conversation=conversation, + current_orchestration_phase=current_orchestration_phase, + report_repo=report_repo, + relative_message=message_db_item, + has_insert_report=report_db_item, + message=final_report, + thought_process=thought_process, + ) + yield full_response + + @classmethod + def _get_messages_by_conversation_id_until_report(cls, conversation_id: str, message_repo: MessageRepository) -> \ + List[DBSchemaMessage]: + """ + Retrieve messages for a conversation in reverse chronological order, + stopping when a report message is encountered. + + Args: + conversation_id: str of the conversation to query + message_repo: SQLAlchemy session object + + Returns: + List of Message objects in chronological order (oldest first) + """ + # Query messages in descending order (newest first) + messages = message_repo.get_all_by_conversation_id(conversation_id=conversation_id) + + # Collect messages until we hit a report + filtered_messages = [] + for msg in reversed(messages): # 从最新消息开始检查 + if msg.type == MessageType.REPORT: + break + filtered_messages.append(msg) + + # Return in chronological order (oldest first) + return list(reversed(filtered_messages)) + + @classmethod + def _convert_message_to_orchestration_message(cls, message: DBSchemaMessage) -> HistoricalMessage: + return HistoricalMessage( + content=message.content, + type=HistoricalMessageType.RESEARCH_PLAN if message.type == "search_plan" else HistoricalMessageType(message.type), + created_time=message.created_time, + message_id=str(message.message_id) + ) + + @classmethod + def _insert_message_by_orchestration_phase( + cls, + conversation, + message_repo: MessageRepository, + current_orchestration_phase: OrchestratorStatusType, + message: Union[OrchestratorStatusType, MessageItem, ThoughtItem, ReportItem] + ) -> Optional[DBSchemaMessage]: + if isinstance(message, OrchestratorStatusType): + # When receiving a PhaseStartMessage of RESEARCHING phase, + # insert an empty REPORT type message as a placeholder + if message == OrchestratorStatusType.RESEARCHING: + return message_repo.create( + DBSchemaMessage( + conversation_id=conversation.conversation_id, + content="", + type=MessageType.REPORT, + ) + ) + elif current_orchestration_phase == OrchestratorStatusType.PLANNING: + # Stores CompleteMessage/ErrorMessage as SEARCH_PLAN during PLANNING phase + return message_repo.create( + DBSchemaMessage( + conversation_id=conversation.conversation_id, + content=message.content, + type=MessageType.SEARCH_PLAN, + ) + ) + logging.warning(f"Orchestration phase {current_orchestration_phase!s} does not require message insertion") + return None + + @classmethod + def _insert_or_update_thought_and_report_process( + cls, + conversation, + current_orchestration_phase: OrchestratorStatusType, + report_repo: ReportRepository, + relative_message: Optional[DBSchemaMessage], + has_insert_report: DBSchemaReport, + message: Optional[Union[OrchestratorStatusType, ThoughtItem, ReportItem]], + thought_process: ThoughtMessages, + ): + if relative_message is None: + logging.warning(f"Relative message is None, can not insert or update report.") + return None + if isinstance(message, OrchestratorStatusType): + if message == OrchestratorStatusType.RESEARCHING and has_insert_report is None: + has_insert_report = report_repo.create( + DBSchemaReport( + message_id=relative_message.message_id, + conversation_id=conversation.conversation_id, + thought="", + report_content="", + ) + ) + else: + if current_orchestration_phase == OrchestratorStatusType.RESEARCHING or current_orchestration_phase == OrchestratorStatusType.REPORT_PLANNING: + if isinstance(message, ThoughtItem): + thought_process.messages.append(message) + if has_insert_report is not None: + has_insert_report.thought = thought_process.model_dump_json() + report_repo.update(has_insert_report) + return has_insert_report + elif current_orchestration_phase == OrchestratorStatusType.REPORT_WRITING or current_orchestration_phase == OrchestratorStatusType.COMPLETED: + if has_insert_report is not None and isinstance(message, ReportItem): + has_insert_report.report_content = message.content + report_repo.update(has_insert_report) + return has_insert_report + + return has_insert_report + + @classmethod + def _create_precess_item_by_orchestration_phase( + cls, + current_orchestration_phase: OrchestratorStatusType, + data: CoreMessage + ) -> Optional[Union[MessageItem, ThoughtItem, ReportItem]]: + if isinstance(data, ErrorMessage): + return MessageItem( + type=MessageType.SEARCH_PLAN, + content=data.error_message, + created_at=data.timestamp, + message_id=str(uuid.uuid4()), + ) + elif current_orchestration_phase == OrchestratorStatusType.PLANNING: + if isinstance(data.payload, str): + return MessageItem( + type=MessageType.SEARCH_PLAN, + content=data.payload, + created_at=data.timestamp, + message_id=str(uuid.uuid4()), + ) + elif current_orchestration_phase == OrchestratorStatusType.RESEARCHING or current_orchestration_phase == OrchestratorStatusType.REPORT_PLANNING: + thought_type = ThoughtProcessType.CONTENT + if isinstance(data.payload, ToolCallingRecord): + thought_type = ThoughtProcessType.TOOL_CALL + elif data.metadata.get(MessageMetadataKey.ADDITION_TYPE, + None) == AgentMessageAdditionType.TIPS: + thought_type = ThoughtProcessType.TITLE + return ThoughtItem( + type=thought_type, + content=data.payload.model_dump_json() if isinstance(data.payload, + ToolCallingRecord) else data.payload, + created_at=data.timestamp, + ) + elif current_orchestration_phase == OrchestratorStatusType.REPORT_WRITING: + return ReportItem( + content=data.payload, + created_at=data.timestamp, + ) + return None diff --git a/deepinsight/service/schemas/__init__.py b/deepinsight/service/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepinsight/service/schemas/chat.py b/deepinsight/service/schemas/chat.py new file mode 100644 index 0000000..3f1a729 --- /dev/null +++ b/deepinsight/service/schemas/chat.py @@ -0,0 +1,42 @@ +from datetime import datetime +from typing import Union, List + +from pydantic import BaseModel, Field + + +class ServiceMessage(BaseModel): + """ + Represents a single message within a conversation. + """ + id: str + content: Union[str, List] + role: str + created_at: datetime + + +class GetChatHistoryData(BaseModel): + """ + Schema for the request body when fetching chat history. + """ + conversationId: str = Field(alias="conversationId") + +class GetChatHistoryStructure(BaseModel): + """ + Represents the structured data for a chat history response. + """ + conversation_id: str = Field(alias="conversationId") + user_id: str = Field(alias="userId") + created_time: str + title: str + status: str + messages: List[ServiceMessage] + # TODO: Missing 'updated_time' field for conversation update timestamp + + +class GetChatHistoryRsp(BaseModel): + """ + The complete response schema for fetching chat history. + """ + code: int + message: str + data: GetChatHistoryStructure diff --git a/pyproject.toml b/pyproject.toml index 0452803..6d34389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ python = ">=3.10,<3.13" camel-ai = ">=0.2.60" pydantic = ">=2.10.6" rich = ">=14.0.0" +sqlalchemy = ">=2.0.41" +psycopg2 = ">=2.9.10" [tool.poetry.group.dev.dependencies] pytest = "*" \ No newline at end of file diff --git a/tests/core/agents/test_planner.py b/tests/core/agents/test_planner.py new file mode 100644 index 0000000..8cbd0c5 --- /dev/null +++ b/tests/core/agents/test_planner.py @@ -0,0 +1,36 @@ +import os +import unittest +from unittest.mock import patch + +from camel.types import ModelPlatformType, ModelType + +from deepinsight.config.model import ModelConfig +from deepinsight.core.agent.planner import Planner + + +class TestPlannerAgent(unittest.TestCase): + def setUp(self): + """Test setup that runs before each test method.""" + self.patcher1 = patch('camel.models.openai_model.OpenAIModel.token_counter') + self.mock_token_counter = self.patcher1.start() + + def side_effect(text): + return len(text.split()) + self.mock_token_counter.side_effect = side_effect + + os.environ["OPENAI_API_KEY"] = "sk-test" + + self.model_config = { + "stream": True, + } + + def test_default_plan_parser(self): + planner = Planner( + model_config=ModelConfig( + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, + model_config_dict=self.model_config + ) + ) + parser_response = planner._default_plan_parser("1. 分析谷歌a2a基本信息") + self.assertEqual(len(parser_response.search_plans), 1) -- Gitee From 11ca4b9c91373cad0445a69ae032ab4135211e24 Mon Sep 17 00:00:00 2001 From: Tech1024Wizard Date: Sat, 26 Jul 2025 21:37:47 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E3=80=90feat=E3=80=91=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=8E=A5=E5=85=A5=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E4=BB=A5=E5=8F=8A=E8=AE=A4=E8=AF=81=E9=89=B4=E6=9D=83?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=EF=BC=88=E5=AF=B9=E6=8E=A5openeuler=20intell?= =?UTF-8?q?gence=E6=9C=8D=E5=8A=A1=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 18 ++ .gitignore | 3 +- deepinsight/api/app.py | 164 +++++++++--------- deepinsight/db/init_db.py | 73 -------- deepinsight/db/main.py | 87 ---------- deepinsight/db/schemas/__init__.py | 1 - deepinsight/service/auth/session_manager.py | 46 +++++ deepinsight/service/auth/session_service.py | 65 +++++++ deepinsight/service/auth/user_dependency.py | 86 +++++++++ deepinsight/service/conversation.py | 19 +- deepinsight/service/deep_research.py | 16 +- deepinsight/service/schemas/conversation.py | 73 ++++++++ deepinsight/service/schemas/model.py | 15 ++ deepinsight/{db => stores}/__init__.py | 0 deepinsight/stores/mongodb/__init__.py | 0 deepinsight/stores/mongodb/database.py | 84 +++++++++ deepinsight/stores/postgresql/__init__.py | 0 .../postgresql/database.py} | 64 +++++-- .../postgresql}/repositories/__init__.py | 0 .../repositories/base_repository.py | 9 + .../repositories/conversation_repository.py | 13 +- .../repositories/message_repository.py | 14 +- .../repositories/report_repository.py | 13 +- .../stores/postgresql/schemas/__init__.py | 0 .../postgresql}/schemas/conversation.py | 15 +- .../postgresql}/schemas/message.py | 14 +- .../postgresql}/schemas/report.py | 15 +- pyproject.toml | 30 +++- 28 files changed, 641 insertions(+), 296 deletions(-) create mode 100644 .env.example delete mode 100644 deepinsight/db/init_db.py delete mode 100644 deepinsight/db/main.py delete mode 100644 deepinsight/db/schemas/__init__.py create mode 100644 deepinsight/service/auth/session_manager.py create mode 100644 deepinsight/service/auth/session_service.py create mode 100644 deepinsight/service/auth/user_dependency.py create mode 100644 deepinsight/service/schemas/conversation.py create mode 100644 deepinsight/service/schemas/model.py rename deepinsight/{db => stores}/__init__.py (100%) create mode 100644 deepinsight/stores/mongodb/__init__.py create mode 100644 deepinsight/stores/mongodb/database.py create mode 100644 deepinsight/stores/postgresql/__init__.py rename deepinsight/{db/config.py => stores/postgresql/database.py} (44%) rename deepinsight/{db => stores/postgresql}/repositories/__init__.py (100%) rename deepinsight/{db => stores/postgresql}/repositories/base_repository.py (77%) rename deepinsight/{db => stores/postgresql}/repositories/conversation_repository.py (75%) rename deepinsight/{db => stores/postgresql}/repositories/message_repository.py (78%) rename deepinsight/{db => stores/postgresql}/repositories/report_repository.py (72%) create mode 100644 deepinsight/stores/postgresql/schemas/__init__.py rename deepinsight/{db => stores/postgresql}/schemas/conversation.py (58%) rename deepinsight/{db => stores/postgresql}/schemas/message.py (67%) rename deepinsight/{db => stores/postgresql}/schemas/report.py (64%) diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..186b622 --- /dev/null +++ b/.env.example @@ -0,0 +1,18 @@ +# .env.example +# 数据库配置,默认使用postgresql +DB_TYPE= +# postgresql配置 +POSTGRES_USER= +POSTGRES_PASSWORD= +POSTGRES_HOST= +POSTGRES_PORT= +POSTGRES_DB= + +# MongoDB 配置 +MONGODB_USER= +MONGODB_PASSWORD= +MONGODB_HOST= +MONGODB_PORT= + +# 是否需要认证鉴权: deepInsight集成到openeuler intelligence依赖认证健全,如果单用户部署则不需要,默认用户为admin +REQUIRE_AUTHENTICATION= diff --git a/.gitignore b/.gitignore index f0c638e..ead8a8d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea output **/__pycache__ -mcp_config.json \ No newline at end of file +mcp_config.json +/.env diff --git a/deepinsight/api/app.py b/deepinsight/api/app.py index 038ce27..81bf6d7 100644 --- a/deepinsight/api/app.py +++ b/deepinsight/api/app.py @@ -1,21 +1,19 @@ -import asyncio import os import uuid from datetime import datetime -from typing import Optional, Dict +from typing import Dict -from fastapi import FastAPI, Request, APIRouter, HTTPException, Query, Body +from fastapi import FastAPI, Request, APIRouter, HTTPException, Body, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from deepinsight.db.config import get_db from deepinsight.service.conversation import ConversationService from deepinsight.service.deep_research import MessageType, DeepResearchService -from deepinsight.service.schemas.chat import (ConversationListRsp, ConversationListMsg, ConversationListItem, - AddConversationRsp, BodyAddConversation, AddConversationMsg, - ResponseData, DeleteConversationData) -from deepinsight.service.schemas.chat import GetChatHistoryData, GetChatHistoryStructure, GetChatHistoryRsp -from deepinsight.service.schemas.model import LLMIteam, KbIteam +from deepinsight.service.schemas.conversation import (ConversationListRsp, ConversationListMsg, ConversationListItem, + AddConversationRsp, BodyAddConversation, AddConversationMsg, + ResponseData, DeleteConversationData, RenameConversationData, + BodyGetList) +from deepinsight.service.schemas.chat import GetChatHistoryStructure, GetChatHistoryRsp # 读取环境变量中的 API 前缀 API_PREFIX = os.getenv("API_PREFIX", "") @@ -42,80 +40,72 @@ app_instance.add_middleware( @router.post("/api/chat") async def chat_stream(request: Request): + # try: + body = await request.json() + conversation_id = body.get("conversation_id", "") or str(uuid.uuid4()) + conversation_info = ConversationService.get_conversation_info(conversation_id) + if not conversation_info: + raise HTTPException(status_code=404, detail="Conversation not found") + messages = body.get("messages", []) + if not isinstance(messages, list): + raise HTTPException(status_code=400, detail="messages must be a list") + query = None + for item in reversed(messages): + if item.get('role') == MessageType.USER.value and item.get("content", None): + query = item.get("content") + + async def fake_model_stream(): + for item in DeepResearchService.research(query=query, conversation_id=conversation_id, user_id=""): + resp = GetChatHistoryRsp( + code=0, + message="", + data=GetChatHistoryStructure( + conversation_id=conversation_id, + user_id=conversation_info.user_id, + created_time=str(datetime.now()), + title=conversation_info.title, # 原 name + status=conversation_info.status, + messages=item + ) + ).model_dump_json() + yield f"data: {resp}\n\n" + yield 'data: [DONE]\n\n' + + return StreamingResponse(fake_model_stream(), media_type="text/event-stream") + # except Exception as e: + # raise HTTPException(status_code=500, detail=str(e)) + + +# TODO +@router.get("/api/conversations", response_model=ConversationListRsp, tags=["conversation"]) +# 是否需要,如果需要是否只需要返回conversation id +async def get_conversation_list(body: BodyGetList = Depends()): try: - body = await request.json() - conversation_id = body.get("conversation_id", "") or str(uuid.uuid4()) - db_generator = get_db() - db_session = next(db_generator) - service = ConversationService(db_session) - conversation_info = service.get_conversation_info(conversation_id) - if not conversation_info: - raise HTTPException(status_code=404, detail="Conversation not found") - messages = body.get("messages", []) - if not isinstance(messages, list): - raise HTTPException(status_code=400, detail="messages must be a list") - query = None - for item in reversed(messages): - if item.get('role') == MessageType.USER.value and item.get("content", None): - query = item.get("content") - - async def fake_model_stream(): - for item in DeepResearchService.research(query=query, conversation_id=conversation_id, user_id=""): - resp = GetChatHistoryRsp( - code=0, - message="", - data=GetChatHistoryStructure( - conversation_id=conversation_id, - user_id=conversation_info.user_id, - created_time=str(datetime.now()), - title=conversation_info.title, # 原 name - status=conversation_info.status, - messages=item - ) - ).model_dump_json() - yield f"data: {resp}\n\n" - yield 'data: [DONE]\n\n' - - return StreamingResponse(fake_model_stream(), media_type="text/event-stream") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/api/conversation", response_model=ConversationListRsp, tags=["conversation"]) -async def get_conversation_list(): - try: + conversation_list = ConversationService.get_list(user_id=body.user_id, offset=body.offset, limit=body.limit) return ConversationListRsp( - code=200, + code=0, message="OK", - result=ConversationListMsg(conversations=list(_conversations.values())) + data=ConversationListMsg(conversations=conversation_list) ) + except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +# temporarily deprecated @router.post("/api/conversation", response_model=AddConversationRsp, tags=["conversation"]) async def add_conversation( - appId: Optional[str] = Query(default=""), - debug: Optional[bool] = Query(default=False), body: BodyAddConversation = Body(...) ): try: - conversation_id = str(uuid.uuid4()) - new_conversation = ConversationListItem( - conversationId=conversation_id, - title="新会话", - docCount=0, - createdTime=datetime.utcnow().isoformat(), - appId=appId, - debug=debug, - llm=LLMIteam(llmId=body.llm_id), - kbList=[KbIteam(kbId=kb, kbName=f"知识库-{kb}") for kb in (body.kb_ids or [])] - ) - _conversations[conversation_id] = new_conversation + new_conversation = ConversationService.add_conversation(user_id=body.user_id, title=body.title, + conversation_id=body.conversation_id) + return AddConversationRsp( - code=200, + code=0, message="OK", - result=AddConversationMsg(conversationId=conversation_id) + data=AddConversationMsg(conversationId=str(new_conversation.conversation_id), + created_time=str(new_conversation.created_time)) ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -124,30 +114,40 @@ async def add_conversation( @router.delete("/api/conversation", response_model=ResponseData, tags=["conversation"]) async def delete_conversation(data: DeleteConversationData = Body(...)): try: - for cid in data.conversationList: - _conversations.pop(cid, None) - return ResponseData(code=200, message="Deleted", result={}) + for cid in data.conversation_list: + ConversationService.del_conversation(conversation_id=cid) + return ResponseData(code=0, message="Deleted", data={}) + except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.post("/api/history", response_model=GetChatHistoryRsp, tags=["conversation"]) -async def get_chat_history(data: GetChatHistoryData): - db_generator = get_db() - db_session = next(db_generator) - service = ConversationService(db_session) +@router.put("/api/conversation", response_model=ResponseData, tags=["conversation"]) +async def rename_conversation(data: RenameConversationData = Body(...)): + try: + conversation, is_succeed = ConversationService.rename_conversation(conversation_id=data.conversation_id, + new_name=data.new_name) + if is_succeed: + return ResponseData(code=0, message="Modified", data={"new_name": data.new_name}) + else: + return ResponseData(code=100, message="Conversation Not Found", data={"new_name": data.new_name}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + - conversation_info = service.get_conversation_info(data.conversationId) - history_present = service.get_history_messages(data.conversationId) - data = GetChatHistoryStructure( - conversation_id=data.conversationId, +@router.get("/api/conversations/{conversation_id}/messages", response_model=GetChatHistoryRsp, tags=["conversation"]) +async def get_conversation_messages(conversation_id: str): + conversation_info = ConversationService.get_conversation_info(conversation_id) + history_present = ConversationService.get_history_messages(conversation_id) + new_data = GetChatHistoryStructure( + conversation_id=conversation_id, user_id=conversation_info.user_id, created_time=str(conversation_info.created_time), title=conversation_info.title, # 原 name status=conversation_info.status, messages=history_present ) - return GetChatHistoryRsp(code=0, message="ok", data=data).model_dump() + return GetChatHistoryRsp(code=0, message="ok", data=new_data) -app_instance.include_router(router, prefix=API_PREFIX) +app_instance.include_router(router, prefix=API_PREFIX) \ No newline at end of file diff --git a/deepinsight/db/init_db.py b/deepinsight/db/init_db.py deleted file mode 100644 index f19c356..0000000 --- a/deepinsight/db/init_db.py +++ /dev/null @@ -1,73 +0,0 @@ -from sqlalchemy import create_engine, text -from sqlalchemy.exc import SQLAlchemyError -import os - - -def init_database(): - """初始化数据库,创建UUID扩展和所需表结构""" - try: - # 获取数据库连接URL - db_url = 'postgresql+psycopg2://xuhonggong:@127.0.0.1:5432/postgres' - engine = create_engine(db_url) - - # 连接数据库并执行初始化操作 - # 使用begin()替代connect(),自动管理事务 - with engine.begin() as conn: - # 创建UUID扩展(仅PostgreSQL需要) - if db_url.startswith('postgresql'): - print("创建UUID扩展...") - conn.execute(text("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")) - - # 创建conversation表 - print("创建conversation表...") - conn.execute(text(""" - CREATE TABLE IF NOT EXISTS conversation ( - conversation_id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), - user_id VARCHAR(36) NOT NULL, - created_time TIMESTAMP(3) WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP(3), - title VARCHAR(255) DEFAULT '新建对话', - status VARCHAR(50) NOT NULL DEFAULT 'active' - ); - """)) - - # 创建message表 - print("创建message表...") - conn.execute(text(""" - CREATE TABLE IF NOT EXISTS message ( - message_id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), - conversation_id UUID NOT NULL, - content TEXT NOT NULL, - type VARCHAR(50) NOT NULL, - created_time TIMESTAMP(3) WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP(3), - FOREIGN KEY (conversation_id) REFERENCES conversation(conversation_id) ON DELETE CASCADE - ); - """)) - - # 创建report表 - print("创建report表...") - conn.execute(text(""" - CREATE TABLE IF NOT EXISTS report ( - report_id UUID PRIMARY KEY NOT NULL DEFAULT uuid_generate_v4(), - message_id UUID NOT NULL, - conversation_id UUID NOT NULL, - thought TEXT, - report_content TEXT NOT NULL, - created_time TIMESTAMP(3) WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP(3), - FOREIGN KEY (message_id) REFERENCES message(message_id) ON DELETE CASCADE, - FOREIGN KEY (conversation_id) REFERENCES conversation(conversation_id) ON DELETE CASCADE - ); - """)) - - # 不需要手动调用commit(),engine.begin()会自动提交 - print("数据库表结构创建成功!") - - except SQLAlchemyError as e: - print(f"数据库初始化失败: {str(e)}") - raise - except Exception as e: - print(f"发生错误: {str(e)}") - raise - - -if __name__ == "__main__": - init_database() diff --git a/deepinsight/db/main.py b/deepinsight/db/main.py deleted file mode 100644 index 5317869..0000000 --- a/deepinsight/db/main.py +++ /dev/null @@ -1,87 +0,0 @@ -import uuid - -from deepinsight.db.config import get_db -from deepinsight.db.repositories.conversation_repository import ConversationRepository -from deepinsight.db.repositories.message_repository import MessageRepository -from deepinsight.db.repositories.report_repository import ReportRepository -from deepinsight.db.schemas.conversation import Conversation -from deepinsight.db.schemas.message import Message -from deepinsight.db.schemas.report import Report - - -def example_usage(): - """数据库操作示例""" - # 获取数据库会话 - db = next(get_db()) - - try: - # 初始化仓库 - conversation_repo = ConversationRepository(db) - message_repo = MessageRepository(db) - report_repo = ReportRepository(db) - - # 创建新对话 - new_conversation = \ - Conversation( - user_id=str(uuid.uuid4()), - title="示例对话" - ) - saved_conversation = conversation_repo.create(new_conversation) - print(f"创建对话: {saved_conversation}") - - # 添加消息 - user_message = Message( - conversation_id=saved_conversation.conversation_id, - content="你好,这是一条用户消息", - type="chat" - ) - saved_message = message_repo.create(user_message) - print(f"创建消息: {saved_message}") - - # 添加报告消息 - report_message = Message( - conversation_id=saved_conversation.conversation_id, - content="正在生成报告...", - type="report" - ) - saved_report_message = message_repo.create(report_message) - print(f"创建报告消息: {saved_report_message}") - - # 添加报告 - new_report = Report( - message_id=saved_report_message.message_id, - conversation_id=saved_conversation.conversation_id, - thought="这是思考过程...", - report_content="这是生成的报告内容..." - ) - saved_report = report_repo.create(new_report) - print(f"创建报告: {saved_report}") - - # 查询对话中的所有消息 - messages = message_repo.get_by_conversation_id(saved_conversation.conversation_id) - print(f"\n对话 {saved_conversation.conversation_id} 中的消息:") - for msg in messages: - print(f"- {msg.type}: {msg.content}") - - # 查询对话中的报告 - reports = report_repo.get_by_conversation_id(saved_conversation.conversation_id) - print(f"\n对话 {saved_conversation.conversation_id} 中的报告:") - for rep in reports: - print(f"- 报告内容: {rep.report_content}") - - # 更新对话标题 - saved_conversation.title = "更新后的对话标题" - updated_conversation = conversation_repo.update(saved_conversation) - print(f"\n更新后的对话: {updated_conversation}") - - # 删除对话(会级联删除相关消息和报告) - # conversation_repo.delete(saved_conversation) - # print(f"\n已删除对话: {saved_conversation.conversation_id}") - - finally: - # 关闭数据库会话 - db.close() - - -if __name__ == "__main__": - example_usage() diff --git a/deepinsight/db/schemas/__init__.py b/deepinsight/db/schemas/__init__.py deleted file mode 100644 index ede9439..0000000 --- a/deepinsight/db/schemas/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# 数据模型定义 diff --git a/deepinsight/service/auth/session_manager.py b/deepinsight/service/auth/session_manager.py new file mode 100644 index 0000000..a602f0e --- /dev/null +++ b/deepinsight/service/auth/session_manager.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import logging + +from deepinsight.stores.mongodb.database import MongoDB, Session + + +class SessionManager: + """浏览器Session管理""" + + @staticmethod + async def verify_user(session_id: str) -> bool: + """验证用户是否在Session中""" + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return False + return Session(**data).user_sub is not None + except Exception as e: + err = "用户不在Session中" + logging.error("[SessionManager] %s", err) + raise e + + @staticmethod + async def get_user_sub(session_id: str) -> str: + """从Session中获取用户""" + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return None + user_sub = Session(**data).user_sub + except Exception as e: + err = "从Session中获取用户失败" + logging.error("[SessionManager] %s", err) + raise e + + return user_sub diff --git a/deepinsight/service/auth/session_service.py b/deepinsight/service/auth/session_service.py new file mode 100644 index 0000000..7cb59a2 --- /dev/null +++ b/deepinsight/service/auth/session_service.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import uuid + +from starlette import status +from starlette.exceptions import HTTPException +from starlette.requests import HTTPConnection +from deepinsight.service.auth.session_manager import SessionManager + + +class UserHTTPException(HTTPException): + def __init__(self, status_code: int, retcode: int, rtmsg: str, data): + super().__init__(status_code=status_code) + self.retcode = retcode + self.rtmsg = rtmsg + self.data = data + + +async def verify_user(request: HTTPConnection): + """验证用户是否在Session中""" + try: + session_id = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + session_id = auth_header.split(" ", 1)[1] + elif "ECSESSION" in request.cookies: + session_id = request.cookies["ECSESSION"] + if session_id is None: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + except: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + if not (await SessionManager.verify_user(session_id)): + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + + +async def get_user_sub(request: HTTPConnection) -> uuid: + """从Session中获取用户""" + # if config["DEBUG"]: + # await UserManager.add_user((await Convertor.convert_user_sub_to_user_entity('admin'))) + # return "admin" + try: + session_id = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + session_id = auth_header.split(" ", 1)[1] + elif "ECSESSION" in request.cookies: + session_id = request.cookies["ECSESSION"] + except: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + user_sub = await SessionManager.get_user_sub(session_id) + if not user_sub: + raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + retcode=401, rtmsg="Authentication Error.", data="") + return user_sub diff --git a/deepinsight/service/auth/user_dependency.py b/deepinsight/service/auth/user_dependency.py new file mode 100644 index 0000000..16f6fe2 --- /dev/null +++ b/deepinsight/service/auth/user_dependency.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import os +from fastapi import Depends +from starlette.requests import HTTPConnection + +from deepinsight.service.auth.session_manager import SessionManager +from deepinsight.service.auth.session_service import UserHTTPException + + +async def get_current_user(request: HTTPConnection): + """ + 获取当前用户的依赖项 + + 此函数会先检查环境变量是否需要认证,不需要则直接返回admin + 需要认证时会验证用户身份,并将用户信息添加到请求对象中 + 可作为依赖项在路由处理函数中使用 + """ + try: + # 检查环境变量是否需要认证鉴权,默认需要 + need_auth = os.getenv("REQUIRE_AUTHENTICATION", "true").lower() != "false" + + # 如果不需要认证,直接返回默认用户admin + if not need_auth: + request.state.user = "admin" + return "admin" + + # 从请求头或cookie中获取session_id + session_id = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + session_id = auth_header.split(" ", 1)[1] + elif "ECSESSION" in request.cookies: + session_id = request.cookies["ECSESSION"] + + if not session_id: + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg="Authentication Error: No session ID found", + data="" + ) + + # 验证用户会话 + if not await SessionManager.verify_user(session_id): + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg="Authentication Error: Invalid session", + data="" + ) + + # 获取用户信息 + user_sub = await SessionManager.get_user_sub(session_id) + if not user_sub: + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg="Authentication Error: User not found", + data="" + ) + + # 将用户信息添加到请求状态中 + request.state.user = user_sub + return user_sub + + except Exception as e: + if isinstance(e, UserHTTPException): + raise e + raise UserHTTPException( + status_code=401, + retcode=401, + rtmsg=f"Authentication Error: {str(e)}", + data="" + ) + + +# 用于需要验证用户的路由的装饰器依赖 +require_user = Depends(get_current_user) \ No newline at end of file diff --git a/deepinsight/service/conversation.py b/deepinsight/service/conversation.py index 8b60b7e..fb2fef8 100644 --- a/deepinsight/service/conversation.py +++ b/deepinsight/service/conversation.py @@ -1,9 +1,17 @@ -import json +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. from typing import List -from deepinsight.db.config import get_db -from deepinsight.db.repositories.conversation_repository import ConversationRepository -from deepinsight.db.repositories.message_repository import MessageRepository +from deepinsight.stores.postgresql.database import get_database_session +from deepinsight.stores.postgresql.repositories.conversation_repository import ConversationRepository +from deepinsight.stores.postgresql.repositories.message_repository import MessageRepository from deepinsight.service.deep_research import DeepResearchService from deepinsight.service.schemas.chat import ServiceMessage @@ -42,7 +50,8 @@ class ConversationService: # 示例用法 if __name__ == "__main__": - db_generator = get_db() + + db_generator = get_database_session() db_session = next(db_generator) service = ConversationService(db_session) diff --git a/deepinsight/service/deep_research.py b/deepinsight/service/deep_research.py index 5d925d0..2240b1e 100644 --- a/deepinsight/service/deep_research.py +++ b/deepinsight/service/deep_research.py @@ -24,12 +24,12 @@ from deepinsight.core.types.agent import AgentType, AgentMessageAdditionType from deepinsight.core.types.historical_message import HistoricalMessage, HistoricalMessageType from deepinsight.core.types.messages import StartMessage, ChunkMessage, EndMessage, MessageMetadataKey, CompleteMessage, \ ErrorMessage, Message as CoreMessage -from deepinsight.db.config import get_db -from deepinsight.db.repositories.conversation_repository import ConversationRepository -from deepinsight.db.repositories.message_repository import MessageRepository -from deepinsight.db.repositories.report_repository import ReportRepository -from deepinsight.db.schemas.message import Message as DBSchemaMessage -from deepinsight.db.schemas.report import Report as DBSchemaReport +from deepinsight.stores.postgresql.database import get_database_session +from deepinsight.stores.postgresql.repositories.conversation_repository import ConversationRepository +from deepinsight.stores.postgresql.repositories.message_repository import MessageRepository +from deepinsight.stores.postgresql.repositories.report_repository import ReportRepository +from deepinsight.stores.postgresql.schemas import report as DBSchemaReport +from deepinsight.stores.postgresql.schemas import message as DBSchemaMessage from deepinsight.service.schemas.chat import ServiceMessage AGENT_PROCESS_TIPS_TEMPLATE = { @@ -102,7 +102,7 @@ class DeepResearchService: user_id: str, ) -> Generator[List[ServiceMessage], None, None]: # Initialize repositories with fresh session - db = next(get_db()) + db = get_database_session() try: conversation_repo = ConversationRepository(db) message_repo = MessageRepository(db) @@ -167,7 +167,7 @@ class DeepResearchService: @classmethod def get_report_and_thought_by_message_id(cls, message_id: str): - db = next(get_db()) + db = get_database_session() try: report_repo = ReportRepository(db) report_and_thought_data = report_repo.get_by_message_id(message_id) diff --git a/deepinsight/service/schemas/conversation.py b/deepinsight/service/schemas/conversation.py new file mode 100644 index 0000000..2b4da73 --- /dev/null +++ b/deepinsight/service/schemas/conversation.py @@ -0,0 +1,73 @@ +from typing import Optional, List + +from pydantic import BaseModel, Field + +from deepinsight.service.schemas.model import LLMIteam, KbIteam + + +class ConversationListItem(BaseModel): + conversationId: str + title: str + createdTime: str + type: str = "normal" + + +class ConversationListMsg(BaseModel): + conversations: List[ConversationListItem] + + +class ConversationListRsp(BaseModel): + code: int + message: str + data: ConversationListMsg + + +class AddConversationMsg(BaseModel): + conversationId: str + created_time: str + + +class AddConversationRsp(BaseModel): + code: int + message: str + data: AddConversationMsg + + +class BodyAddConversation(BaseModel): + user_id: str = "empty" + title: Optional[str] = "" + + +class ModifyConversationData(BaseModel): + title: str = Field(..., min_length=1, max_length=2000) + + +class UpdateConversationRsp(BaseModel): + code: int + message: str + data: ConversationListItem + + +class DeleteConversationData(BaseModel): + conversation_list: List[str] + + +class ResponseData(BaseModel): + code: int + message: str + data: Optional[dict] = {} + + +class RenameConversationData(BaseModel): + conversation_id: str + new_name: str + ori_name: Optional[str] = "default name" + + +class BodyGetList(BaseModel): + user_id: Optional[str] = "test_user" + offset: Optional[int] = 0 + limit: Optional[int] = 100 + + + diff --git a/deepinsight/service/schemas/model.py b/deepinsight/service/schemas/model.py new file mode 100644 index 0000000..3459d0e --- /dev/null +++ b/deepinsight/service/schemas/model.py @@ -0,0 +1,15 @@ +# ====== 数据模型定义 ====== +from typing import Optional + +from pydantic import BaseModel + + +class LLMIteam(BaseModel): + icon: Optional[str] = "" + llmId: str = "empty" + modelName: str = "Ollama LLM" + + +class KbIteam(BaseModel): + kbId: str + kbName: str \ No newline at end of file diff --git a/deepinsight/db/__init__.py b/deepinsight/stores/__init__.py similarity index 100% rename from deepinsight/db/__init__.py rename to deepinsight/stores/__init__.py diff --git a/deepinsight/stores/mongodb/__init__.py b/deepinsight/stores/mongodb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepinsight/stores/mongodb/database.py b/deepinsight/stores/mongodb/database.py new file mode 100644 index 0000000..c2cd380 --- /dev/null +++ b/deepinsight/stores/mongodb/database.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +from __future__ import annotations + +import logging +import os + +from pydantic import BaseModel, Field +from datetime import datetime +from pymongo import AsyncMongoClient +from typing import TYPE_CHECKING, Optional +import uuid + + +class Session(BaseModel): + """ + Session + + collection: session + """ + + id: str = Field(alias="_id") + ip: str + user_sub: Optional[str] = Field(default=None) + nonce: Optional[str] = Field(default=None) + expired_at: datetime + + +class Task(BaseModel): + """ + collection: witchiand_task + """ + + task_id: uuid.UUID = Field(alias="_id") + status: str + created_time: datetime = Field(default_factory=datetime.now) + + +if TYPE_CHECKING: + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.collection import AsyncCollection + + +class MongoDB: + """MongoDB连接""" + + user = os.getenv('MONGODB_USER', 'admin') + password = os.getenv('MONGODB_PASSWORD', '') + host = os.getenv('MONGODB_HOST', 'localhost') + port = os.getenv('MONGODB_PORT', 27017) + _client: AsyncMongoClient = AsyncMongoClient( + f"mongodb://{user}:{password}@{host}:{port}/?directConnection=true&replicaSet=rs0", + uuidRepresentation="standard" + ) + + @classmethod + def get_collection(cls, collection_name: str) -> AsyncCollection: + """获取MongoDB集合(表)""" + try: + return cls._client[os.getenv('MONGODB_DATABASE', '')][collection_name] + except Exception as e: + logging.exception("[MongoDB] 获取集合 %s 失败", collection_name) + raise RuntimeError(str(e)) from e + + @classmethod + async def clear_collection(cls, collection_name: str) -> None: + """清空MongoDB集合(表)""" + try: + await cls._client[os.getenv('MONGODB_DATABASE', '')][collection_name].delete_many({}) + except Exception: + logging.exception("[MongoDB] 清空集合 %s 失败", collection_name) + + @classmethod + def get_session(cls) -> AsyncClientSession: + """获取MongoDB会话""" + return cls._client.start_session() diff --git a/deepinsight/stores/postgresql/__init__.py b/deepinsight/stores/postgresql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepinsight/db/config.py b/deepinsight/stores/postgresql/database.py similarity index 44% rename from deepinsight/db/config.py rename to deepinsight/stores/postgresql/database.py index 3f0881c..36ca500 100644 --- a/deepinsight/db/config.py +++ b/deepinsight/stores/postgresql/database.py @@ -1,12 +1,19 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +import os + +from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -import os -from dotenv import load_dotenv -os.environ['POSTGRES_CONN_STR'] = 'postgresql+psycopg2://xuhonggong:@127.0.0.1:5432/postgres' -os.environ['DB_TYPE'] = 'postgresql' -# 加载环境变量 load_dotenv() @@ -16,25 +23,31 @@ class DatabaseConfig: def __init__(self, db_type: str = None, connection_string: str = None): """ 初始化数据库配置 - + :param db_type: 数据库类型,如"postgresql"或"sqlite" :param connection_string: 数据库连接字符串 """ # 优先使用传入的参数,否则从环境变量获取,最后使用默认值 - self.db_type = db_type or os.getenv("DB_TYPE", "postgresql") + self.db_type = db_type or os.getenv("DB_TYPE", "") if connection_string: self.connection_string = connection_string else: if self.db_type == "postgresql": - self.connection_string = os.getenv( - "POSTGRES_CONN_STR", - "postgresql+psycopg2://postgres:postgres@localhost:5432/chat_db" - ) + # 从环境变量中获取数据库连接信息 + db_user = os.getenv('POSTGRES_USER', 'default_user') + db_password = os.getenv('POSTGRES_PASSWORD', 'default_password') + db_host = os.getenv('POSTGRES_HOST', 'localhost') + db_port = os.getenv('POSTGRES_PORT', '5432') + db_name = os.getenv('POSTGRES_DB', 'postgres') + + # 构建数据库连接字符串 + self.connection_string = f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" + elif self.db_type == "sqlite": self.connection_string = os.getenv( "SQLITE_CONN_STR", - "sqlite:///./chat_db.db" + "sqlite:///./chat_db.stores" ) else: raise ValueError(f"不支持的数据库类型: {self.db_type}") @@ -53,14 +66,27 @@ engine = create_engine( # 创建会话 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -# 声明基类 -Base = declarative_base() +# SQLAlchemy模型基类 +# 所有数据模型都应继承此类,用于表结构定义 +DatabaseModel = declarative_base() + +# 数据库会话工厂 +DatabaseSession = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def get_database_session(): + """ + 获取数据库会话的生成器函数 + Yields: + Session: 数据库会话对象 -def get_db(): - """获取数据库会话的依赖项""" - db = SessionLocal() + Note: + 使用FastAPI等框架时,可作为依赖项注入 + 会话会在使用后自动关闭 + """ + session = DatabaseSession() try: - yield db + yield session finally: - db.close() + session.close() diff --git a/deepinsight/db/repositories/__init__.py b/deepinsight/stores/postgresql/repositories/__init__.py similarity index 100% rename from deepinsight/db/repositories/__init__.py rename to deepinsight/stores/postgresql/repositories/__init__.py diff --git a/deepinsight/db/repositories/base_repository.py b/deepinsight/stores/postgresql/repositories/base_repository.py similarity index 77% rename from deepinsight/db/repositories/base_repository.py rename to deepinsight/stores/postgresql/repositories/base_repository.py index 3923e86..33c2dce 100644 --- a/deepinsight/db/repositories/base_repository.py +++ b/deepinsight/stores/postgresql/repositories/base_repository.py @@ -1,3 +1,12 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. from sqlalchemy.orm import Session from typing import Generic, TypeVar, List, Optional, Type diff --git a/deepinsight/db/repositories/conversation_repository.py b/deepinsight/stores/postgresql/repositories/conversation_repository.py similarity index 75% rename from deepinsight/db/repositories/conversation_repository.py rename to deepinsight/stores/postgresql/repositories/conversation_repository.py index f48398d..33a869c 100644 --- a/deepinsight/db/repositories/conversation_repository.py +++ b/deepinsight/stores/postgresql/repositories/conversation_repository.py @@ -1,8 +1,17 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. from sqlalchemy.orm import Session from typing import List, Optional -from deepinsight.db.repositories.base_repository import BaseRepository -from deepinsight.db.schemas.conversation import Conversation +from deepinsight.stores.postgresql.repositories.base_repository import BaseRepository +from deepinsight.stores.postgresql.schemas.conversation import Conversation class ConversationRepository(BaseRepository[Conversation]): diff --git a/deepinsight/db/repositories/message_repository.py b/deepinsight/stores/postgresql/repositories/message_repository.py similarity index 78% rename from deepinsight/db/repositories/message_repository.py rename to deepinsight/stores/postgresql/repositories/message_repository.py index 5ae9b6a..7c6734e 100644 --- a/deepinsight/db/repositories/message_repository.py +++ b/deepinsight/stores/postgresql/repositories/message_repository.py @@ -1,9 +1,17 @@ -from sqlalchemy import desc +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. from sqlalchemy.orm import Session from typing import List, Optional -from deepinsight.db.repositories.base_repository import BaseRepository -from deepinsight.db.schemas.message import Message +from deepinsight.stores.postgresql.repositories.base_repository import BaseRepository +from deepinsight.stores.postgresql.schemas.message import Message class MessageRepository(BaseRepository[Message]): diff --git a/deepinsight/db/repositories/report_repository.py b/deepinsight/stores/postgresql/repositories/report_repository.py similarity index 72% rename from deepinsight/db/repositories/report_repository.py rename to deepinsight/stores/postgresql/repositories/report_repository.py index 7fdcb08..6c2a30f 100644 --- a/deepinsight/db/repositories/report_repository.py +++ b/deepinsight/stores/postgresql/repositories/report_repository.py @@ -1,8 +1,17 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. from sqlalchemy.orm import Session from typing import List, Optional -from deepinsight.db.repositories.base_repository import BaseRepository -from deepinsight.db.schemas.report import Report +from deepinsight.stores.postgresql.repositories.base_repository import BaseRepository +from deepinsight.stores.postgresql.schemas.report import Report class ReportRepository(BaseRepository[Report]): diff --git a/deepinsight/stores/postgresql/schemas/__init__.py b/deepinsight/stores/postgresql/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepinsight/db/schemas/conversation.py b/deepinsight/stores/postgresql/schemas/conversation.py similarity index 58% rename from deepinsight/db/schemas/conversation.py rename to deepinsight/stores/postgresql/schemas/conversation.py index 21de321..cc6e1c2 100644 --- a/deepinsight/db/schemas/conversation.py +++ b/deepinsight/stores/postgresql/schemas/conversation.py @@ -1,12 +1,21 @@ -from sqlalchemy import Column, String, DateTime, Text +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy import Column, String, DateTime from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql import func import uuid -from deepinsight.db.config import Base +from deepinsight.stores.postgresql.database import DatabaseModel -class Conversation(Base): +class Conversation(DatabaseModel): """对话表模型""" __tablename__ = "conversation" diff --git a/deepinsight/db/schemas/message.py b/deepinsight/stores/postgresql/schemas/message.py similarity index 67% rename from deepinsight/db/schemas/message.py rename to deepinsight/stores/postgresql/schemas/message.py index 4349f1f..8409f3a 100644 --- a/deepinsight/db/schemas/message.py +++ b/deepinsight/stores/postgresql/schemas/message.py @@ -1,14 +1,22 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. from sqlalchemy import Column, String, DateTime, Text, ForeignKey from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql import func from sqlalchemy.orm import relationship import uuid -from deepinsight.db.config import Base -from .report import Report # 保证先定义Report,再定义Message +from deepinsight.stores.postgresql.database import DatabaseModel -class Message(Base): +class Message(DatabaseModel): """消息表模型""" __tablename__ = "message" diff --git a/deepinsight/db/schemas/report.py b/deepinsight/stores/postgresql/schemas/report.py similarity index 64% rename from deepinsight/db/schemas/report.py rename to deepinsight/stores/postgresql/schemas/report.py index 71570cd..11f0e32 100644 --- a/deepinsight/db/schemas/report.py +++ b/deepinsight/stores/postgresql/schemas/report.py @@ -1,13 +1,22 @@ -from sqlalchemy import Column, String, DateTime, Text, ForeignKey +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +from sqlalchemy import Column, DateTime, Text, ForeignKey from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql import func from sqlalchemy.orm import relationship import uuid -from deepinsight.db.config import Base +from deepinsight.stores.postgresql.database import DatabaseModel -class Report(Base): +class Report(DatabaseModel): """报告表模型""" __tablename__ = "report" diff --git a/pyproject.toml b/pyproject.toml index 6d34389..cb49cf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,15 +3,37 @@ name = "deepInsight" version = "0.1.0" description = "" readme = "README.md" -packages = [{include = "deepinsight", from = '.'}] +# 明确指定需要包含的顶级包 +packages = [ + {include = "deepinsight"}, + {include = "integrations"} +] +# 排除非Python包目录 +exclude = ["License"] [tool.poetry.dependencies] python = ">=3.10,<3.13" camel-ai = ">=0.2.60" pydantic = ">=2.10.6" rich = ">=14.0.0" -sqlalchemy = ">=2.0.41" -psycopg2 = ">=2.9.10" +fastapi = "0.116.1" +sqlalchemy = "2.0.41" +pymongo = "4.13.2" +python-dotenv = "*" # 新增环境变量管理依赖 + +# 跨平台可选依赖组 +[tool.poetry.group.win.dependencies] +psycopg2-binary = "2.9.9" # Windows平台专用 + +[tool.poetry.group.unix.dependencies] +psycopg2 = "2.9.9" # Linux/macOS平台专用 [tool.poetry.group.dev.dependencies] -pytest = "*" \ No newline at end of file +pytest = "*" +black = "*" +flake8 = "*" +mypy = "*" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file -- Gitee From 2d062e0df5c7e731b7489a8d122a8bb8e358ba3d Mon Sep 17 00:00:00 2001 From: Tech1024Wizard Date: Sat, 26 Jul 2025 22:31:54 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E3=80=90feat=E3=80=91=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=8E=A5=E5=85=A5=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E4=BB=A5=E5=8F=8A=E8=AE=A4=E8=AF=81=E9=89=B4=E6=9D=83?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=EF=BC=88=E5=AF=B9=E6=8E=A5openeuler=20intell?= =?UTF-8?q?gence=E6=9C=8D=E5=8A=A1=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepinsight/api/app.py | 5 +- deepinsight/service/conversation.py | 91 +++++++++++-------- deepinsight/service/schemas/conversation.py | 14 ++- .../repositories/conversation_repository.py | 6 +- 4 files changed, 69 insertions(+), 47 deletions(-) diff --git a/deepinsight/api/app.py b/deepinsight/api/app.py index 81bf6d7..9c03ce8 100644 --- a/deepinsight/api/app.py +++ b/deepinsight/api/app.py @@ -1,4 +1,5 @@ import os +import os import uuid from datetime import datetime from typing import Dict @@ -9,11 +10,11 @@ from fastapi.responses import StreamingResponse from deepinsight.service.conversation import ConversationService from deepinsight.service.deep_research import MessageType, DeepResearchService +from deepinsight.service.schemas.chat import GetChatHistoryStructure, GetChatHistoryRsp from deepinsight.service.schemas.conversation import (ConversationListRsp, ConversationListMsg, ConversationListItem, AddConversationRsp, BodyAddConversation, AddConversationMsg, ResponseData, DeleteConversationData, RenameConversationData, BodyGetList) -from deepinsight.service.schemas.chat import GetChatHistoryStructure, GetChatHistoryRsp # 读取环境变量中的 API 前缀 API_PREFIX = os.getenv("API_PREFIX", "") @@ -150,4 +151,4 @@ async def get_conversation_messages(conversation_id: str): return GetChatHistoryRsp(code=0, message="ok", data=new_data) -app_instance.include_router(router, prefix=API_PREFIX) \ No newline at end of file +app_instance.include_router(router, prefix=API_PREFIX) diff --git a/deepinsight/service/conversation.py b/deepinsight/service/conversation.py index fb2fef8..490bb6e 100644 --- a/deepinsight/service/conversation.py +++ b/deepinsight/service/conversation.py @@ -1,31 +1,63 @@ -# Copyright (c) 2025 Huawei Technologies Co. Ltd. -# deepinsight is licensed under Mulan PSL v2. -# You can use this software according to the terms and conditions of the Mulan PSL v2. -# You may obtain a copy of Mulan PSL v2 at: -# http://license.coscl.org.cn/MulanPSL2 -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, -# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, -# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -# See the Mulan PSL v2 for more details. from typing import List +from deepinsight.service.deep_research import DeepResearchService +from deepinsight.service.schemas.chat import ServiceMessage +from deepinsight.service.schemas.conversation import ConversationListItem from deepinsight.stores.postgresql.database import get_database_session from deepinsight.stores.postgresql.repositories.conversation_repository import ConversationRepository from deepinsight.stores.postgresql.repositories.message_repository import MessageRepository -from deepinsight.service.deep_research import DeepResearchService -from deepinsight.service.schemas.chat import ServiceMessage +from deepinsight.stores.postgresql.repositories.report_repository import ReportRepository class ConversationService: - def __init__(self, storage): - self.storage = storage - - def get_conversation_info(self, conversation_id_str: str): - repository = ConversationRepository(self.storage) + @classmethod + def get_list(cls, user_id, offset: int = 0, limit: int = 100): + db = get_database_session() + conversation_repo = ConversationRepository(db) + conversation_list = conversation_repo.get_by_user_id(user_id=user_id, offset=offset, limit=limit) + conv_item_list = [] + for conv in conversation_list: + conv_item_list.append(ConversationListItem( + conversationId=str(conv.conversation_id), + title=conv.title, + createdTime=str(conv.created_time) + ) + ) + return conv_item_list + + @classmethod + def del_conversation(cls, conversation_id): + db = get_database_session() + conversation_repo = ConversationRepository(db) + message_repo = MessageRepository(db) + report_repo = ReportRepository(db) + message_repo.delete_by_conversation_id(conversation_id=conversation_id) + report_repo.delete_by_conversation_id(conversation_id=conversation_id) + conversation_repo.delete_conversation(conversation_id=conversation_id) + + @classmethod + def rename_conversation(cls, conversation_id, new_name): + db = get_database_session() + conversation_repo = ConversationRepository(db) + return conversation_repo.update_title(conversation_id=conversation_id, new_title=new_name) + + @classmethod + def add_conversation(cls, user_id, title, conversation_id): + db = get_database_session() + conversation_repo = ConversationRepository(db) + saved_conversation = conversation_repo.create_conversation(user_id, title, conversation_id) + return saved_conversation + + @classmethod + def get_conversation_info(cls, conversation_id_str: str): + db = get_database_session + repository = ConversationRepository(db) return repository.get_by_id(conversation_id_str) - def get_history_messages(self, conversation_id_str: str) -> List[ServiceMessage]: - repository = MessageRepository(self.storage) + @classmethod + def get_history_messages(cls, conversation_id_str: str) -> List[ServiceMessage]: + db = get_database_session() + repository = MessageRepository(db) messages_from_db = repository.get_by_conversation_id(conversation_id_str) processed_messages = [] @@ -48,24 +80,5 @@ class ConversationService: return processed_messages -# 示例用法 -if __name__ == "__main__": - - db_generator = get_database_session() - db_session = next(db_generator) - - service = ConversationService(db_session) - - # 尝试获取一个存在的对话历史 - conv_id_present = 'e6653788-7aa0-489a-b08d-79b2c3fb0b76' - print(f"Fetching history for conversation ID: {conv_id_present}") - # try: - history_present = service.get_history_messages(conv_id_present) - for each in history_present: - print(each.model_dump_json()) - - # print(history_present) - # except ValueError as e: - # print(f"Error: {e}") - - print("\n" + "=" * 50 + "\n") +if __name__ == '__main__': + print(ConversationService.add_conversation(user_id="1")) diff --git a/deepinsight/service/schemas/conversation.py b/deepinsight/service/schemas/conversation.py index 2b4da73..6387562 100644 --- a/deepinsight/service/schemas/conversation.py +++ b/deepinsight/service/schemas/conversation.py @@ -1,9 +1,16 @@ +# Copyright (c) 2025 Huawei Technologies Co. Ltd. +# deepinsight is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. from typing import Optional, List from pydantic import BaseModel, Field -from deepinsight.service.schemas.model import LLMIteam, KbIteam - class ConversationListItem(BaseModel): conversationId: str @@ -34,7 +41,8 @@ class AddConversationRsp(BaseModel): class BodyAddConversation(BaseModel): - user_id: str = "empty" + conversation_id: str + user_id: str = "test_user" title: Optional[str] = "" diff --git a/deepinsight/stores/postgresql/repositories/conversation_repository.py b/deepinsight/stores/postgresql/repositories/conversation_repository.py index 33a869c..7361c8b 100644 --- a/deepinsight/stores/postgresql/repositories/conversation_repository.py +++ b/deepinsight/stores/postgresql/repositories/conversation_repository.py @@ -30,19 +30,19 @@ class ConversationRepository(BaseRepository[Conversation]): """ return self.db.query(Conversation).filter(Conversation.conversation_id == conversation_id).first() - def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List[Conversation]: + def get_by_user_id(self, user_id: str, offset: int = 0, limit: int = 100) -> List[Conversation]: """ 根据用户ID获取对话列表 :param user_id: 用户ID - :param skip: 跳过的记录数 + :param offset: 跳过的记录数 :param limit: 最大返回记录数 :return: 对话列表 """ return self.db.query(Conversation)\ .filter(Conversation.user_id == user_id)\ .order_by(Conversation.created_time.desc())\ - .offset(skip)\ + .offset(offset)\ .limit(limit)\ .all() -- Gitee From a22683f502bbd06e7d332e4c6142cf4d0a32caba Mon Sep 17 00:00:00 2001 From: Tech1024Wizard Date: Sat, 26 Jul 2025 22:55:16 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E3=80=90feat=E3=80=91=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=8E=A5=E5=85=A5=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E4=BB=A5=E5=8F=8A=E8=AE=A4=E8=AF=81=E9=89=B4=E6=9D=83?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=EF=BC=88=E5=AF=B9=E6=8E=A5openeuler=20intell?= =?UTF-8?q?gence=E6=9C=8D=E5=8A=A1=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 7 +- deepinsight/api/app.py | 221 ++++++++++++++++---------------- deepinsight/{app.py => main.py} | 0 deepinsight/server.py | 27 ++++ pyproject.toml | 2 + 5 files changed, 145 insertions(+), 112 deletions(-) rename deepinsight/{app.py => main.py} (100%) create mode 100644 deepinsight/server.py diff --git a/.env.example b/.env.example index 186b622..145189c 100644 --- a/.env.example +++ b/.env.example @@ -15,4 +15,9 @@ MONGODB_HOST= MONGODB_PORT= # 是否需要认证鉴权: deepInsight集成到openeuler intelligence依赖认证健全,如果单用户部署则不需要,默认用户为admin -REQUIRE_AUTHENTICATION= +REQUIRE_AUTHENTICATION=false + +# 默认监听地址和端口 +HOST=0.0.0.0 +PORT=8000 + diff --git a/deepinsight/api/app.py b/deepinsight/api/app.py index 9c03ce8..2be7a8d 100644 --- a/deepinsight/api/app.py +++ b/deepinsight/api/app.py @@ -1,154 +1,153 @@ import os -import os import uuid from datetime import datetime from typing import Dict -from fastapi import FastAPI, Request, APIRouter, HTTPException, Body, Depends -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse +from flask import Flask, request, jsonify, Response, stream_with_context +from flask_cors import CORS from deepinsight.service.conversation import ConversationService from deepinsight.service.deep_research import MessageType, DeepResearchService from deepinsight.service.schemas.chat import GetChatHistoryStructure, GetChatHistoryRsp -from deepinsight.service.schemas.conversation import (ConversationListRsp, ConversationListMsg, ConversationListItem, - AddConversationRsp, BodyAddConversation, AddConversationMsg, - ResponseData, DeleteConversationData, RenameConversationData, - BodyGetList) +from deepinsight.service.schemas.conversation import ( + ConversationListRsp, ConversationListMsg, ConversationListItem, + AddConversationRsp, AddConversationMsg +) # 读取环境变量中的 API 前缀 API_PREFIX = os.getenv("API_PREFIX", "") -# 创建 FastAPI 实例 -app_instance = FastAPI( - title="DeepInsight API", - description="A streaming chat API for DeepInsight", - version="1.0.0" -) +# 创建 Flask 实例 +app = Flask(__name__) _conversations: Dict[str, ConversationListItem] = {} -# 创建路由 -router = APIRouter() - -# 跨域中间件配置 -app_instance.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) + +# 跨域配置 +CORS(app, resources={r"/*": {"origins": "*"}}) -@router.post("/api/chat") -async def chat_stream(request: Request): - # try: - body = await request.json() - conversation_id = body.get("conversation_id", "") or str(uuid.uuid4()) - conversation_info = ConversationService.get_conversation_info(conversation_id) - if not conversation_info: - raise HTTPException(status_code=404, detail="Conversation not found") - messages = body.get("messages", []) - if not isinstance(messages, list): - raise HTTPException(status_code=400, detail="messages must be a list") - query = None - for item in reversed(messages): - if item.get('role') == MessageType.USER.value and item.get("content", None): - query = item.get("content") - - async def fake_model_stream(): - for item in DeepResearchService.research(query=query, conversation_id=conversation_id, user_id=""): - resp = GetChatHistoryRsp( - code=0, - message="", - data=GetChatHistoryStructure( - conversation_id=conversation_id, - user_id=conversation_info.user_id, - created_time=str(datetime.now()), - title=conversation_info.title, # 原 name - status=conversation_info.status, - messages=item - ) - ).model_dump_json() - yield f"data: {resp}\n\n" - yield 'data: [DONE]\n\n' - - return StreamingResponse(fake_model_stream(), media_type="text/event-stream") - # except Exception as e: - # raise HTTPException(status_code=500, detail=str(e)) - - -# TODO -@router.get("/api/conversations", response_model=ConversationListRsp, tags=["conversation"]) -# 是否需要,如果需要是否只需要返回conversation id -async def get_conversation_list(body: BodyGetList = Depends()): +@app.route(f"{API_PREFIX}/api/chat", methods=['POST']) +def chat_stream(): try: - conversation_list = ConversationService.get_list(user_id=body.user_id, offset=body.offset, limit=body.limit) - return ConversationListRsp( + body = request.get_json() + conversation_id = body.get("conversation_id", "") or str(uuid.uuid4()) + conversation_info = ConversationService.get_conversation_info(conversation_id) + if not conversation_info: + return jsonify({"error": "Conversation not found"}), 404 + + messages = body.get("messages", []) + if not isinstance(messages, list): + return jsonify({"error": "messages must be a list"}), 400 + + query = None + for item in reversed(messages): + if item.get('role') == MessageType.USER.value and item.get("content", None): + query = item.get("content") + + def generate(): + uuid_str = str(uuid.uuid4()) + for item in DeepResearchService.research(query=query, conversation_id=conversation_id, user_id=""): + resp = GetChatHistoryRsp( + code=0, + message="", + data=GetChatHistoryStructure( + id=uuid_str, + conversation_id=conversation_id, + user_id=conversation_info.user_id, + created_time=str(datetime.now()), + title=conversation_info.title, + status=conversation_info.status, + messages=item + ) + ).model_dump_json() + yield f"data: {resp}\n\n" + yield 'data: [DONE]\n\n' + + return Response(stream_with_context(generate()), mimetype="text/event-stream") + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route(f"{API_PREFIX}/api/conversations", methods=['GET']) +def get_conversation_list(): + try: + user_id = "test_user" + offset = int(request.args.get('offset', 0)) + limit = int(request.args.get('limit', 100)) + + conversation_list = ConversationService.get_list(user_id=user_id, offset=offset, limit=limit) + response = ConversationListRsp( code=0, message="OK", data=ConversationListMsg(conversations=conversation_list) ) - + return jsonify(response.dict()) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + return jsonify({"error": str(e)}), 500 -# temporarily deprecated -@router.post("/api/conversation", response_model=AddConversationRsp, tags=["conversation"]) -async def add_conversation( - body: BodyAddConversation = Body(...) -): +@app.route(f"{API_PREFIX}/api/conversation", methods=['POST']) +def add_conversation(): try: - new_conversation = ConversationService.add_conversation(user_id=body.user_id, title=body.title, - conversation_id=body.conversation_id) + body = request.get_json() + new_conversation = ConversationService.add_conversation( + user_id="test_user", + title=body['title'], + conversation_id=body.get('conversation_id') + ) - return AddConversationRsp( + response = AddConversationRsp( code=0, message="OK", - data=AddConversationMsg(conversationId=str(new_conversation.conversation_id), - created_time=str(new_conversation.created_time)) + data=AddConversationMsg( + conversationId=str(new_conversation.conversation_id), + created_time=str(new_conversation.created_time) + ) ) + return jsonify(response.dict()) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + return jsonify({"error": str(e)}), 500 -@router.delete("/api/conversation", response_model=ResponseData, tags=["conversation"]) -async def delete_conversation(data: DeleteConversationData = Body(...)): +@app.route(f"{API_PREFIX}/api/conversation", methods=['DELETE']) +def delete_conversation(): try: - for cid in data.conversation_list: + data = request.get_json() + for cid in data['conversation_list']: ConversationService.del_conversation(conversation_id=cid) - return ResponseData(code=0, message="Deleted", data={}) - + return jsonify({"code": 0, "message": "Deleted", "data": {}}) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + return jsonify({"error": str(e)}), 500 -@router.put("/api/conversation", response_model=ResponseData, tags=["conversation"]) -async def rename_conversation(data: RenameConversationData = Body(...)): +@app.route(f"{API_PREFIX}/api/conversation", methods=['PUT']) +def rename_conversation(): try: - conversation, is_succeed = ConversationService.rename_conversation(conversation_id=data.conversation_id, - new_name=data.new_name) + data = request.get_json() + conversation, is_succeed = ConversationService.rename_conversation( + conversation_id=data['conversation_id'], + new_name=data['new_name'] + ) if is_succeed: - return ResponseData(code=0, message="Modified", data={"new_name": data.new_name}) + return jsonify({"code": 0, "message": "Modified", "data": {"new_name": data['new_name']}}) else: - return ResponseData(code=100, message="Conversation Not Found", data={"new_name": data.new_name}) + return jsonify({"code": 100, "message": "Conversation Not Found", "data": {"new_name": data['new_name']}}) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - + return jsonify({"error": str(e)}), 500 -@router.get("/api/conversations/{conversation_id}/messages", response_model=GetChatHistoryRsp, tags=["conversation"]) -async def get_conversation_messages(conversation_id: str): - conversation_info = ConversationService.get_conversation_info(conversation_id) - history_present = ConversationService.get_history_messages(conversation_id) - new_data = GetChatHistoryStructure( - conversation_id=conversation_id, - user_id=conversation_info.user_id, - created_time=str(conversation_info.created_time), - title=conversation_info.title, # 原 name - status=conversation_info.status, - messages=history_present - ) - return GetChatHistoryRsp(code=0, message="ok", data=new_data) - -app_instance.include_router(router, prefix=API_PREFIX) +@app.route(f"{API_PREFIX}/api/conversations//messages", methods=['GET']) +def get_conversation_messages(conversation_id): + try: + conversation_info = ConversationService.get_conversation_info(conversation_id) + history_present = ConversationService.get_history_messages(conversation_id) + new_data = GetChatHistoryStructure( + conversation_id=conversation_id, + user_id=conversation_info.user_id, + created_time=str(conversation_info.created_time), + title=conversation_info.title, + status=conversation_info.status, + messages=history_present + ) + return jsonify(GetChatHistoryRsp(code=0, message="ok", data=new_data).dict()) + except Exception as e: + return jsonify({"error": str(e)}), 500 \ No newline at end of file diff --git a/deepinsight/app.py b/deepinsight/main.py similarity index 100% rename from deepinsight/app.py rename to deepinsight/main.py diff --git a/deepinsight/server.py b/deepinsight/server.py new file mode 100644 index 0000000..80437eb --- /dev/null +++ b/deepinsight/server.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import os + +import uvicorn + +if __name__ == "__main__": + # 从环境变量获取配置,或使用默认值 + # 从环境变量获取配置,或使用默认值 + host = os.getenv("HOST", "0.0.0.0") # 从环境变量读取HOST,默认监听所有接口 + port = int(os.getenv("PORT", "8000")) # 从环境变量读取PORT,默认8000 + + # 配置日志 + log_config = uvicorn.config.LOGGING_CONFIG + log_config["formatters"]["access"]["fmt"] = '%(asctime)s - %(levelname)s - %(message)s' + log_config["formatters"]["default"]["fmt"] = '%(asctime)s - %(levelname)s - %(message)s' + + # 启动UVicorn服务器 + uvicorn.run( + "deepinsight.api.app:app_instance", + host=host, + port=port, + reload=True, # 开发时启用热重载 + log_config=log_config, + workers=1, # 生产环境可以增加worker数量 + access_log=True + ) diff --git a/pyproject.toml b/pyproject.toml index cb49cf0..de952b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ rich = ">=14.0.0" fastapi = "0.116.1" sqlalchemy = "2.0.41" pymongo = "4.13.2" +flask = "3.1.1" +flask_cors = "6.0.1" python-dotenv = "*" # 新增环境变量管理依赖 # 跨平台可选依赖组 -- Gitee