diff --git a/config.yaml b/config.yaml index b8275c54fd3643367c3c0eb66b748f0a00365030..3ae71a13e0cd2c14cf0f9782533fab3c0db50335 100644 --- a/config.yaml +++ b/config.yaml @@ -17,6 +17,27 @@ llms: max_tokens: 4096 timeout: 120 +# 文件存储服务(可选)默认使用本地存储且不提供 HTTP 访问 +file_storage: + type: local # 默认 local 模式(本地磁盘)。可选: s3 (AWS S3 API 兼容的 OBS 服务) + # map_rule: # 当 DeepInsight 与其他框架协同工作且共用存储服务,建议启用并编辑 map_rule 以防意外覆盖 + # kb_doc_image: + # # 使用 HTTP API 访问解析服务时,当前版本会自动按请求的img_base_url生成为 img_base_url/{bucket}/{object} 的图片链接 + # # 其中doc_id使用请求中的文件名(url encoded)作为值,img_base_url默认为workspace.resource_base_uri + # bucket: deepinsight # 如果该bucket不存在,自动创建时将会同时将其设置为允许匿名读 + # object: parsed-doc-imgs/{kb_id}/{doc_id}/{img_path} + # kb_doc_binary: + # bucket: deepinsight + # object: kb/original_files/{kb_id}/{doc_id}/{doc_name} + # report_image: + # bucket: deepinsight + # object: charts/report/{img_path} + # s3: # 当模式设置为 s3 时,以下三个参数为必要参数 + # endpoint: + # ak: + # sk: + # remote_access: false # 是否开放存储服务的远端访问,当前总是为False + # 通用工作路径配置(独立于 RAG) workspace: work_root: ./data @@ -28,6 +49,13 @@ workspace: # - base_url: 拼接 base_url,适合api模式,如:http://127.0.0.1:8888/api/v1/deepinsight/charts/image/chart_123.png image_path_mode: base_url image_base_url: http://127.0.0.1:8888/api/v1/deepinsight/charts/image + # 在 Markdown 中由 DeepInsight 生成的图片等超链接资源使用的 uri 前缀。 + # 对于本地运行模式,总是保持 ../../ + # 对于使用 S3 兼容的 OBS 且允许匿名 GET 请求,可以将其设置为 OBS 服务相对浏览器的地址 + # 由于 file_storage.remote_access 未启用,当前暂不支持使用本地存储且需要由 http 访问的场景 + resource_base_uri: "../../" +# resource_base_uri: "http://127.0.0.1:8888/api/v1/deepinsight/res/". # 浏览器使用自定义转发服务时 +# resource_base_uri: "http://127.0.0.1:9000/". # 浏览器直接访问OBS时 # RAG 相关工作路径配置 # 将作为所有 RAG 本地数据的前缀目录,例如: diff --git a/deepinsight/api/app.py b/deepinsight/api/app.py index 68e3cf8666705fe2817b5c8ce54a500c842f66c3..728ef78db8b0927d1e93d94166b317d8027744cd 100755 --- a/deepinsight/api/app.py +++ b/deepinsight/api/app.py @@ -9,6 +9,7 @@ # See the Mulan PSL v2 for more details. import argparse +import base64 import logging import os import re @@ -19,15 +20,17 @@ from urllib.parse import quote import dotenv import uvicorn -from fastapi import FastAPI, APIRouter, Header +from fastapi import FastAPI, APIRouter, Body, Header from fastapi.responses import HTMLResponse from fastapi.responses import FileResponse from starlette import status from deepinsight.config.config import load_config +from deepinsight.service.conference import ConferenceService from deepinsight.service.research.research import ResearchService from deepinsight.service.conference.paper_extractor import PaperExtractionService, PaperParseException from deepinsight.utils.log_utils import initRootLogger +from deepinsight.utils.file_storage import get_storage_impl from deepinsight.core.utils.research_utils import load_expert_config from deepinsight.service.schemas.common import ResponseModel from deepinsight.service.schemas.research import ResearchRequest, PPTGenerateRequest, PdfGenerateRequest @@ -58,6 +61,8 @@ config = load_config(args.config) research_service = ResearchService(config) paper_extract_service = PaperExtractionService(config) +conference_service = ConferenceService(config) +get_storage_impl(config) # 加载专家数据 experts = load_expert_config(args.expert_config) router = APIRouter(tags=["deepinsight"]) @@ -146,6 +151,66 @@ async def parse_paper_meta(request: ExtractPaperMetaRequest): return dict(error=str(e)) +@router.post("/deepinsight/paper/conference_meta") +async def get_conference_meta( + kb_id: str = Body(description="ID of knowledge base"), + kb_name: str = Body(description="Name of knowledge base. Currently should be in format 'conf_name+year'" + " such as 'CAD+2025'.") +): + """Get or create a conference of the specified knowledge base if it exists. + + If no conference refer to this knowledge base, create a new conference record by `kb_name`.""" + _ = kb_id # unsupported yet + split = kb_name.rsplit("+", 1) + if len(split) != 2 or not split[-1].isdigit(): + return dict(error="Only knowledge base named as 'CONF+year' such as 'CAD+2025' can use Paper parser. " + "Rename your database or select another document parser.") + conf_name = split[0] + year = int(split[1]) + try: + id_, fullname = await conference_service.get_or_create_conference(conf_name, year) + return dict(id=id_, fullname=fullname) + except conference_service.ConferenceQueryException as e: + return dict(error=str(e)) + + +@router.post("/deepinsight/paper/parse/binary") +async def parse_paper_binary( + filename: str = Body(), + binary: str = Body(description="File binary in Base64 format"), + conference_id: int = Body(), + external_kb_id: str = Body(description="Only for storage and generate image URL."), + from_page: int | None = Body(default=None, description="(todo) The first page index to parse (included). " + "`None` means the first page of the file."), + to_page: int | None = Body(default=None, description="(todo) The last page index to parse (included). " + "`None` means the last page of the file."), + img_base_url: str | None = Body( + default=None, + description="The prefix part of images in parsed doc. Default is the value of `workspace.resource_base_uri` " + "in config file (whose default value is '../../'.") +): + """Parse metadata (title, author, abstract, keywords and number of sections) from a paper binary file.""" + _ = from_page, to_page + binary = base64.b64decode(binary) + try: + doc, meta = await conference_service.ingest_single_paper( + conference_id=conference_id, kb_id_external=external_kb_id, filename=filename, + binary=binary, resource_prefix=img_base_url) + return dict( + title=meta.paper_title, + author_info=meta.author_info.model_dump(), + abstract=meta.abstract, + keywords=meta.keywords, + topic=meta.topic, + sections=[ + (chunk.page_content, chunk.page_content.lstrip().split("\n", 1)[0].strip(" \n#")) + for chunk in doc.text + ] + ) + except PaperParseException as e: + return dict(error=str(e)) + + @router.get("/deepinsight/experts") async def get_experts(type: Optional[str] = None): """ diff --git a/deepinsight/cli/commands/conference.py b/deepinsight/cli/commands/conference.py index 4b4d0632a668950427d4c8fcf51685d9e8b70c1b..3c037396c8c829cd2c21c8d79c5fa3ecc226b4de 100644 --- a/deepinsight/cli/commands/conference.py +++ b/deepinsight/cli/commands/conference.py @@ -113,6 +113,8 @@ Examples: def _get_service(self) -> ConferenceService: if self._service is None: config = self._get_config() + from deepinsight.utils.file_storage import get_storage_impl + get_storage_impl(config) self._service = ConferenceService(config) return self._service diff --git a/deepinsight/cli/commands/research.py b/deepinsight/cli/commands/research.py index 8c489ef79459783a7f4ba10f47a5056968d09e44..1db823fe713fba9034f5fbb5ecd10b6008655737 100644 --- a/deepinsight/cli/commands/research.py +++ b/deepinsight/cli/commands/research.py @@ -308,6 +308,8 @@ def run_expert_review(question: str, insight_service: ResearchService, conversat ) def run_insight(config: Config, gen_pdf: bool = True, initial_topic: str | None = None) -> int: + from deepinsight.utils.file_storage import get_storage_impl + get_storage_impl(config) insight_service = ResearchService(config) with Live(refresh_per_second=4, vertical_overflow="ellipsis") as live: live.console.print("[bold green]✅ DeepInsight CLI 已成功启动!输入 'exit' 或 'quit' 可退出程序。[/bold green]") diff --git a/deepinsight/config/config.py b/deepinsight/config/config.py index 5988955ca79ee7993975575731543f6538d9425c..aba5a60c65e9ea4fc090d3e0984abcf9840734bd 100644 --- a/deepinsight/config/config.py +++ b/deepinsight/config/config.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, Field from deepinsight.config.app_config import AppInfo from deepinsight.config.database_config import DatabaseConfig +from deepinsight.config.file_storage_config import FileStorageConfig from deepinsight.config.prompt_management_config import PromptManagementConfig from deepinsight.config.llm_config import LLMConfig from deepinsight.config.scenarios_config import ScenariosConfig @@ -49,6 +50,8 @@ class Config(BaseModel): description="General workspace path configuration", ) + file_storage: FileStorageConfig = Field(default_factory=FileStorageConfig) + CONFIG: Optional[Config] = None diff --git a/deepinsight/config/file_storage_config.py b/deepinsight/config/file_storage_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f070e3720ce2dfc42ddafb5c661c79ae7c14af2e --- /dev/null +++ b/deepinsight/config/file_storage_config.py @@ -0,0 +1,127 @@ +"""Configuration about how to store files referenced by Markdown text.""" +from enum import Enum +from typing import Annotated, Any, ClassVar, Type + +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator, ValidationError +from pydantic_core import ErrorDetails, InitErrorDetails + + +class StorageType(str, Enum): + LOCAL = "local" + """Storage on local disk.""" + S3_OBS = "s3" + """AWS S3 compatible OBS(Object Storage Service).""" + + +class _ConfigModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class ConfigS3(_ConfigModel): + """Configuration for AWS S3 compatible OBS(Object Storage Service) client.""" + endpoint: str + ak: SecretStr + sk: SecretStr + + +class ListenConfig(_ConfigModel): + """How to start an HTTP server to handle get file request. Currently, HTTPS is unsupported.""" + attach: bool = True + """Whether attach to deepinsight main service.""" + path_prefix: Annotated[str, Field(default_factory=lambda: ...)] + """If `attach` is `True`, default is '/resources'. Otherwise, default is '/' to compatible with S3 OBS.""" + + name: str = "DeepInsight file accessor" + """Server progress name. Only take efforts when `attach` is `False`.""" + host: str = None + """Server listen IP. Only take efforts when `attach` is `False`.""" + port: int = None + """Server listen Port. Only take efforts when `attach` is `False`.""" + + def model_post_init(self, context: Any, /) -> None: + if self.path_prefix is ...: + self.path_prefix = "/resources" if self.attach else "/" + +class ConfigLocal(_ConfigModel): + """Configuration for how to store files in disk.""" + + root_dir: str | None = None + """The base directory to store files. + If it's `None`, DeepInsight workspace root (`workspace.work_root`) is used. + """ + + def actual_root_dir(self, workspace_root: str) -> str: + return self.root_dir or workspace_root + + +class MappingItem(_ConfigModel): + """Specify how to map a storage request to OBS bucket name and filename. + + `bucket` and `object` are in Python str.format() style. Available keys differs from every usage. + """ + model_config = ConfigDict(frozen=True) + bucket: str + object: str + + +_MAPPING_AVAILABLE_KEYS: dict[str, tuple[tuple[str, ...], tuple[str, ...]]] = dict( + kb_doc_image=(("kb_id",), ("kb_id", "doc_id","img_path")), + kb_doc_binary=(("kb_id", "owner_type", "owner_id"), ("kb_id", "owner_type", "owner_id", "doc_id", "doc_name")), + report_image=((), ("img_path",)) +) + + +class ObsMappingConfig(_ConfigModel): + model_config = ConfigDict(frozen=True) + + kb_doc_image: MappingItem = MappingItem(bucket="rag_storage", object="{kb_id}/{doc_id}/{img_path}") + kb_doc_binary: MappingItem = MappingItem(bucket="original_files", object="{owner_type}/{owner_id}/{doc_name}") + report_image: MappingItem = MappingItem(bucket="charts", object="{img_path}") + + @model_validator(mode="after") + def _check_mapping_keys(self): + errors: list[tuple[str, tuple, str]] = [] + for rule_name, key_rules in _MAPPING_AVAILABLE_KEYS.items(): + mapping: MappingItem = getattr(self, rule_name) + for field, keys in zip(("bucket", "object"), key_rules): + try: + getattr(mapping, field).format(**{k: "" for k in keys}) + except KeyError as e: + key_msg = "', '".join(keys) + errors.append((f"Rule has a unsupported key {e}. Available: '{key_msg}'.", + (rule_name, field), mapping.bucket)) + except ValueError as e: + errors.append((str(e), (rule_name, field), mapping.bucket)) + if errors: + raise ValidationError.from_exception_data( + type(self).__name__, + [ + InitErrorDetails(loc=loc, type="value_error", input=inputs, ctx=dict(error=msg)) + for msg, loc, inputs in errors + ] + ) + return self + + +class FileStorageConfig(_ConfigModel): + type: StorageType = StorageType.LOCAL + s3: ConfigS3 | None = None + local: Annotated[ConfigLocal | None, Field(default_factory=ConfigLocal)] + remote_access: bool | ListenConfig = False + map_rule: Annotated[ObsMappingConfig, Field(default_factory=ObsMappingConfig)] + + _REQUIRED_FIELD_MAP: ClassVar[dict[StorageType, str]] = { + StorageType.LOCAL: "local", + StorageType.S3_OBS: "s3" + } + + def model_post_init(self, context: Any, /) -> None: + if self.remote_access is True: + self.remote_access = ListenConfig() + + @model_validator(mode="after") + def _check_configs(self): + required_config = self._REQUIRED_FIELD_MAP[self.type] + if getattr(self, required_config) is None: + raise ValueError(f"For storage type '{self.type}', config field '{required_config}' is required.") + return self diff --git a/deepinsight/config/workspace_config.py b/deepinsight/config/workspace_config.py index bdc1c7fdee50f4d9badc5b154632a22d6a0bf1c9..d59b8f15c686ce6f35a67a978c38b0a70fcbeded 100644 --- a/deepinsight/config/workspace_config.py +++ b/deepinsight/config/workspace_config.py @@ -1,5 +1,5 @@ -from typing import Optional -from pydantic import BaseModel, Field +from typing import Optional, Annotated, Literal +from pydantic import BaseModel, Field, AnyHttpUrl class WorkspaceConfig(BaseModel): @@ -30,4 +30,11 @@ class WorkspaceConfig(BaseModel): conference_ppt_template_path: Optional[str] = Field( default=None, description="PPT 模板路径(用于会议洞察报告生成)", - ) \ No newline at end of file + ) + + resource_base_uri: Literal["../../"] | Annotated[str, AnyHttpUrl] = "../../" + """在 Markdown 中由 DeepInsight 生成的图片等超链接资源使用的 uri 前缀。 + + 对于本地运行模式,总是保持 ../../ + 对于需要由 http 访问的场景,则应当开启 file_storage.remote_access 且与其设置或其他可访问方式保持一致。 + """ diff --git a/deepinsight/service/conference/conference.py b/deepinsight/service/conference/conference.py index ca2c9068d2c9ec2ec16b35d431c9189364a60e4e..1531a44068be5c19e013ef85210d3088c2fedc29 100644 --- a/deepinsight/service/conference/conference.py +++ b/deepinsight/service/conference/conference.py @@ -9,17 +9,22 @@ # See the Mulan PSL v2 for more details. from __future__ import annotations +import asyncio import os +import random import shutil import logging from datetime import datetime from typing import List, Optional, Annotated -from pydantic import BaseModel, Field, ConfigDict, ValidationError, AnyHttpUrl +from pydantic import BaseModel, Field, ConfigDict, AnyHttpUrl +from sqlalchemy.orm import Session +from sqlalchemy.exc import IntegrityError from langchain_core.messages import HumanMessage from langchain.agents import create_agent from langchain.agents.structured_output import ToolStrategy +from deepinsight.service.rag.loaders.base import ParseResult from deepinsight.utils.file_utils import compute_md5 from deepinsight.databases.models.academic import Conference, Paper, PaperAuthorRelation, Author from deepinsight.databases.models.knowledge import KnowledgeBase @@ -49,9 +54,11 @@ from deepinsight.service.schemas.conference import ( ) from deepinsight.utils.progress import ProgressReporter from deepinsight.utils.llm_utils import init_langchain_models_from_llm_config -from deepinsight.service.conference.paper_extractor import PaperExtractionService -from deepinsight.service.schemas.paper_extract import ExtractPaperMetaRequest, ExtractPaperMetaFromDocsRequest, DocSegment +from deepinsight.service.conference.paper_extractor import PaperExtractionService, PaperParseException +from deepinsight.service.schemas.paper_extract import ExtractPaperMetaRequest, ExtractPaperMetaFromDocsRequest, \ + DocSegment, PaperMeta from deepinsight.core.agent.conference_research.conf_topic import get_conference_topics +from deepinsight.utils.file_storage.factory import get_storage_impl class ConferenceService: """ @@ -356,6 +363,58 @@ explores the interaction of computer systems with related areas such as computer await self._incremental_ingest_for_conference(kb, conf_id, req, reporter) return conf_id, kb.kb_id + async def get_or_create_conference(self, conf_name: str, year: int) -> tuple[int, str]: + with self._db.get_session() as db: # type: Session + conf: Conference = db.query(Conference).filter( + Conference.short_name == conf_name, Conference.year == year # type: ignore + ).first() + if conf: + return conf.conference_id, conf.full_name or conf_name + new_conf_meta = await self._query_conference_meta(conf_name, year) + max_retry = 3 + for retry in range(max_retry): + try: + with self._db.get_session() as db: # type: Session + conf = Conference( + full_name=new_conf_meta.full_name, + short_name=conf_name, + year=year, + website=new_conf_meta.website, + topics=new_conf_meta.topics, + ) + db.add(conf) + db.commit() + return conf.conference_id, conf.full_name or conf_name + except IntegrityError: + await asyncio.sleep(random.random() * 2 + 0.5) # retry with a random interval + continue + raise self.ConferenceQueryException("Try creating new conference with too many conflicts.") + + async def ingest_single_paper(self, conference_id: int, kb_id_external: str, + filename: str, binary: bytes, + resource_prefix: str = None) -> tuple[ParseResult, PaperMeta]: + """API implementation for HTTP server.""" + from deepinsight.service.rag.parsers.mineru_vl import MineruVLParser + from deepinsight.service.schemas.rag import DocumentPayload + + storage = get_storage_impl() + await storage.document_images_init_bucket(kb_id_external, set_allow_anonymous=True) + + if not resource_prefix: + resource_prefix = self._config.workspace.resource_base_uri or "../../" + if not resource_prefix.endswith("/"): + resource_prefix = resource_prefix + "/" + parser = MineruVLParser(self._config.rag.parser.mineru_vl) + doc_id = filename.replace("/", "_").replace("\\", "_") + parsed = await parser.parse( + DocumentPayload(doc_id=doc_id, filename=filename, binary_content=binary, + raw_text="", source_path=filename), + kb_id=kb_id_external, resource_prefix=resource_prefix) + text = "\n\n".join(chunk.page_content for chunk in parsed.result.text) + extract_request = ExtractPaperMetaRequest(conference_id=conference_id, filename=filename, paper=text) + response = await self._paper_extractor.extract_and_store(extract_request) + return parsed.result, response.full_meta + async def _reparse_unfinished_docs_for_conference(self, kb_id: int, conference_id: int, reporter: Optional[ProgressReporter]) -> None: try: docs = await self._knowledge.retry_unfinished_docs(kb_id, reporter=reporter) diff --git a/deepinsight/service/conference/paper_extractor.py b/deepinsight/service/conference/paper_extractor.py index 5064f55971836a1495f8ad68ba292d69c71bdcda..4dd6a6c0093481f1f8d24b212d762a210e22097d 100644 --- a/deepinsight/service/conference/paper_extractor.py +++ b/deepinsight/service/conference/paper_extractor.py @@ -121,6 +121,7 @@ class PaperExtractionService: conference_id=conf_id, author_ids=author_ids, topic=paper_meta.topic, + full_meta=paper_meta, ) async def extract_and_store_from_documents(self, req: ExtractPaperMetaFromDocsRequest) -> ExtractPaperMetaResponse: @@ -154,6 +155,7 @@ class PaperExtractionService: conference_id=conf_id, author_ids=author_ids, topic=metadata.topic, + full_meta=metadata, ) # --------------------- Conference helpers --------------------- diff --git a/deepinsight/service/knowledge/knowledge.py b/deepinsight/service/knowledge/knowledge.py index 6fcae2abbc026cc138d1d393edeb547800e7f5fd..24972708699bfa1823a9ea3fab7407f8ada6dbf3 100644 --- a/deepinsight/service/knowledge/knowledge.py +++ b/deepinsight/service/knowledge/knowledge.py @@ -6,6 +6,7 @@ import hashlib from datetime import datetime from typing import List, Optional, Tuple +from deepinsight.utils.file_storage import get_storage_impl from deepinsight.utils.file_utils import compute_md5 from deepinsight.config.config import Config from deepinsight.databases.connection import Database @@ -119,15 +120,28 @@ class KnowledgeService: extracted_text: Optional[str] = None try: + binary = req.binary + if not binary: + with open(req.file_path, "rb") as f: + binary = f.read() + + # todo: put origin doc here + # await get_storage_impl().knowledge_file_put( + # kb.kb_id, kb.owner_type, kb.owner_id, str(doc.doc_id), + # doc.file_name or os.path.basename(doc.file_path), + # binary + # ) payload = DocumentPayload( doc_id=str(doc.doc_id), + filename=req.file_name or os.path.basename(req.file_path), + binary_content=binary, raw_text="", # let engine extract from source_path source_path=req.file_path, title=req.file_name or os.path.basename(req.file_path), hash=req.md5, origin="knowledge", ) - idx = await self._rag_engine.ingest_document(payload, working_dir) + idx = await self._rag_engine.ingest_document(payload, working_dir, req.kb_id) doc.parse_status = ( idx.process_status.value if hasattr(idx.process_status, "value") else idx.process_status ) or doc.parse_status @@ -297,7 +311,7 @@ class KnowledgeService: async def reparse_document(self, kb_id: int, doc_id: int) -> KnowledgeDocumentResponse: with self._db.get_session() as session: kb, working_dir = await self._get_or_create_rag_for_kb(session, kb_id) - doc = ( + doc: KnowledgeDocument = ( session.query(KnowledgeDocument) .filter(KnowledgeDocument.kb_id == kb_id, KnowledgeDocument.doc_id == doc_id) .first() @@ -310,15 +324,19 @@ class KnowledgeService: extracted_text: Optional[str] = None idx = None try: + binary = await get_storage_impl().knowledge_file_get(kb.kb_id, kb.owner_type, kb.owner_id, str(doc_id), + doc.file_name or os.path.basename(doc.file_path)) payload = DocumentPayload( doc_id=str(doc.doc_id), + filename=doc.file_name or os.path.basename(doc.file_path), + binary_content=binary, raw_text="", source_path=doc.file_path, title=doc.file_name or os.path.basename(doc.file_path), hash=doc.md5, origin="knowledge_retry", ) - idx = await self._rag_engine.ingest_document(payload, working_dir) + idx = await self._rag_engine.ingest_document(payload, working_dir, kb_id) doc.parse_status = ( idx.process_status.value if hasattr(idx.process_status, "value") else idx.process_status ) or doc.parse_status diff --git a/deepinsight/service/rag/backends/lightrag_backend.py b/deepinsight/service/rag/backends/lightrag_backend.py index 400bc4202f064a49f41fb4a9af1c976a6a1a0928..afa82c2b07fa4e75604535ebbe76da85304cf893 100644 --- a/deepinsight/service/rag/backends/lightrag_backend.py +++ b/deepinsight/service/rag/backends/lightrag_backend.py @@ -65,7 +65,7 @@ class LightRAGBackend(BaseRAGBackend): {"page_content": chunk.page_content, "metadata": getattr(chunk, "metadata", {})} for chunk in text_chunks ] - file_paths = parsed.file_paths or ([payload.source_path] if payload.source_path else None) + file_paths = parsed.file_paths or ([payload.source_path] if payload.source_path else payload.filename) await rag.ainsert([text], ids=[payload.doc_id], file_paths=file_paths) chunks_count = _estimate_chunks(text) diff --git a/deepinsight/service/rag/engine.py b/deepinsight/service/rag/engine.py index b59b5322da533bc0cafa851a904b267443837fda..65596f42f97e620571ca9b4980f4182850024158 100644 --- a/deepinsight/service/rag/engine.py +++ b/deepinsight/service/rag/engine.py @@ -7,6 +7,7 @@ import logging from langchain_core.documents import Document as LCDocument +import deepinsight.config.config as config_file from deepinsight.config.config import Config from deepinsight.config.rag_config import RAGEngineType, RAGParserType from deepinsight.service.rag.backends import ( @@ -41,13 +42,13 @@ class RAGEngine: async def ingest_document( self, doc: DocumentPayload, - working_dir: str, + working_dir: str, kb_id: int, make_knowledge_graph: bool | None = None, ) -> IndexResult: if not working_dir: raise ValueError("working_dir must not be empty") os.makedirs(working_dir, exist_ok=True) - parsed = await self._prepare_document(doc, working_dir) + parsed = await self._prepare_document(doc, kb_id) return await self._backend.ingest( doc, working_dir, @@ -217,7 +218,7 @@ class RAGEngine: return LlamaIndexParser(parser_cfg.llamaindex) return None - async def _prepare_document(self, doc: DocumentPayload, working_dir: str) -> LoaderOutput: + async def _prepare_document(self, doc: DocumentPayload, kb_id: int) -> LoaderOutput: if doc.raw_text and doc.raw_text.strip(): parse_result = ParseResult( text=[ @@ -227,12 +228,12 @@ class RAGEngine: ) ] ) - file_paths = [doc.source_path] if doc.source_path else None + file_paths = [doc.source_path] if doc.source_path else doc.filename return LoaderOutput(result=parse_result, file_paths=file_paths) if not self._parser: raise ValueError("Document parser not configured, raw_text missing.") - return await self._parser.parse(doc, working_dir) + return await self._parser.parse(doc, kb_id, config_file.CONFIG.workspace.resource_base_uri) __all__ = [ diff --git a/deepinsight/service/rag/parsers/base.py b/deepinsight/service/rag/parsers/base.py index 952be52c150b6bb2b55137ad0357cf6fe157957d..e53ec8b5c14254915372c838db3636c34ff2cfd1 100644 --- a/deepinsight/service/rag/parsers/base.py +++ b/deepinsight/service/rag/parsers/base.py @@ -10,6 +10,6 @@ class BaseDocumentParser(ABC): """Abstract parser interface to normalize document ingestion.""" @abstractmethod - async def parse(self, payload: DocumentPayload, working_dir: str) -> LoaderOutput: + async def parse(self, payload: DocumentPayload, kb_id: int, resource_prefix: str) -> LoaderOutput: raise NotImplementedError diff --git a/deepinsight/service/rag/parsers/llama_index.py b/deepinsight/service/rag/parsers/llama_index.py index ad48434ff7d030524882ca417a633aaff9b50ab4..91a536c6e86b7b9ec7028a02be65ef21a87d5cbc 100644 --- a/deepinsight/service/rag/parsers/llama_index.py +++ b/deepinsight/service/rag/parsers/llama_index.py @@ -21,7 +21,7 @@ class LlamaIndexParser(BaseDocumentParser): self._config = config self._file_extractor = self._init_file_extractors(config) - async def parse(self, payload: DocumentPayload, working_dir: str) -> LoaderOutput: + async def parse(self, payload: DocumentPayload, kb_id: int, resource_prefix: str) -> LoaderOutput: if not payload.source_path: raise ValueError("LlamaIndex parser requires payload.source_path to be provided") file_path = payload.source_path diff --git a/deepinsight/service/rag/parsers/mineru_vl.py b/deepinsight/service/rag/parsers/mineru_vl.py index f125cf4ef811200a46ba72d7c86eaa6a86156fec..61bc89290099b12705e51ddf2c0d0db5ebc81b22 100644 --- a/deepinsight/service/rag/parsers/mineru_vl.py +++ b/deepinsight/service/rag/parsers/mineru_vl.py @@ -1,13 +1,11 @@ from __future__ import annotations -import logging -import os - - import asyncio import base64 +import logging import re -from typing import Any, Dict, List, Optional +import os +import urllib.parse from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage @@ -18,7 +16,7 @@ from deepinsight.service.rag.loaders.base import ParseResult from deepinsight.service.rag.parsers.base import BaseDocumentParser from deepinsight.service.rag.types import LoaderOutput from deepinsight.service.schemas.rag import DocumentPayload - +from deepinsight.utils.file_storage import get_storage_impl class MineruVLParser(BaseDocumentParser): @@ -27,146 +25,56 @@ class MineruVLParser(BaseDocumentParser): def __init__(self, config: MineruParserConfig): self._config = config - async def parse(self, payload: DocumentPayload, working_dir: str) -> LoaderOutput: - if not payload.source_path: - raise ValueError("MinerU parser requires payload.source_path to be provided") - file_path = payload.source_path - if not os.path.isfile(file_path): - raise FileNotFoundError(file_path) - - parse_result = await _parse_file_content(file_path) - - # Process images if present - if getattr(parse_result, "images", None): - doc_name = payload.source_path or (payload.metadata or {}).get("source") or payload.doc_id - await _store_images( - working_dir, - payload.doc_id, - doc_name, - parse_result.images, - ) + async def parse(self, payload: DocumentPayload, kb_id: int | str, resource_prefix: str) -> LoaderOutput: + parse_result = await _parse_file_content(payload.filename, payload.binary_content) + + if parse_result.images: + await get_storage_impl().document_images_init_bucket(str(kb_id), exist_ok=True) + img_map_with_path = {f"images/{k}": v for k, v in parse_result.images.items()} + path_map = await get_storage_impl().document_images_store(str(kb_id), payload.doc_id, img_map_with_path) await _replace_image_link( - parse_result, + parse_result, path_map=path_map, replace_alt_text=bool(self._config.enable_vl if self._config else True), - prefix=os.path.join("..", "..", working_dir, payload.doc_id), + prefix=resource_prefix, ) - return LoaderOutput(result=parse_result, file_paths=[file_path]) + return LoaderOutput(result=parse_result, file_paths=[payload.filename]) -async def _parse_file_content(file_path: str) -> ParseResult: +async def _parse_file_content(filename: str, binary: bytes) -> ParseResult: """Parse file content using MinerU for office documents or LangChain loaders for other types.""" - ext = os.path.splitext(file_path.lower())[1] - doc_with_resource: ParseResult | None = None - docs = [] - try: - if ext in {".pdf", ".docx", ".doc", ".pptx", ".ppt"}: - try: - from deepinsight.service.rag.loaders.mineru_online import MinerUOnlineClient - - with open(file_path, mode="rb") as f: - doc_with_resource = await MinerUOnlineClient().process(os.path.basename(file_path), f.read()) - except Exception as e: - logging.error("Failed to parse %r using MinerU: %s", file_path, e) - docs = [] - elif ext in {".txt", ".md", ".markdown"}: - try: - from langchain_community.document_loaders import TextLoader - - loader = TextLoader(file_path, encoding="utf-8") - docs = loader.load() - except Exception: - docs = [] - elif ext == ".csv": - try: - from langchain_community.document_loaders import CSVLoader - - loader = CSVLoader(file_path) - docs = loader.load() - except Exception: - docs = [] - else: - try: - from langchain_community.document_loaders import TextLoader - - loader = TextLoader(file_path, encoding="utf-8") - docs = loader.load() - except Exception: - docs = [] - except Exception: - docs = [] - - if not (docs or doc_with_resource): - logging.warning("Extraction on file %s failed, fallback to plain text reader.", file_path) - text = _extract_text(file_path) - return ParseResult(text=[LCDocument(page_content=text, metadata={"source": file_path})]) - - return doc_with_resource or ParseResult(text=docs) - - -def _extract_text(file_path: str) -> str: - _, ext = os.path.splitext(file_path.lower()) - text_based_exts = {".txt", ".md", ".markdown"} - try: - if ext in text_based_exts: - with open(file_path, "r", encoding="utf-8", errors="ignore") as f: - return f.read() + ext = filename.lower().rsplit(".")[-1] + if ext in {"pdf", "docx", "doc", "pptx", "ppt"}: try: - import textract - - content = textract.process(file_path) - return content.decode("utf-8", errors="ignore") - except Exception: - with open(file_path, "rb") as f: - raw = f.read() - try: - return raw.decode("utf-8") - except Exception: - return raw.decode("utf-8", errors="ignore") - except Exception as e: - raise RuntimeError(f"Text extraction failed for {file_path}: {e}") from e - + from deepinsight.service.rag.loaders.mineru_online import MinerUOnlineClient + return await MinerUOnlineClient().process(filename, binary) + except Exception as e: + logging.error("Failed to parse %r using MinerU: %s", filename, e) + raise + text = _extract_text(ext, binary) + return ParseResult(text=[LCDocument(page_content=text, metadata={"source": filename})]) -async def _store_images(working_dir: str, doc_id: str, doc_name: str, images: dict[str, bytes]) -> None: - doc_name = os.path.basename(doc_name) - img_dir = os.path.join(working_dir, doc_id, "images") - logging.debug("Begin to store %d images to %r for document %r", len(images), img_dir, doc_name) - if not os.path.exists(img_dir): - os.makedirs(img_dir, exist_ok=True) - belong_file_path = os.path.join(img_dir, "belongs_to.txt") +def _extract_text(ext: str, binary: bytes) -> str: try: - with open(belong_file_path, mode="xt+", encoding="utf8") as belong_file: - belong_file.write(doc_name) - existed = "an unknown document" - except FileExistsError: - with open(belong_file_path, mode="rt", encoding="utf8") as belong_file: - existed = belong_file.read() - existed = existed if len(existed) < 256 else (existed[:256] + "...") - existed = f"document named {existed!r}" - - for filename, content in images.items(): - file_path = os.path.join(img_dir, filename) + return binary.decode("utf8") + except UnicodeDecodeError: + pass + if ext in {"txt", "md", "markdown"}: try: - with open(file_path, mode="xb") as f: - f.write(content) - continue - except FileExistsError: + return binary.decode("gb2312") + except Exception: # noqa: fallback pass - logging.debug("Begin to store image for %r to %r", doc_name, file_path) - with open(file_path, mode="wb+") as f: - if f.read() != content: - logging.warning("Image %s already exists for %s, overwrite.", file_path, existed) - f.write(content) - else: - logging.debug("Image %s already exists for %s with same content, skip.", file_path, existed) - logging.debug("End to store %d images to %r for document %r", len(images), img_dir, doc_name) + return binary.decode("utf8", errors="ignore") + from deepinsight.service.conference.paper_extractor import PaperParseException + raise PaperParseException("Unsupported file. Please select another parser to parse this file.") -async def _replace_image_link(doc: ParseResult, replace_alt_text: bool = True, +async def _replace_image_link(doc: ParseResult, replace_alt_text: bool = True, path_map: dict[str, str] = None, vl: BaseChatModel = None, prefix: str = "") -> ParseResult: if not doc.images: return doc + path_map = path_map or {} if prefix and not prefix.endswith("/"): prefix += "/" used_images: dict[str, bytes] = {} @@ -187,6 +95,8 @@ async def _replace_image_link(doc: ParseResult, replace_alt_text: bool = True, replacement = {f"images/{k}": v for k, v in replacement.items() if v} if not replacement: logging.warning(f"All failure on creating image description for {len(used_images)} images.") + else: + logging.info("Try create %s image description and %s succeeded.", len(used_images), len(replacement)) else: replacement = {} @@ -196,7 +106,7 @@ async def _replace_image_link(doc: ParseResult, replace_alt_text: bool = True, return m.group() new_alt = replacement[url] if url in replacement else m.group(1) - new_path = f"{prefix}{url}" + new_path = f"{prefix}{urllib.parse.quote(path_map.get(url, url), safe='/')}" return f"![{new_alt}]({new_path})" for chunk in doc.text: diff --git a/deepinsight/service/schemas/knowledge.py b/deepinsight/service/schemas/knowledge.py index b44e9e4850e703e045a85088510f5eb190f89173..0ed10685520b423094b693b8a98a38513c241068 100644 --- a/deepinsight/service/schemas/knowledge.py +++ b/deepinsight/service/schemas/knowledge.py @@ -23,6 +23,7 @@ class KnowledgeBaseCreateRequest(BaseModel): class KnowledgeDocumentCreateRequest(BaseModel): kb_id: int file_path: str + binary: bytes | None = None file_name: Optional[str] = None md5: Optional[str] = None diff --git a/deepinsight/service/schemas/paper_extract.py b/deepinsight/service/schemas/paper_extract.py index 763cba1afaa71201f46e7dc52f9e9ba65837dd1c..61003567daaed7dff396d6852a57c29a74af21c7 100644 --- a/deepinsight/service/schemas/paper_extract.py +++ b/deepinsight/service/schemas/paper_extract.py @@ -86,3 +86,5 @@ class ExtractPaperMetaResponse(BaseModel): conference_id: int = Field(..., description="ID of the conference") author_ids: List[int] = Field(default_factory=list, description="List of author IDs") topic: Optional[str] = Field(None, description="Main topic of the paper") + full_meta: PaperMeta + """All extracted information of this paper.""" diff --git a/deepinsight/service/schemas/rag.py b/deepinsight/service/schemas/rag.py index 880a3af4572891b88fe0e3c957a9b3bad460f816..f52611386e2ebba7759110cc8f250fe1ddfbb01f 100644 --- a/deepinsight/service/schemas/rag.py +++ b/deepinsight/service/schemas/rag.py @@ -11,6 +11,7 @@ class DocumentPayload(BaseModel): Fields: - doc_id: unique document id (idempotency key) + - binary_content: document binary content - raw_text: plain text content - source_path: original file path (optional) - title: optional title @@ -20,8 +21,11 @@ class DocumentPayload(BaseModel): """ doc_id: str = Field(..., description="Unique document ID") + filename: str + binary_content: bytes + """The raw binary of document file.""" raw_text: str = Field(..., description="Document plain text") - source_path: Optional[str] = Field(None, description="Original file path") + source_path: Optional[str] = Field(None, description="Depreciated. Original file path") title: Optional[str] = Field(None, description="Title") hash: Optional[str] = Field(None, description="Content hash") origin: Optional[str] = Field(None, description="Source tag") diff --git a/deepinsight/utils/file_storage/__init__.py b/deepinsight/utils/file_storage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18142932818348f9876c1efd50835873fce8fccf --- /dev/null +++ b/deepinsight/utils/file_storage/__init__.py @@ -0,0 +1,2 @@ +from deepinsight.utils.file_storage.base import StorageOp, StorageError, BaseFileStorage +from deepinsight.utils.file_storage.factory import get_storage_impl diff --git a/deepinsight/utils/file_storage/base.py b/deepinsight/utils/file_storage/base.py new file mode 100644 index 0000000000000000000000000000000000000000..22084966242952e6a76d15fdbd0ba6c735f81a4f --- /dev/null +++ b/deepinsight/utils/file_storage/base.py @@ -0,0 +1,182 @@ +"""Interface definition (compatible with AWS S3 OBS) and storage mapping definition for any files.""" +__all__ = ["StorageError", "StorageOp", "BaseFileStorage"] + +import asyncio +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import Type, TypeVar, TYPE_CHECKING + +from pydantic import BaseModel, PrivateAttr + +from deepinsight.config.file_storage_config import ObsMappingConfig + +if TYPE_CHECKING: + from deepinsight.config.config import Config +else: + from pydantic import BaseModel as Config + + +_Self = TypeVar("_Self") +logger = logging.getLogger(__name__) + + +class StorageOp(str, Enum): + CREATE = "create" + DELETE = "delete" + GET = "get" + LIST = "list" + CONFIG = "config" + + +class StorageError(RuntimeError): + + class Reason(str, Enum): + BUCKET_NOT_FOUND = "bucket_not_found" + FILE_NOT_FOUND = "file_not_found" + ALREADY_EXISTS = "already_exists" + PERMISSION = "permission_denied" + SPACE_LIMITED = "space_limited" + NETWORK = "network_error" + NAME_ILLEGAL = "name_illegal" + BUCKET_NOT_EMPTY = "bucket_not_empty" + OTHER = "other" + + op: StorageOp + bucket: str + filename: str | None + """May be a file prefix""" + reason: Reason + + def __init__(self, op: StorageOp, bucket: str, filename: str | None = None, *, reason: Reason): + self.op = op + self.bucket = bucket + self.filename = filename + self.reason = reason + if filename: + task = f"{op.value} object {filename!r} on bucket {bucket!r}" + else: + task = f"{op.value} bucket {bucket!r}" + super().__init__(f"Storage subsystem failed with code {reason.value!r} when going to {task}.") + + +class BaseFileStorage(ABC, BaseModel): + """ + Defines these necessary interfaces (all subclass should implement these methods): + - Create a bucket. + - List buckets. + - List files in specified bucket (can with prefix). + - Add / Get / Delete a file from specified bucket. + + Implements useful methods: + - Store images for a document. + - Store images for a report. + """ + keymap: ObsMappingConfig = ObsMappingConfig() + _warned_unsupported_method: set[str] = PrivateAttr(default_factory=set) + + def __aenter__(self): + return self + + def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + @classmethod + @abstractmethod + def from_config(cls: Type[_Self], config: "Config") -> _Self: + raise NotImplementedError(f"{cls.__name__}.from_config") + + # bucket operations begin + @abstractmethod + async def bucket_create(self, bucket: str, *, exist_ok: bool = False) -> bool: + """Return True if actually create this bucket and False if bucket already exists.""" + raise NotImplementedError("bucket_create") + + @abstractmethod + async def list_buckets(self) -> list[str]: + raise NotImplementedError("list_buckets") + + @abstractmethod + async def list_files(self, bucket: str, prefix: str = None) -> list[str]: + raise NotImplementedError("list_files") + + # file operations begin + @abstractmethod + async def file_add(self, bucket: str, filename: str, content: bytes) -> None: + raise NotImplementedError("file_add") + + @abstractmethod + async def file_delete(self, bucket: str, filename: str, allow_not_exists: bool = True) -> None: + raise NotImplementedError("file_delete") + + @abstractmethod + async def file_get(self, bucket: str, filename: str) -> bytes: + raise NotImplementedError("file_get") + + # unnecessary interfaces definition + + async def bucket_allow_anonymous_get(self, bucket: str) -> None: + if "set_anonymous_get" not in self._warned_unsupported_method: + self._warned_unsupported_method.add("set_anonymous_get") + logger.warning(f"'{type(self).__name__}.allow_anonymous_get({bucket!r})' is not implemented" + f" and has no efforts.") + + # utils begin + async def document_images_init_bucket(self, knowledge_base_id: str, exist_ok: bool = True, + set_allow_anonymous: bool = False) -> None: + bucket = self.keymap.kb_doc_image.bucket.format(kb_id=knowledge_base_id) + if await self.bucket_create(bucket, exist_ok=exist_ok) and set_allow_anonymous: + await self.bucket_allow_anonymous_get(bucket) + + async def document_images_store(self, knowledge_base_id: str, document_id: str, + images: dict[str, bytes]) -> dict[str, str]: + """Store images and returns a mapping from origin image path to its stored path as {bucket}/{object}.""" + if not images: + return {} + bucket = self.keymap.kb_doc_image.bucket.format(kb_id=knowledge_base_id) + obj_names = { + name: self.keymap.kb_doc_image.object.format(kb_id=knowledge_base_id, doc_id=document_id, img_path=name) + for name in images + } + upload_tasks = [self.file_add(bucket, obj_names[name], content) for name, content in images.items()] + await asyncio.gather(*upload_tasks) + return {k: f"{bucket}/{v}" for k, v in obj_names.items()} + + async def chart_store(self, name: str, content: bytes) -> None: + bucket = self.keymap.report_image.bucket + obj_name = self.keymap.report_image.object.format(img_path=name) + try: + await self.file_add(bucket, obj_name, content) + return + except StorageError as e: + if e.reason != e.Reason.BUCKET_NOT_FOUND: + raise + await self.bucket_create(bucket, exist_ok=True) + await self.file_add(bucket, obj_name, content) + + async def knowledge_file_init_bucket(self, knowledge_base_id: str, owner_type: str, owner_id: str, + exist_ok: bool = True): + bucket = self.keymap.kb_doc_binary.bucket.format(kb_id=knowledge_base_id, owner_type=owner_type, + owner_id=owner_id) + await self.bucket_create(bucket, exist_ok=exist_ok) + + async def knowledge_file_get(self, knowledge_base_id: str, + owner_type: str, owner_id: str, + doc_id: str, doc_name: str) -> bytes: + bucket, obj_name = self._knowledge_file_info(knowledge_base_id, owner_type, owner_id, doc_id, doc_name) + return await self.file_get(bucket, obj_name) + + async def knowledge_file_put(self, knowledge_base_id: str, + owner_type: str, owner_id: str, + doc_id: str, doc_name: str, binary: bytes) -> None: + bucket, obj_name = self._knowledge_file_info(knowledge_base_id, owner_type, owner_id, doc_id, doc_name) + await self.file_add(bucket, obj_name, binary) + + def _knowledge_file_info(self, knowledge_base_id: str, owner_type: str, owner_id: str, + doc_id: str, doc_name: str) -> tuple[str, str]: + bucket_args = dict(kb_id=knowledge_base_id, owner_type=owner_type, owner_id=owner_id) + object_args = dict(kb_id=knowledge_base_id, owner_type=owner_type, owner_id=owner_id, + doc_id=doc_id, doc_name=doc_name) + bucket = self.keymap.kb_doc_binary.bucket.format_map(bucket_args) + obj_name = self.keymap.kb_doc_binary.object.format_map(object_args) + return bucket, obj_name diff --git a/deepinsight/utils/file_storage/factory.py b/deepinsight/utils/file_storage/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..5fec667872a5eb00c381b8adfce11a9b0fad92d6 --- /dev/null +++ b/deepinsight/utils/file_storage/factory.py @@ -0,0 +1,33 @@ +"""Init storage or get existing storage implementation.""" +__all__ = ["get_storage_impl"] + +from typing import TYPE_CHECKING + +from deepinsight.config.file_storage_config import StorageType +from deepinsight.utils.file_storage.base import BaseFileStorage + +if TYPE_CHECKING: + from deepinsight.config.config import Config +else: + from typing import Any as Config + + +_current: BaseFileStorage | None = None + + +def get_storage_impl(config: Config = None) -> BaseFileStorage: + """Init storage or get existing storage implementation.""" + global _current + if config is None: + if not _current: + raise RuntimeError("Deepinsight file storage subsystem not fully inited.") + return _current + if config.file_storage.type == StorageType.LOCAL: + from deepinsight.utils.file_storage.local import LocalStorage + _current = LocalStorage.from_config(config) + elif config.file_storage.type == StorageType.S3_OBS: + from deepinsight.utils.file_storage.s3_compatible import S3CompatibleObsClient + _current = S3CompatibleObsClient.from_config(config) + else: + raise NotImplementedError(f"Unsupported storage type {config.file_storage.type}") + return _current diff --git a/deepinsight/utils/file_storage/local.py b/deepinsight/utils/file_storage/local.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ef014f2f87276782920aa684c295a14499a297 --- /dev/null +++ b/deepinsight/utils/file_storage/local.py @@ -0,0 +1,134 @@ +import io +import logging +import os.path +import pathlib +from typing import Any + +from pydantic import ConfigDict + +from deepinsight.config.config import Config +from deepinsight.utils.file_storage.base import BaseFileStorage, StorageError, StorageOp + + +logger = logging.getLogger(__name__) + + +class LocalStorage(BaseFileStorage): + """Storage implementation via local disk storage.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + + root_dir: str + + @classmethod + def from_config(cls, config: Config) -> "LocalStorage": + return LocalStorage( + root_dir=config.file_storage.local.root_dir or config.workspace.work_root, + keymap=config.file_storage.map_rule + ) + + def model_post_init(self, context: Any, /) -> None: + path = pathlib.Path(self.root_dir) + if path.exists(): + if not path.is_dir(): + raise RuntimeError(f"Path {self.root_dir!r} want to be a directory but actually not.") + else: + os.makedirs(path, exist_ok=True) + + async def bucket_create(self, bucket: str, *, exist_ok: bool = False) -> bool: + exist = self._check_bucket_exists(StorageOp.CREATE, bucket, allow_miss=True) + if exist: + if exist_ok: + return False + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.ALREADY_EXISTS) + os.makedirs(pathlib.Path(self.root_dir, bucket), exist_ok=True) + return True + + async def list_buckets(self) -> list[str]: + return [ + item.name + for item in pathlib.Path(self.root_dir).iterdir() + if item.is_dir() + ] + + async def list_files(self, bucket: str, prefix: str = None) -> list[str]: + self._check_bucket_exists(StorageOp.LIST, bucket, prefix) + bucket_path = pathlib.Path(self.root_dir, bucket) + if prefix: + path = pathlib.Path(self._path_of(StorageOp.LIST, bucket, prefix)) + else: + path = pathlib.Path(bucket_path) + if not path.is_dir(): + return [] + return [ + str(item.relative_to(bucket_path)) + for item in path.rglob("*") + if item.is_file() + ] + + async def file_add(self, bucket: str, filename: str, content: bytes) -> None: + self._check_bucket_exists(StorageOp.CREATE, bucket, filename) + with self._open_file(StorageOp.CREATE, bucket, filename) as f: + f.write(content) + + async def file_delete(self, bucket: str, filename: str, allow_not_exists: bool = True) -> None: + self._check_bucket_exists(StorageOp.DELETE, bucket, filename) + path = self._path_of(StorageOp.GET, bucket, filename) + try: + os.remove(path) + except FileNotFoundError: + if allow_not_exists: + return + raise StorageError(StorageOp.DELETE, bucket, filename, reason=StorageError.Reason.FILE_NOT_FOUND) from None + + async def file_get(self, bucket: str, filename: str) -> bytes: + self._check_bucket_exists(StorageOp.GET, bucket, filename) + with self._open_file(StorageOp.GET, bucket, filename) as f: + return f.read() + + def _check_bucket_exists(self, op: StorageOp, bucket: str, file: str | None = None, *, + allow_miss: bool = False) -> bool: + bucket_path = pathlib.PurePath(bucket) + if any(["\\" in bucket, + len(bucket_path.parts) > 1, + any(part in {"..", ".", ""} for part in bucket_path.parts)]): + raise StorageError(op, bucket, reason=StorageError.Reason.NAME_ILLEGAL) + + path = pathlib.PurePath(self.root_dir, bucket) + if os.path.isdir(path): + return True + if os.path.exists(path): + logger.error(f"Local storage want {str(path)!r} is a directory but got a file.") + raise StorageError(op, bucket, file, reason=StorageError.Reason.NAME_ILLEGAL) + if not allow_miss: + raise StorageError(op, bucket, file, reason=StorageError.Reason.BUCKET_NOT_FOUND) + return False + + def _open_file(self, op: StorageOp, bucket: str, filename: str) -> io.BufferedReader | io.BufferedWriter: + path = self._path_of(op, bucket, filename) + + directory = path.parent + if not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + elif not os.path.isdir(directory): + logger.error(f"Local storage want {str(directory)!r} is a directory but got a file.") + raise StorageError(op, bucket, filename, reason=StorageError.Reason.NAME_ILLEGAL) + + if op == StorageOp.GET: + try: + return open(path, mode="rb") + except (FileNotFoundError, IsADirectoryError): + raise StorageError(op, bucket, filename, reason=StorageError.Reason.FILE_NOT_FOUND) from None + elif op == StorageOp.CREATE: + try: + return open(path, mode="xb") + except FileExistsError: + pass + logger.warning(f"File {str(path)!r} conflicts with exiting, next writing will overwrite its content.") + return open(path, mode="wb") + raise AssertionError("Illegal execute path. DeepInsight has a bug on local file storage.") + + def _path_of(self, op: StorageOp, bucket: str, filename: str) -> pathlib.PurePath: + filename_path = pathlib.PurePath(filename) + if "\\" in filename or any(part in {"..", ".", ""} for part in filename_path.parts): + raise StorageError(op, bucket, filename, reason=StorageError.Reason.NAME_ILLEGAL) + return pathlib.PurePath(self.root_dir, bucket, filename) diff --git a/deepinsight/utils/file_storage/s3_compatible.py b/deepinsight/utils/file_storage/s3_compatible.py new file mode 100644 index 0000000000000000000000000000000000000000..e72f49ab3340e4663c6c91ca73bd297b21290d9e --- /dev/null +++ b/deepinsight/utils/file_storage/s3_compatible.py @@ -0,0 +1,322 @@ +"""S3 compatible OBS(Object Storage Service) client implementation with AWS V4 signature authentication.""" +__all__ = ["S3CompatibleObsClient"] + +import hashlib +import hmac +import json +import logging +import urllib.parse +from datetime import datetime + +import aiohttp +from pydantic import ConfigDict, PrivateAttr + +from deepinsight.config.config import Config +from deepinsight.config.file_storage_config import ConfigS3 +from deepinsight.utils.file_storage.base import BaseFileStorage, StorageError, StorageOp + +logger = logging.getLogger(__name__) + + +class S3CompatibleObsClient(BaseFileStorage): + """S3 compatible OBS(Object Storage Service) client implementation with AWS V4 signature authentication.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + + config: ConfigS3 + _session: aiohttp.ClientSession | None = PrivateAttr(None) + _warn_delete_always_allow_unexist: bool = PrivateAttr(True) + + async def __aenter__(self): + if not self._session: + self._session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._session: + await self._session.close() + + @staticmethod + def _parse_list_buckets_xml(xml_content: str) -> list[str]: + """Parse XML response from list buckets operation.""" + import xml.etree.ElementTree as ElementTree + + try: + root = ElementTree.fromstring(xml_content) + # Find all Bucket/Name elements - handle namespace + buckets = [] + + for bucket in root.findall(".//{http://s3.amazonaws.com/doc/2006-03-01/}Bucket"): + name_elem = bucket.find("{http://s3.amazonaws.com/doc/2006-03-01/}Name") + if name_elem is not None and name_elem.text: + buckets.append(name_elem.text) + # If no buckets found with namespace, try without namespace as fallback + if not buckets: + for bucket in root.findall(".//{*}Bucket"): + name_elem = bucket.find("{*}Name") + if name_elem is not None and name_elem.text: + buckets.append(name_elem.text) + + return buckets + except ElementTree.ParseError as e: + logger.error(f"Failed to parse XML response: {e}") + raise StorageError(StorageOp.LIST, "", reason=StorageError.Reason.OTHER) from e + + @staticmethod + def _parse_list_objects_xml(xml_content: str, bucket: str, prefix: str = None) -> list[str]: + """Parse XML response from list objects operation.""" + import xml.etree.ElementTree as ElementTree + + try: + root = ElementTree.fromstring(xml_content) + files = [] + for content in root.findall(".//{http://s3.amazonaws.com/doc/2006-03-01/}Contents"): + key_elem = content.find("{http://s3.amazonaws.com/doc/2006-03-01/}Key") + if key_elem is not None and key_elem.text: + file_name = key_elem.text + if prefix is None or file_name.startswith(prefix): + files.append(file_name) + + # If no files found with namespace, try without namespace as fallback + if not files: + for content in root.findall(".//{*}Contents"): + key_elem = content.find("{*}Key") + if key_elem is not None and key_elem.text: + file_name = key_elem.text + if prefix is None or file_name.startswith(prefix): + files.append(file_name) + return files + except ElementTree.ParseError as e: + logger.error(f"Failed to parse XML response: {e}") + raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.OTHER) from e + + @classmethod + def from_config(cls, config: Config) -> "S3CompatibleObsClient": + return cls(config=config.file_storage.s3, keymap=config.file_storage.map_rule) + + async def bucket_create(self, bucket: str, *, exist_ok: bool = False) -> bool: + """Create a new bucket.""" + url = self._request_url(bucket) + try: + status, content, headers = await self._make_request("HEAD", url) + if status == 200: + if exist_ok: + return False + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.ALREADY_EXISTS) + elif status == 400: + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.NAME_ILLEGAL) + elif status == 404: + pass + else: + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.OTHER) + except aiohttp.ClientError as e: + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.NETWORK) from e + + try: + status, content, headers = await self._make_request("PUT", url) + if status not in [200, 201]: + error_text = content.decode("utf-8", errors="ignore") + logger.error(f"Failed to create bucket {bucket}: {error_text}") + if status == 403: + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.PERMISSION) + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.OTHER) + except aiohttp.ClientError as e: + logger.error(f"Network error creating bucket: {e}") + raise StorageError(StorageOp.CREATE, bucket, reason=StorageError.Reason.NETWORK) from e + return True + + async def list_buckets(self) -> list[str]: + try: + status, content, headers = await self._make_request("GET", self._request_url()) + if status != 200: + error_text = content.decode("utf-8", errors="ignore") + logger.error(f"Failed to list buckets: {error_text}") + if status == 403: + raise StorageError(StorageOp.LIST, '', reason=StorageError.Reason.PERMISSION) + raise StorageError(StorageOp.LIST, '', reason=StorageError.Reason.OTHER) + return self._parse_list_buckets_xml(content.decode("utf-8")) + except aiohttp.ClientError as e: + logger.error(f"Network error listing buckets: {e}") + raise StorageError(StorageOp.LIST, '', reason=StorageError.Reason.NETWORK) from e + + async def list_files(self, bucket: str, prefix: str = None) -> list[str]: + params = {} + if prefix: + params["prefix"] = prefix + params["delimiter"] = "/" + url = self._request_url(bucket) + if params: + url += "?" + "&".join(f"{k}={urllib.parse.quote(v, safe='')}" + for k, v in sorted(params.items(), key=lambda i: i[0])) + + try: + status, content, headers = await self._make_request("GET", url) + if status == 404: + raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.BUCKET_NOT_FOUND) + elif status != 200: + error_text = content.decode("utf-8", errors="ignore") + logger.error(f"Failed to list files in bucket {bucket}: {error_text}") + if status == 403: + raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.PERMISSION) + raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.OTHER) + return self._parse_list_objects_xml(content.decode("utf-8"), bucket, prefix) + except aiohttp.ClientError as e: + logger.error(f"Network error listing files: {e}") + raise StorageError(StorageOp.LIST, bucket, prefix, reason=StorageError.Reason.NETWORK) from e + + async def file_add(self, bucket: str, filename: str, content: bytes) -> None: + url = self._request_url(bucket, filename) + + try: + status, content_resp, headers = await self._make_request("PUT", url, data=content) + if status == 404: + raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.BUCKET_NOT_FOUND) + elif status not in [200, 201]: + error_text = content_resp.decode("utf-8", errors="ignore") + logger.error(f"Failed to add file {filename} to bucket {bucket}: {error_text}") + if status == 403: + raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.PERMISSION) + raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.OTHER) + except aiohttp.ClientError as e: + logger.error(f"Network error adding file: {e}") + raise StorageError(StorageOp.CREATE, bucket, filename, reason=StorageError.Reason.NETWORK) from e + + async def file_delete(self, bucket: str, filename: str, allow_not_exists: bool = True) -> None: + url = self._request_url(bucket, filename) + if not allow_not_exists: + if self._warn_delete_always_allow_unexist: + self._warn_delete_always_allow_unexist = False + logger.warning(f"Storage implementation {type(self).__name__} always allows delete unexist files.") + try: + status, content, headers = await self._make_request("DELETE", url) + if status not in [200, 204]: + error_text = content.decode("utf-8", errors="ignore") + logger.error(f"Failed to delete file {filename} from bucket {bucket}: {error_text}") + if status == 403: + raise StorageError(StorageOp.DELETE, bucket, filename, reason=StorageError.Reason.PERMISSION) + raise StorageError(StorageOp.DELETE, bucket, filename, reason=StorageError.Reason.OTHER) + except aiohttp.ClientError as e: + logger.error(f"Network error deleting file: {e}") + raise StorageError(StorageOp.DELETE, bucket, filename, reason=StorageError.Reason.NETWORK) from e + + async def file_get(self, bucket: str, filename: str) -> bytes: + try: + status, content, headers = await self._make_request("GET", self._request_url(bucket, filename)) + if status == 404: + raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.FILE_NOT_FOUND) + elif status == 403: + raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.PERMISSION) + elif status != 200: + error_text = content.decode("utf-8", errors="ignore") + if "InvalidBucketName" in error_text: + raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.BUCKET_NOT_FOUND) + logger.error(f"Failed to get file {filename} from bucket {bucket}: {error_text}") + raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.OTHER) + + return content + except aiohttp.ClientError as e: + logger.error(f"Network error getting file: {e}") + raise StorageError(StorageOp.GET, bucket, filename, reason=StorageError.Reason.NETWORK) from e + + async def bucket_allow_anonymous_get(self, bucket: str) -> None: + url = self._request_url(bucket) + "?policy=" + try: + status, content, headers = await self._make_request( + "PUT", url, headers={"Content-Type": "application/json"}, + data=json.dumps({ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": "*", + "Action": "s3:GetObject", + "Resource": f"arn:aws:s3:::{bucket}/*" + } + ] + }).encode("utf8") + ) + if status == 404: + raise StorageError(StorageOp.GET, bucket, reason=StorageError.Reason.FILE_NOT_FOUND) + elif status == 403: + raise StorageError(StorageOp.GET, bucket, reason=StorageError.Reason.PERMISSION) + elif status not in {200, 204}: + error_text = content.decode("utf-8", errors="ignore") + logger.error(f"Failed to set bucket {bucket} policy to allow anonymous get: {error_text}") + raise StorageError(StorageOp.GET, bucket, reason=StorageError.Reason.OTHER) + except aiohttp.ClientError as e: + logger.error(f"Network error setting bucket policy to allow anonymous get: {e}") + raise StorageError(StorageOp.CONFIG, bucket, reason=StorageError.Reason.NETWORK) from e + + def _get_aws_v4_signature(self, method: str, path: str, headers: dict, payload: bytes = b'', + query: str = '') -> dict: + """Generate AWS V4 signature for authentication.""" + parsed_url = urllib.parse.urlparse(self.config.endpoint) + host = parsed_url.netloc + + # AWS V4 signature parameters + service = "s3" + region = "us-east-1" + algorithm = "AWS4-HMAC-SHA256" + + now = datetime.utcnow() + amz_date = now.strftime("%Y%m%dT%H%M%SZ") + date_stamp = now.strftime("%Y%m%d") + + # Calculate signature + canonical_uri = path + canonical_querystring = query + canonical_headers = f"host:{host}\nx-amz-date:{amz_date}\n" + signed_headers = "host;x-amz-date" + payload_hash = hashlib.sha256(payload).hexdigest() + canonical_request = (f"{method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n" + f"{signed_headers}\n{payload_hash}") + credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" + string_to_sign = (f"{algorithm}\n{amz_date}\n{credential_scope}\n" + f"{hashlib.sha256(canonical_request.encode()).hexdigest()}") + signing_key = self._aws_v4_signature_key(date_stamp, region, service) + signature = hmac.new(signing_key, string_to_sign.encode(), hashlib.sha256).hexdigest() + + # Create authorization header + authorization_header = (f"{algorithm} Credential={self.config.ak.get_secret_value()}/{credential_scope}," + f" SignedHeaders={signed_headers}, Signature={signature}") + auth_headers = { + "Authorization": authorization_header, + "x-amz-date": amz_date, + "x-amz-content-sha256": payload_hash + } + auth_headers.update(headers) + + return auth_headers + + def _aws_v4_signature_key(self, date_stamp: str, region: str, service: str) -> bytes: + """Get AWS V4 signature key.""" + key = f"AWS4{self.config.sk.get_secret_value()}".encode() + k_date = hmac.new(key, date_stamp.encode(), hashlib.sha256).digest() + k_region = hmac.new(k_date, region.encode(), hashlib.sha256).digest() + k_service = hmac.new(k_region, service.encode(), hashlib.sha256).digest() + k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest() + return k_signing + + def _request_url(self, bucket: str = None, key: str = None) -> str: + base_url = self.config.endpoint.rstrip("/") + if bucket: + bucket = urllib.parse.quote(bucket, safe="~/") + key = urllib.parse.quote(key, safe="~/") if key else key + return f"{base_url}/{bucket}/{key}" if key else f"{base_url}/{bucket}" + return f"{base_url}/" + + async def _make_request(self, method: str, url: str, headers: dict = None, + data: bytes = None) -> tuple[int, bytes, dict]: + headers = headers or {} + if data: + headers.setdefault("Content-Type", "application/octet-stream") + + # Get AWS V4 signature headers + parsed_url = urllib.parse.urlparse(url) + path = parsed_url.path + query = parsed_url.query + auth_headers = self._get_aws_v4_signature(method, path, headers, data or b'', query) + if not self._session: + self._session = aiohttp.ClientSession(trust_env=True) + async with self._session.request(method, url, headers=auth_headers, data=data) as response: + content = await response.read() + return response.status, content, dict(response.headers) diff --git a/tests/utils/file_storage/__init__.py b/tests/utils/file_storage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..609078e652fb52affe782a0efedb5a5810c086a0 --- /dev/null +++ b/tests/utils/file_storage/__init__.py @@ -0,0 +1 @@ +"""Testcase for package `deepinsight.utils.file_storage`.""" \ No newline at end of file diff --git a/tests/utils/file_storage/test_storage_local.py b/tests/utils/file_storage/test_storage_local.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e7c470b3064c34892db3d040875d6978563850 --- /dev/null +++ b/tests/utils/file_storage/test_storage_local.py @@ -0,0 +1,81 @@ +import asyncio +import os.path +from unittest import TestCase +from unittest.mock import MagicMock + +from deepinsight.config.config import Config +from deepinsight.config.file_storage_config import FileStorageConfig +from deepinsight.utils.file_storage import get_storage_impl, StorageOp, StorageError + + +class TestStorageLocal(TestCase): + target_dir = "./local_storage_test" + + def setUp(self): + self.assertFalse(os.path.exists(self.target_dir)) + + def test_storage(self): + config: Config = MagicMock() + config.workspace.work_root = self.target_dir + config.file_storage = FileStorageConfig(type="local") # type: ignore + get_storage_impl(config) + self.assertTrue(os.path.exists(self.target_dir)) + + async def test_main(): + storage = get_storage_impl() + bucket = "ab" + fake_bucket = "abb" + + file1 = "1.txt" + file2 = "1/1.txt" + fake_file = "1" + + content1 = b"123" + content2 = b"12345" + + r = StorageError.Reason + + self.assertEqual([], await storage.list_buckets()) + await storage.bucket_create(bucket, exist_ok=True) + self.assertEqual([bucket], await storage.list_buckets()) + + await self._assert_raises(storage.file_add(fake_bucket, file1, content1), + StorageOp.CREATE, fake_bucket, file1, r.BUCKET_NOT_FOUND) + + await storage.file_add(bucket, file1, content1) + self.assertEqual([file1], await storage.list_files(bucket)) + self.assertEqual([], await storage.list_files(bucket, "1")) + + await storage.file_add(bucket, file2, content2) + self.assertEqual([file2], await storage.list_files(bucket, "1")) + self.assertEqual({file1, file2}, set(await storage.list_files(bucket))) + await self._assert_raises(storage.file_get(bucket, fake_file), + StorageOp.GET, bucket, fake_file, r.FILE_NOT_FOUND) + + self.assertEqual(content1, await storage.file_get(bucket, file1)) + self.assertEqual(content2, await storage.file_get(bucket, file2)) + await self._assert_raises(storage.file_get(fake_bucket, file2), + StorageOp.GET, fake_bucket, file2, r.BUCKET_NOT_FOUND) + + await storage.file_delete(bucket, file2, allow_not_exists=False) + await storage.file_delete(bucket, file2, allow_not_exists=True) + await self._assert_raises(storage.file_delete(bucket, file2, allow_not_exists=False), + StorageOp.DELETE, bucket, file2, r.FILE_NOT_FOUND) + self.assertEqual([file1], await storage.list_files(bucket)) + + asyncio.run(test_main()) + + def tearDown(self): + import shutil + shutil.rmtree(self.target_dir, ignore_errors=True) + + async def _assert_raises(self, awaitable, op: StorageOp, bucket: str, file: str, reason): + try: + await awaitable + except StorageError as e: + self.assertEqual(e.op, op) + self.assertEqual(e.bucket, bucket) + self.assertEqual(e.filename, file) + self.assertEqual(e.reason, reason) + else: + self.fail("Except raises") diff --git a/tests/utils/file_storage/test_storage_s3.py b/tests/utils/file_storage/test_storage_s3.py new file mode 100644 index 0000000000000000000000000000000000000000..a94187d0054854cbdd77a73f70f8e73341fa0875 --- /dev/null +++ b/tests/utils/file_storage/test_storage_s3.py @@ -0,0 +1,86 @@ +import logging +import os.path +from unittest import IsolatedAsyncioTestCase + +import boto3 + +from deepinsight.config.file_storage_config import ConfigS3 +from deepinsight.utils.file_storage import StorageOp, StorageError +from deepinsight.utils.file_storage.s3_compatible import S3CompatibleObsClient + + +class TestStorageS3(IsolatedAsyncioTestCase): + async def test_storage(self): + endpoint = os.getenv("ST_OBS_S3_ENDPOINT") + ak = os.getenv("ST_OBS_S3_AK") + sk = os.getenv("ST_OBS_S3_SK") + bucket = os.getenv("ST_OBS_S3_BUCKET1") + fake_bucket = os.getenv("ST_OBS_S3_BUCKET2") + + if not all((endpoint, ak, sk, bucket, fake_bucket)): + self.skipTest("No available S3 compatible endpoint. Set 'ST_OBS_S3_ENDPOINT', 'ST_OBS_S3_AK', " + "'ST_OBS_S3_SK', 'ST_OBS_S3_BUCKET1', 'ST_OBS_S3_BUCKET2' to test this case.") + + async with S3CompatibleObsClient(config=ConfigS3(endpoint=endpoint, ak=ak, sk=sk)) as storage: # type: ignore + file1 = "100%20.txt" + file2 = "1/中文~ 带空格.txt" + fake_file = "1" + + content1 = b"123" + content2 = b"12345" + + r = StorageError.Reason + already_exists = set(await storage.list_buckets()) + await storage.bucket_create(bucket, exist_ok=False) + self.assertEqual({*already_exists, bucket}, set(await storage.list_buckets())) + await storage.bucket_create(bucket, exist_ok=True) + self.assertEqual({*already_exists, bucket}, set(await storage.list_buckets())) + + await self._assert_raises(storage.file_add(fake_bucket, file1, content1), + StorageOp.CREATE, fake_bucket, file1, r.BUCKET_NOT_FOUND) + + await storage.file_add(bucket, file1, content1) + self.assertEqual([file1], await storage.list_files(bucket)) + self.assertEqual([], await storage.list_files(bucket, "2")) + + await storage.file_add(bucket, file2, content2) + self.assertEqual([file2], await storage.list_files(bucket, "1/")) + self.assertEqual({file1, file2}, set(await storage.list_files(bucket))) + await self._assert_raises(storage.file_get(bucket, fake_file), + StorageOp.GET, bucket, fake_file, r.FILE_NOT_FOUND) + + self.assertEqual(content1, await storage.file_get(bucket, file1)) + self.assertEqual(content2, await storage.file_get(bucket, file2)) + await self._assert_raises(storage.file_get(fake_bucket, file2), + StorageOp.GET, fake_bucket, file2, r.FILE_NOT_FOUND) + + await storage.file_delete(bucket, file2) + await storage.file_delete(bucket, file2, allow_not_exists=True) + await storage.file_delete(bucket, file2, allow_not_exists=False) # always allow + self.assertEqual([file1], await storage.list_files(bucket)) + + def tearDown(self): + endpoint = os.getenv("ST_OBS_S3_ENDPOINT") + ak = os.getenv("ST_OBS_S3_AK") + sk = os.getenv("ST_OBS_S3_SK") + bucket_name = os.getenv("ST_OBS_S3_BUCKET1") + if not all((endpoint, ak, sk, bucket_name)): + return + try: + s3 = boto3.resource("s3", endpoint_url=endpoint, aws_access_key_id=ak, aws_secret_access_key=sk) + bucket = s3.Bucket(bucket_name) + bucket.objects.all().delete() + bucket.delete() + except Exception as e: + logging.warning(f"Cleanup failed with {e}") + + async def _assert_raises(self, awaitable, op: StorageOp, bucket: str, file: str, reason): + try: + await awaitable + except StorageError as e: + self.assertEqual(e.op, op) + self.assertEqual(e.bucket, bucket) + self.assertEqual(e.filename, file) + self.assertEqual(e.reason, reason) + else: + self.fail("Except raises") diff --git a/tests/utils/file_storage/test_utils.py b/tests/utils/file_storage/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7488aef9a3537e53006cf3332b5806de451b4df0 --- /dev/null +++ b/tests/utils/file_storage/test_utils.py @@ -0,0 +1,73 @@ +from os import path +from unittest import IsolatedAsyncioTestCase + +from deepinsight.config.file_storage_config import ObsMappingConfig, MappingItem +from deepinsight.utils.file_storage.local import LocalStorage + + +FOR_TEST_CONFIG = ObsMappingConfig( + kb_doc_image=MappingItem(bucket="aaa{kb_id}", object="bbb/{doc_id}/{img_path}"), + kb_doc_binary=MappingItem(bucket="bbb{kb_id}", object="ccc/{kb_id}/{doc_id}/{doc_name}"), + report_image=MappingItem(bucket="report-img-bucket-test", object="some/of/the/{img_path}") +) + + +class TestUtilFuncs(IsolatedAsyncioTestCase): + def setUp(self): + self.path = path.join(path.dirname(path.abspath(__file__)), "./ut_file_storage_utils") + self.assertFalse(path.exists(self.path)) + self.storage = LocalStorage(root_dir=self.path, keymap=FOR_TEST_CONFIG) + + async def test_document_images(self): + self.assertEqual([], await self.storage.list_buckets()) + kb_id = "_x" + bucket = "aaa_x" + await self.storage.document_images_init_bucket(kb_id) + self.assertEqual([bucket], await self.storage.list_buckets()) + images = {f"some/{i}.jpg": (f"{i}" * i).encode("utf8") for i in range(3, 6)} + doc_id = "test1" + await self.storage.document_images_store(kb_id, doc_id, images) + actual = set(await self.storage.list_files(bucket)) + want = {f"bbb/{doc_id}/{img}" for img in images} + self.assertEqual(want, actual) + for name, content in images.items(): + self.assertEqual(content, await self.storage.file_get(bucket, f"bbb/{doc_id}/{name}")) + + async def test_document_binary(self): + self.assertEqual([], await self.storage.list_buckets()) + kb_id = "_x" + bucket = "bbb_x" + owner = "unused" + owner_id = "unused_id" + await self.storage.knowledge_file_init_bucket(kb_id, owner, owner_id) + self.assertEqual([bucket], await self.storage.list_buckets()) + docs = [ + (f"some_{i}.pdf", str(i), (f"{i}1" * i).encode("utf8")) + for i in range(4, 7) + ] + for name, doc_id, content in docs: + await self.storage.knowledge_file_put(kb_id, owner, owner_id, doc_id, name, content) + actual = set(await self.storage.list_files(bucket)) + want = {f"ccc/{kb_id}/{doc_id}/{name}" for name, doc_id, _ in docs} + self.assertEqual(want, actual) + for name, doc_id, content in docs: + self.assertEqual(content, await self.storage.knowledge_file_get(kb_id, owner, owner_id, doc_id, name)) + self.assertEqual(content, await self.storage.file_get(bucket, f"ccc/{kb_id}/{doc_id}/{name}")) + + async def test_chart_images(self): + self.assertEqual([], await self.storage.list_buckets()) + bucket = "report-img-bucket-test" + images = {f"some/{i}.png": (f"{i}" * i).encode("utf8") for i in range(10, 13)} + for name, content in images.items(): + await self.storage.chart_store(name, content) + self.assertEqual([bucket], await self.storage.list_buckets()) + + actual = set(await self.storage.list_files(bucket)) + want = {f"some/of/the/{img}" for img in images} + self.assertEqual(want, actual) + for name, content in images.items(): + self.assertEqual(content, await self.storage.file_get(bucket, f"some/of/the/{name}")) + + def tearDown(self): + import shutil + shutil.rmtree(self.path, ignore_errors=True)