diff --git a/README.md b/README.md index 3016bef3aefe1f4da0b5e2a66f4297609a8a19d3..c5d1eea70e7d266cf95998f6995b165377ffe469 100644 --- a/README.md +++ b/README.md @@ -244,6 +244,7 @@ openYuanrong datasystem 还提供了基于 Kubernetes 容器化部署方式, ``` - object + 通过 object 接口,实现基于引用计数的缓存数据管理: ```python diff --git a/VERSION b/VERSION index 79a2734bbf3de7aaf00e385c644d30704c03c7c8..09a3acfa138db01dc50a688796b2861638c369ec 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.0 \ No newline at end of file +0.6.0 \ No newline at end of file diff --git a/build.sh b/build.sh index 3b0030255c65b1f79043c0f7cd5debcbdf714242..a8066ca8b0f587c3f7d143f4f380d295dc1907db 100755 --- a/build.sh +++ b/build.sh @@ -20,8 +20,9 @@ readonly USAGE=" Usage: bash build.sh [-h] [-r] [-d] [-c off/on/html] [-t off|build|run] [-s on|off] [-j ] [-p on|off] [-S address|thread|undefined|off] [-o ] [-u ] [-B ] [-J on|off] [-P on/off] [-G on/off] [-X on/off] [-e on/off] [-T ] - [-R on/off] [-O on/off] [-I ] [-M \"on|off \"/off] [-D \"on \"/off] - [-C on/off] [-l ] [-i on/off] [-n on/off][-x on/off] + [-R on/off] [-O on/off] [-I ] [-M \"on|off \"/off] + [-D \"on|off \"/off] [-A on/off] + [-C on/off] [-l ] [-i on/off] [-n on/off] [-x on/off] Options: -h Output this help and exit. @@ -58,6 +59,10 @@ Options: Command to set up OFED environment: ./mlnxofedinstall --without-depcheck --without-fw-update --force /etc/init.d/openibd restart + -A Build with UCX framework to support RDMA transport, choose from on/off, default: off. + Notes for compiling and running with RDMA support: + 1. An RDMA-capable NIC and its driver must be installed and properly configured. + 2. The RDMA userspace libraries (libibverbs, librdmacm) from rdma-core must be installed. For debug code: -p Generate perf point logs, choose from: on/off, default: off. @@ -129,6 +134,7 @@ function init_default_opts() { export DOWNLOAD_UB="off" export UB_URL="" export UB_SHA256="" + export BUILD_WITH_RDMA="off" # For testcase export BUILD_TESTCASE="off" @@ -517,6 +523,7 @@ function build_datasystem() "-DSUPPORT_JEPROF:BOOL=${SUPPORT_JEPROF}" "-DBUILD_WITH_URMA:BOOL=${BUILD_WITH_URMA}" "-DURMA_OVER_UB:BOOL=${URMA_OVER_UB}" + "-DBUILD_WITH_RDMA:BOOL=${BUILD_WITH_RDMA}" ) if [[ "${BUILD_WITH_URMA}" == "on" ]]; then @@ -605,7 +612,7 @@ function main() { echo "Can't get logical cpu count, set to default 16" logical_cpu_cout=16 fi - while getopts 'hdro:j:t:u:c:e:p:s:l:i:n:B:F:S:P:T:X:R:D:C:M:x:m:' OPT; do + while getopts 'hdro:j:t:u:c:e:p:s:l:i:n:A:B:F:S:P:T:X:R:D:C:M:x:m:' OPT; do case "${OPT}" in d) BUILD_TYPE="Debug" @@ -697,6 +704,10 @@ function main() { D) parse_ub_download_options ${OPTARG} ;; + A) + check_on_off "${OPTARG}" A + BUILD_WITH_RDMA="${OPTARG}" + ;; h) echo -e "${USAGE}" exit 0 diff --git a/cli/deploy/conf/worker_config.json b/cli/deploy/conf/worker_config.json index 2bee2d4a4e0e0d139fe64ade0fea48d816219d8e..e768f6cf2ce894ea3eb5eadb1aff5332a5a8dec5 100644 --- a/cli/deploy/conf/worker_config.json +++ b/cli/deploy/conf/worker_config.json @@ -506,5 +506,9 @@ "stream_idle_time_s": { "value": "300", "description": "stream idle time. default 300s (5 minutes)" + }, + "skip_authenticate": { + "value": "false", + "description": "Skip authentication for worker requests" } } \ No newline at end of file diff --git a/cmake/dependency.cmake b/cmake/dependency.cmake index 4970b88f4cd3e3f8e3bcfa2376121e57ba67c3f3..489874ddc24dd68e23fbe6e913366f1f669f5e3c 100644 --- a/cmake/dependency.cmake +++ b/cmake/dependency.cmake @@ -26,6 +26,9 @@ if (BUILD_WITH_URMA) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ub.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/urma.cmake) endif() +if (BUILD_WITH_RDMA) + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ucx.cmake) +endif() if (WITH_TESTS) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) diff --git a/cmake/external_libs/ucx.cmake b/cmake/external_libs/ucx.cmake new file mode 100644 index 0000000000000000000000000000000000000000..0c5233a72f549a7baf25cfb10c3257dc3adb4c7c --- /dev/null +++ b/cmake/external_libs/ucx.cmake @@ -0,0 +1,51 @@ +set(UCX_VERSION 1.18.0) +if ("$ENV{DS_PACKAGE}" STREQUAL "") + set(UCX_URL "https://gitee.com/mirrors/ucxsource/repository/archive/v1.18.0.zip") + set(UCX_SHA256 "99b94e14630b9f72044d965166c4b0985d80a9914cb52f015c573e3d27ee9f81") +else() + gen_thirdparty_pkg(UCX UCX_URL UCX_SHA256 UCX_FAKE_SHA256 UCX_VERSION) +endif() + +include(CheckIncludeFile) + +check_include_file("rdma/rdma_cma.h" RDMA_CORE_FOUND) + +if (RDMA_CORE_FOUND) + message(STATUS "rdma-core found: rdma/rdma_cma.h header is available.") +else() + message(FATAL_ERROR "rdma-core not found. Please install rdma-core to proceed.") +endif() + +set(UCX_CONF_OPTIONS + --enable-optimizations + --with-verbs=${rdma_core_ROOT} + --with-rdmacm=${rdma_core_ROOT} + --enable-mt + ) + +set(UCX_C_FLAGS ${THIRDPARTY_SAFE_FLAGS}) +set(UCX_LINK_FLAGS "-Wl,-z,now") + +set(UCX_AUTOGEN sh autogen.sh) + +set(_ORG_LD_FLAGS $ENV{LDFLAGS}) +set(ENV{LDFLAGS} "${THIRDPARTY_SAFE_FLAGS} ${_ORG_LD_FLAGS}") + +add_thirdparty_lib(UCX + URL ${UCX_URL} + SHA256 ${UCX_SHA256} + FAKE_SHA256 ${UCX_FAKE_SHA256} + VERSION ${UCX_VERSION} + CONF_OPTIONS ${UCX_CONF_OPTIONS} + C_FLAGS ${UCX_C_FLAGS} + TOOLCHAIN configure + PRE_CONFIGURE ${UCX_AUTOGEN} + ) + +set(UCX_DIR ${UCX_ROOT}) +find_package(UCX ${UCX_VERSION} REQUIRED) + +add_definitions(-DUSE_RDMA) +if(UCX_FOUND) + include_directories(${UCX_INCLUDE_DIR}) +endif() \ No newline at end of file diff --git a/cmake/modules/FindUCX.cmake b/cmake/modules/FindUCX.cmake new file mode 100644 index 0000000000000000000000000000000000000000..550497ce96aaa0ac43b3aab0ac02f1b9b7c278f8 --- /dev/null +++ b/cmake/modules/FindUCX.cmake @@ -0,0 +1,94 @@ +# - Find UCX (ucp/api/ucp.h, libucp.so, libuct.so, libucs.so) +# This module defines +# UCX_INCLUDE_DIR - Directory containing ucp/api/ucp.h +# UCX_UCP_LIBRARY - Path to libucp +# UCX_UCT_LIBRARY - Path to libuct +# UCX_UCS_LIBRARY - Path to libucs +# UCX_UCM_LIBRARY - Path to libucm (optional) +# UCX_LIBRARIES - All required UCX libraries (ucp, uct, ucs) +# UCX_FOUND - True if UCX is found + +set(_UCX_SEARCH_DIRS ${ucx_ROOT}) + +# Find include directory (look for ucp/api/ucp.h) +find_path(UCX_INCLUDE_DIR + NAMES ucp/api/ucp.h + PATHS ${_UCX_SEARCH_DIRS} + PATH_SUFFIXES include + DOC "Path to UCX include directory (containing ucp/api/ucp.h)" + NO_CMAKE_SYSTEM_PATH + NO_SYSTEM_ENVIRONMENT_PATH +) + +# Find libraries +find_library(UCX_UCP_LIBRARY + NAMES ucp + PATHS ${_UCX_SEARCH_DIRS} + PATH_SUFFIXES lib lib64 + DOC "UCX UCP library" + NO_CMAKE_SYSTEM_PATH + NO_SYSTEM_ENVIRONMENT_PATH +) + +find_library(UCX_UCT_LIBRARY + NAMES uct + PATHS ${_UCX_SEARCH_DIRS} + PATH_SUFFIXES lib lib64 + DOC "UCX UCT library" + NO_CMAKE_SYSTEM_PATH + NO_SYSTEM_ENVIRONMENT_PATH +) + +find_library(UCX_UCS_LIBRARY + NAMES ucs + PATHS ${_UCX_SEARCH_DIRS} + PATH_SUFFIXES lib lib64 + DOC "UCX UCS library" + NO_CMAKE_SYSTEM_PATH + NO_SYSTEM_ENVIRONMENT_PATH +) + +find_library(UCX_UCM_LIBRARY + NAMES ucm + PATHS ${_UCX_SEARCH_DIRS} + PATH_SUFFIXES lib lib64 + NO_CMAKE_SYSTEM_PATH + NO_SYSTEM_ENVIRONMENT_PATH +) + +# Build full library list +set(UCX_LIBRARIES ${UCX_UCP_LIBRARY} ${UCX_UCT_LIBRARY} ${UCX_UCS_LIBRARY}) +if(UCX_UCM_LIBRARY) + list(APPEND UCX_LIBRARIES ${UCX_UCM_LIBRARY}) +endif() + +# Standard argument handling +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + UCX + REQUIRED_VARS + UCX_INCLUDE_DIR + UCX_UCP_LIBRARY + UCX_UCT_LIBRARY + UCX_UCS_LIBRARY +) + +# Set variables as advanced (hide in CMake GUI) +mark_as_advanced( + UCX_INCLUDE_DIR + UCX_UCP_LIBRARY + UCX_UCT_LIBRARY + UCX_UCS_LIBRARY + UCX_UCM_LIBRARY +) + +# Print status for debugging +if(UCX_FOUND) + message(STATUS "UCX_INCLUDE_DIR = ${UCX_INCLUDE_DIR}") + message(STATUS "UCX_UCP_LIBRARY = ${UCX_UCP_LIBRARY}") + message(STATUS "UCX_UCT_LIBRARY = ${UCX_UCT_LIBRARY}") + message(STATUS "UCX_UCS_LIBRARY = ${UCX_UCS_LIBRARY}") + if(UCX_UCM_LIBRARY) + message(STATUS "UCX_UCM_LIBRARY = ${UCX_UCM_LIBRARY}") + endif() +endif() \ No newline at end of file diff --git a/cmake/package.cmake b/cmake/package.cmake index 184293269d8115b8013309f50a7d147e245380de..d66ce808d2452f422c64bcb1c09f9db23af7a220 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -25,7 +25,10 @@ if (BUILD_WITH_URMA) list(APPEND RPC_LIB_PATH "${URMA_LIB_LOCATION}/liburma*.so*") list(APPEND URMA_LIB_PATH "${URMA_IP_IB_LIB_LOCATION}/liburma_*.so*") endif() - +if (BUILD_WITH_RDMA) + list(APPEND RPC_LIB_PATH "${UCX_LIB_PATH}/libuc*.so*") + list(APPEND UCX_LIB_PATH "${UCX_LIB_PATH}/libuc*.so*") +endif() ############################################################ # Datasystem header files and share libraries. ############################################################ @@ -220,6 +223,13 @@ if (BUILD_WITH_URMA) ) endif() +if (BUILD_WITH_RDMA) + install_file_pattern( + PATH_PATTERN ${UCX_LIB_PATH} + DEST_DIR ${DATASYSTEM_SERVICE_LIBPATH}/rdma + PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ + ) +endif() ############################################################ # Datasystem deploy scripts generate zone. ############################################################ diff --git a/cmake/util.cmake b/cmake/util.cmake index aea57de9241c0aa76a170e42685232a2f674c1f6..aa4365fceb265f0368a44cb4755663b43d86dbe7 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -476,7 +476,7 @@ function(ADD_THIRDPARTY_LIB LIB_NAME) endif() list(APPEND ARG_CONF_OPTIONS "--prefix=${${LIB_NAME}_ROOT}") - if (EXISTS ${${_LIB_NAME_LOWER}_SOURCE_DIR}/config) + if (EXISTS ${${_LIB_NAME_LOWER}_SOURCE_DIR}/config AND NOT IS_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/config) set(_CONFIG_FILE ${${_LIB_NAME_LOWER}_SOURCE_DIR}/config) else() set(_CONFIG_FILE ${${_LIB_NAME_LOWER}_SOURCE_DIR}/configure) diff --git a/docs/source_zh_cn/installation/installation_linux.md b/docs/source_zh_cn/installation/installation_linux.md index 02878a14ea877b70bdae00a0bb3a7340be20c499..75ab081aa1faa8a1176f5fcdc311d9488ab58a18 100644 --- a/docs/source_zh_cn/installation/installation_linux.md +++ b/docs/source_zh_cn/installation/installation_linux.md @@ -44,6 +44,7 @@ pip install openyuanrong-datasystem-sdk |openEuler|22.03|运行/编译openYuanrong datasystem的操作系统| |[Python](#安装-python)|3.9-3.11|openYuanrong datasystem的运行/编译依赖Python环境| |[CANN](#安装-cann)|8.2.RC1|运行/编译异构相关特性的依赖库| +|rdma-core|35.1|运行/编译RDMA特性的依赖库| #### 安装 Python @@ -115,6 +116,33 @@ source ${HOME}/Ascend/ascend-toolkit/set_env.sh ``` +#### 安装 rdma-core + +> 如无需运行/编译RDMA特性,可跳过 rdma-core 安装步骤 + +
+rdma-core安装步骤(点我展开) + +[rdma-core](https://github.com/linux-rdma/rdma-core)可通过yum进行安装。 + +安装rdma-core: +```bash +sudo yum install rdma-core-devel +``` + +安装完成后,可通过以下命令查看软件是否安装成功: +```bash +ls -l /usr/lib64/libibverbs.so +ls -l /usr/lib64/librdmacm.so +``` +若输出类似以下内容则说明安装成功: +```bash +lrwxrwxrwx. 1 root root 15 Mar 23 2022 /usr/lib64/libibverbs.so -> libibverbs.so.1 +lrwxrwxrwx. 1 root root 14 Mar 23 2022 /usr/lib64/librdmacm.so -> librdmacm.so.1 +``` + +
+ ### 源码编译额外依赖 > 如无需源码编译 openYuanrong datasystem,请跳过本章节。 diff --git a/include/datasystem/utils/connection.h b/include/datasystem/utils/connection.h index 03352c2583258a2376d2451f4c88f4c3791a510b..4b1f00b2de96951209feda8ab864a7bc111ec2d7 100644 --- a/include/datasystem/utils/connection.h +++ b/include/datasystem/utils/connection.h @@ -41,6 +41,7 @@ struct ConnectOptions { SensitiveValue secretKey = ""; std::string tenantId = ""; bool enableCrossNodeConnection = false; + bool enableExclusiveConnection = false; }; } // namespace datasystem diff --git a/include/datasystem/utils/status.h b/include/datasystem/utils/status.h index 9f90bd1814086f5a241177ef5b824c8bf1186d2d..a3f39e58575dd3479c64193a5ca86617d4c13660 100644 --- a/include/datasystem/utils/status.h +++ b/include/datasystem/utils/status.h @@ -20,6 +20,7 @@ #ifndef DATASYSTEM_UTILS_STATUS_H #define DATASYSTEM_UTILS_STATUS_H +#include #include #include @@ -108,11 +109,11 @@ class Status { public: Status() noexcept; - Status(const Status &other) = default; + Status(const Status &other) noexcept; Status(Status &&other) noexcept; - Status &operator=(const Status &other); + Status &operator=(const Status &other) noexcept; Status &operator=(Status &&other) noexcept; @@ -123,7 +124,7 @@ public: * @param[in] code Return code. * @param[in] msg Return msg. */ - Status(StatusCode code, std::string msg); + Status(StatusCode code, std::string msg) noexcept; /** * @brief Set return info of Status. @@ -219,8 +220,13 @@ public: static std::string StatusCodeName(StatusCode code); private: - StatusCode code_; - std::string errMsg_; + void Assign(const Status &other) noexcept; + + struct State { + StatusCode code; + std::string errMsg; + }; + std::unique_ptr state_{ nullptr }; }; } // namespace datasystem diff --git a/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml b/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml index 58d580d2ae65aaa52d0e8bec0af94acb79525cbc..57ab9e8201e1b8cdddd7f590dbd514af134c1507 100644 --- a/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml +++ b/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml @@ -91,6 +91,7 @@ spec: - -log_monitor_exporter={{ $.Values.global.observability.logMonitorExporter }} - -log_monitor_interval_ms={{ $.Values.global.observability.logMonitorIntervalMs }} - -enable_meta_replica={{ $.Values.global.metadata.enableMetaReplica }} + - -enable_redirect={{ $.Values.global.metadata.enableRedirect }} - -log_async={{ $.Values.global.log.logAsync }} - -logbufsecs={{ $.Values.global.log.logBufSecs }} - -log_compress={{ $.Values.global.log.logCompress }} @@ -135,6 +136,7 @@ spec: - -system_secret_key={{ $.Values.global.akSk.systemSecretKey }} - -tenant_access_key={{ $.Values.global.akSk.tenantAccessKey }} - -tenant_secret_key={{ $.Values.global.akSk.tenantSecretKey }} + - -skip_authenticate={{ $.Values.global.akSk.skipAuthenticate }} - -request_expire_time_s={{ $.Values.global.akSk.requestExpireTimeS }} {{- if eq ($.Values.global.l2Cache.l2CacheType) "obs" }} - -obs_access_key={{ $.Values.global.l2Cache.obs.obsAccessKey }} diff --git a/k8s/helm_chart/datasystem/values.yaml b/k8s/helm_chart/datasystem/values.yaml index 4073f358f745b5c20dff7a8d1473ff93ca3dd7c0..e2782c702c06e90aebcefd4ee1a8eee27ddcc902 100644 --- a/k8s/helm_chart/datasystem/values.yaml +++ b/k8s/helm_chart/datasystem/values.yaml @@ -212,6 +212,8 @@ global: rocksdbBackgroundThreads: 16 # Config the rocksdb support none, sync or async, async by default. Optional value: 'none', 'sync', 'async'. This represents the method of writing metadata to rocksdb. rocksdbWriteMode: "async" + # enable query meta redirect when scale up or voluntary scale down, default is false + enableRedirect: "true" rpc: # Whether to enable the authentication function between components(worker, master) @@ -397,6 +399,8 @@ global: tenantSecretKey: "" # Request expiration time in seconds, the maximum value is 300s. requestExpireTimeS: 300 + # Skip authentication for worker requests + skipAuthenticate: false # fsGroup configuration # All processes of the container are also part of the supplementary group ID. diff --git a/python/object_client.py b/python/object_client.py index 89da5c042a2baa2a6d0c5a784d7b84057800c3af..a2ee7d9bc38c2db9ffb4b55566fe2fd9f7f7fed1 100644 --- a/python/object_client.py +++ b/python/object_client.py @@ -297,6 +297,7 @@ class ObjectClient: access_key(str): The access key used by AK/SK authorize. secret_key(str): The secret key for AK/SK authorize. tenant_id(str): The tenant ID. + enable_exclusive_connection(bool): Indicates if the client connects using exclusive conn mode, default off. Raises: TypeError: Raise a type error if the input parameter is invalid. @@ -317,6 +318,7 @@ class ObjectClient: access_key="", secret_key="", tenant_id="", + enable_exclusive_connection=False ): """Constructor of the ObjectClient class @@ -348,6 +350,7 @@ class ObjectClient: ["access_key", access_key, str], ["secret_key", secret_key, str], ["tenant_id", tenant_id, str], + ["enable_exclusive_connection", enable_exclusive_connection, bool] ] validator.check_args_types(args) self.client = ds.ObjectClient( @@ -359,7 +362,8 @@ class ObjectClient: server_public_key, access_key, secret_key, - tenant_id + tenant_id, + enable_exclusive_connection ) @staticmethod diff --git a/src/datasystem/client/CMakeLists.txt b/src/datasystem/client/CMakeLists.txt index 869c6be50db1d252df7a587443f1ed3f78261e0e..0b46097e30497e356a19d68df4a38fc412c6a5c1 100644 --- a/src/datasystem/client/CMakeLists.txt +++ b/src/datasystem/client/CMakeLists.txt @@ -43,6 +43,7 @@ list(APPEND CLIENT_DEPEND_LIBS common_shm_unit_info common_util common_immutable_string + string_ref nlohmann_json::nlohmann_json posix_protos_client share_memory_protos_client @@ -50,7 +51,8 @@ list(APPEND CLIENT_DEPEND_LIBS worker_stream_protos_client common_acl_device common_shared_memory - common_rdma) + common_rdma + common_parallel) if (ENABLE_PERF) file(GLOB_RECURSE PERF_CLIENT_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} perf_client/*.cpp ) diff --git a/src/datasystem/client/client_worker_common_api.cpp b/src/datasystem/client/client_worker_common_api.cpp index 62bba00f31c43b9e68088256eb688b0375c4aba2..adc89206abd76e2126a99ac3be47f4c27c1ea90c 100644 --- a/src/datasystem/client/client_worker_common_api.cpp +++ b/src/datasystem/client/client_worker_common_api.cpp @@ -38,6 +38,7 @@ #include "datasystem/common/perf/perf_manager.h" #include "datasystem/common/rpc/rpc_auth_key_manager.h" #include "datasystem/common/rpc/unix_sock_fd.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/fd_manager.h" #include "datasystem/common/util/fd_pass.h" #include "datasystem/common/util/format.h" @@ -53,9 +54,13 @@ namespace datasystem { namespace client { + +// Static/global id generator init +std::atomic ClientWorkerCommonApi::exclusiveIdGen_ = 0; + ClientWorkerCommonApi::ClientWorkerCommonApi(HostPort hostPort, RpcCredential cred, HeartbeatType heartbeatType, Signature *signature, std::string tenantId, - bool enableCrossNodeConnection) + bool enableCrossNodeConnection, bool enableExclusiveConnection) : hostPort_(std::move(hostPort)), cred_(std::move(cred)), pageSize_(0), @@ -63,7 +68,8 @@ ClientWorkerCommonApi::ClientWorkerCommonApi(HostPort hostPort, RpcCredential cr heartbeatType_(heartbeatType), signature_(signature), tenantId_(std::move(tenantId)), - enableCrossNodeConnection_(enableCrossNodeConnection) + enableCrossNodeConnection_(enableCrossNodeConnection), + enableExclusiveConnection_(enableExclusiveConnection) { recvPageThread_ = Thread(&ClientWorkerCommonApi::RecvPageFd, this); } @@ -81,11 +87,11 @@ ClientWorkerCommonApi::~ClientWorkerCommonApi() void ClientWorkerCommonApi::SetRpcTimeout() { - constexpr int32_t rpcMaxTimeout = 600'000; // 10min + constexpr int32_t rpcMaxTimeout = 600000; // 10min int32_t rpcTimeout = timeoutMs_ / retryTimes_; - int32_t shorterSplitTime = 30'000; // 30s - int32_t longerSplitTime = 90'000; // 90s + int32_t shorterSplitTime = 30000; // 30s + int32_t longerSplitTime = 90000; // 90s if (timeoutMs_ <= shorterSplitTime) { rpcTimeoutMs_ = timeoutMs_; } else if (timeoutMs_ <= longerSplitTime) { @@ -301,7 +307,7 @@ void ClientWorkerCommonApi::ConstructDecShmUnit(const RegisterClientRspPb &rsp) decShmUnit_->fd = rsp.store_fd(); decShmUnit_->mmapSize = rsp.mmap_size(); decShmUnit_->offset = static_cast(rsp.offset()); - decShmUnit_->id = rsp.shm_id(); + decShmUnit_->id = ShmKey::Intern(rsp.shm_id()); } } @@ -338,6 +344,7 @@ Status ClientWorkerCommonApi::RegisterClient(RegisterClientReqPb &req, int32_t t req.set_shm_enabled(shmEnabled_); req.set_tenant_id(tenantId_); req.set_enable_cross_node(enableCrossNodeConnection_); + req.set_enable_exclusive_connection(enableExclusiveConnection_); req.set_pod_name(Logging::PodName()); RegisterClientRspPb rsp; @@ -373,6 +380,11 @@ Status ClientWorkerCommonApi::RegisterClient(RegisterClientReqPb &req, int32_t t workerUuid_ = rsp.worker_uuid(); workerEnableP2Ptransfer_ = rsp.enable_p2p_transfer(); SetHealthy(!rsp.unhealthy()); + exclusiveConnSockPath_ = rsp.exclusive_conn_sockpath(); + if (enableExclusiveConnection_ && exclusiveConnSockPath_.empty()) { + LOG(WARNING) << "Client requested exclusive connection, but the older worker did not support the feature."; + enableExclusiveConnection_ = false; + } std::vector heartBeatTimeoutMsOptions = { static_cast(timeoutMs), MAX_HEARTBEAT_TIMEOUT_MS }; uint64_t clientDeadTimeoutMs = rsp.client_dead_timeout_s() * TO_MILLISECOND; diff --git a/src/datasystem/client/client_worker_common_api.h b/src/datasystem/client/client_worker_common_api.h index 0158a4e09f7eb6bc1aa2ba411c2b7d86441e5ee6..ce3fab7ea846af51b131e9606a66a82088f1b628 100644 --- a/src/datasystem/client/client_worker_common_api.h +++ b/src/datasystem/client/client_worker_common_api.h @@ -65,7 +65,7 @@ public: explicit ClientWorkerCommonApi(HostPort hostPort, RpcCredential cred = {}, HeartbeatType heartbeatType = HeartbeatType::RPC_HEARTBEAT, Signature *signature = nullptr, std::string tenantId = "", - bool enableCrossNodeConnection = false); + bool enableCrossNodeConnection = false, bool enableExclusiveConnection = false); virtual ~ClientWorkerCommonApi(); @@ -447,6 +447,7 @@ protected: int32_t timeoutMs_{ 0 }; int32_t rpcTimeoutMs_{ 0 }; bool enableCrossNodeConnection_{ false }; + bool enableExclusiveConnection_{ false }; int64_t heartBeatTimeoutMs_{ 0 }; int64_t heartBeatIntervalMs_{ MIN_HEARTBEAT_INTERVAL_MS }; uint64_t clientDeadTimeoutMs_{ 0 }; @@ -472,6 +473,9 @@ protected: } }; RecvClientFdState recvClientFdState_; + static std::atomic exclusiveIdGen_; + std::optional exclusiveId_; + std::string exclusiveConnSockPath_; }; } // namespace client } // namespace datasystem diff --git a/src/datasystem/client/object_cache/client_worker_api.cpp b/src/datasystem/client/object_cache/client_worker_api.cpp index 9b4e9709450de1b9dd78244aadb6cb8148cce8e1..a8353e2f31b7a73e3cc98b21bfa03da1cff5cb0c 100644 --- a/src/datasystem/client/object_cache/client_worker_api.cpp +++ b/src/datasystem/client/object_cache/client_worker_api.cpp @@ -28,6 +28,7 @@ #include "datasystem/common/rpc/rpc_auth_key_manager.h" #include "datasystem/common/rpc/rpc_constants.h" #include "datasystem/common/rpc/unix_sock_fd.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/rpc_util.h" #include "datasystem/common/util/raii.h" @@ -50,15 +51,21 @@ static constexpr uint32_t BIT_NUM_OF_INT = 32; const std::unordered_set RETRY_ERROR_CODE{ StatusCode::K_TRY_AGAIN, StatusCode::K_RPC_CANCELLED, StatusCode::K_RPC_DEADLINE_EXCEEDED, StatusCode::K_RPC_UNAVAILABLE, StatusCode::K_OUT_OF_MEMORY }; -static constexpr uint64_t P2P_TIMEOUT_MS = 60'000; -constexpr uint64_t P2P_SUBSCRIBE_TIMEOUT_MS = 20'000; +static constexpr uint64_t P2P_TIMEOUT_MS = 60000; +constexpr uint64_t P2P_SUBSCRIBE_TIMEOUT_MS = 20000; ClientWorkerApi::ClientWorkerApi(HostPort hostPort, RpcCredential cred, HeartbeatType heartbeatType, Signature *signature, std::string tenantId, - bool enableCrossNodeConnection) + bool enableCrossNodeConnection, bool enableExclusiveConnection) : ClientWorkerCommonApi(hostPort, cred, heartbeatType, signature, std::move(tenantId), - enableCrossNodeConnection) + enableCrossNodeConnection, enableExclusiveConnection) { + if (enableExclusiveConnection) { + // Assign a value and then bump the counter. This id is a client-side-only identifier, a bit like a + // client id but lighter weight for performance sensitive comparisons (existing client id is a large + // string and costly for lookups and string compare) + exclusiveId_ = exclusiveIdGen_++; + } } Status ClientWorkerApi::Init(int32_t timeoutMs) @@ -75,6 +82,10 @@ Status ClientWorkerApi::Init(int32_t timeoutMs) timeoutMs = std::min(clientDeadTimeoutMs_, static_cast(timeoutMs)); } stub_ = std::make_unique(channel, timeoutMs); + if (enableExclusiveConnection_ && exclusiveId_.has_value()) { + // Note: exclusiveConnSockPath_ will be initialized during client register call driven from base class Init() + stub_->SetExclusiveConnInfo(exclusiveId_, exclusiveConnSockPath_); + } return Status::OK(); } @@ -138,7 +149,7 @@ Status ClientWorkerApi::Create(const std::string &objectKey, int64_t dataSize, u shmBuf->fd = rsp.store_fd(); shmBuf->mmapSize = rsp.mmap_size(); shmBuf->offset = static_cast(rsp.offset()); - shmBuf->id = rsp.shm_id(); + shmBuf->id = ShmKey::Intern(rsp.shm_id()); metadataSize = rsp.metadata_size(); version = workerVersion_.load(std::memory_order_relaxed); return Status::OK(); @@ -150,6 +161,9 @@ Status ClientWorkerApi::MultiCreate(bool skipCheckExistence, std::vector(createParams.size()); + req.mutable_object_key()->Reserve(sz); + req.mutable_data_size()->Reserve(sz); for (auto ¶m : createParams) { req.add_object_key(param.objectKey); req.add_data_size(param.dataSize); @@ -173,24 +187,30 @@ Status ClientWorkerApi::MultiCreate(bool skipCheckExistence, std::vector(rsp.results().size()), K_INVALID, FormatString("The length of objectKeyList (%zu) and dataSizeList (%zu) should be the same.", createParams.size(), rsp.results().size())); - auto rspExists = rsp.exists(); - CHECK_FAIL_RETURN_STATUS(static_cast(rspExists.size()) == createParams.size(), K_INVALID, - "The size of rspExists is not consistent with createParams"); - exists.reserve(createParams.size()); - for (auto val : rspExists) { - exists.emplace_back(val); - } - for (auto res : rsp.results()) { - if (!res.shm_id().empty()) { - useShmTransfer = true; - break; + if (!skipCheckExistence) { + CHECK_FAIL_RETURN_STATUS(static_cast(rsp.exists_size()) == createParams.size(), K_INVALID, + "The size of rspExists is not consistent with createParams"); + for (int i = 0; i < rsp.exists_size(); i++) { + exists[i] = rsp.exists(i); } } + auto checkUseShm = [&rsp, &skipCheckExistence]() { + if (skipCheckExistence) { + return true; + } + for (const auto& res : rsp.results()) { + if (!res.shm_id().empty()) { + return true; + } + } + return false; + }; + useShmTransfer = checkUseShm(); if (!useShmTransfer) { return Status::OK(); } for (auto i = 0ul; i < createParams.size(); i++) { - if (exists[i]) { + if (!skipCheckExistence && exists[i]) { continue; } auto &shmBuf = createParams[i].shmBuf; @@ -198,7 +218,7 @@ Status ClientWorkerApi::MultiCreate(bool skipCheckExistence, std::vectorfd = subRsp.store_fd(); shmBuf->mmapSize = subRsp.mmap_size(); shmBuf->offset = static_cast(subRsp.offset()); - shmBuf->id = subRsp.shm_id(); + shmBuf->id = ShmKey::Intern(subRsp.shm_id()); createParams[i].metadataSize = subRsp.metadata_size(); } version = workerVersion_.load(std::memory_order_relaxed); @@ -259,6 +279,7 @@ Status ClientWorkerApi::Get(const GetParam &getParam, uint32_t &version, GetRspP req.set_no_query_l2cache(!getParam.queryL2Cache); req.set_sub_timeout(ClientGetRequestTimeout(subTimeoutMs)); req.set_client_id(GetClientId()); + req.set_return_object_index(true); PerfPoint perfPoint(PerfKey::RPC_CLIENT_GET_OBJECT); int64_t rpcTimeout = std::max(subTimeoutMs, rpcTimeoutMs_); @@ -394,10 +415,11 @@ Status ClientWorkerApi::MultiPublish(const std::vector(param.existence)); req.set_is_replica(param.isReplica); - req.set_auto_release_memory_ref(!bufferInfo[0]->shmId.empty()); + req.set_auto_release_memory_ref(!bufferInfo[0]->shmId.Empty()); std::vector payloads; + req.mutable_object_info()->Reserve(static_cast(bufferInfo.size())); for (size_t i = 0; i < bufferInfo.size(); ++i) { - if (bufferInfo[i]->shmId.empty()) { + if (bufferInfo[i]->shmId.Empty()) { payloads.emplace_back(bufferInfo[i]->pointer, bufferInfo[i]->dataSize); } MultiPublishReqPb::ObjectInfoPb objectInfoPb; @@ -476,7 +498,7 @@ Status ClientWorkerApi::CheckShmFutexResult(uint32_t *waitFlag, uint32_t waitNum return ShmCircularQueue::CheckFutexErrno(result); } -Status ClientWorkerApi::DecreaseWorkerRefByShm(const std::string &shmId, const std::function &connectCheck) +Status ClientWorkerApi::DecreaseWorkerRefByShm(const ShmKey &shmId, const std::function &connectCheck) { RETURN_RUNTIME_ERROR_IF_NULL(decreaseRPCQ_); std::string decElement; @@ -532,7 +554,7 @@ Status ClientWorkerApi::DecreaseWorkerRefByShm(const std::string &shmId, const s return Status::OK(); } -Status ClientWorkerApi::DecreaseWorkerRef(const std::vector &objectKeys) +Status ClientWorkerApi::DecreaseWorkerRef(const std::vector &objectKeys) { DecreaseReferenceRequest req; req.set_client_id(GetClientId()); diff --git a/src/datasystem/client/object_cache/client_worker_api.h b/src/datasystem/client/object_cache/client_worker_api.h index ecaf9b2ba4aa27d727b5dbe7661342cbf4c5b5fd..ed7c43fae8f21541c879db76d9cdf540a929b09b 100644 --- a/src/datasystem/client/object_cache/client_worker_api.h +++ b/src/datasystem/client/object_cache/client_worker_api.h @@ -30,6 +30,7 @@ #include "datasystem/client/mmap_table.h" #include "datasystem/common/ak_sk/signature.h" #include "datasystem/common/object_cache/object_base.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/fd_pass.h" #include "datasystem/common/util/net_util.h" #include "datasystem/common/util/queue/shm_circular_queue.h" @@ -89,11 +90,12 @@ public: * @param[in] signature Used to do AK/SK authenticate. * @param[in] tenantId The tenant id. * @param[in] enableCrossNodeConnection Indicates whether the client can connect to the standby node. + * @param[in] enableExclusiveConnection Indicates whether the client will use exclusive, per-thread connections */ explicit ClientWorkerApi(HostPort hostPort, RpcCredential cred, HeartbeatType heartbeatType = HeartbeatType::RPC_HEARTBEAT, Signature *signature = nullptr, std::string tenantId = "", - bool enableCrossNodeConnection = false); + bool enableCrossNodeConnection = false, bool enableExclusiveConnection = false); /** * @brief Initialize ClientWorkerApi. @@ -178,7 +180,7 @@ public: * @param[in] connectCheck the connect check with local server. * @return Status of the call. */ - Status DecreaseWorkerRefByShm(const std::string &shmId, const std::function &connectCheck); + Status DecreaseWorkerRefByShm(const ShmKey &shmId, const std::function &connectCheck); /** * @brief Wakes up all waiting processes in the shared memory queue and clears the pointers to the queue. @@ -191,7 +193,7 @@ public: * @param[in] objectKey The ID of the object to decrease ref. * @return Status of the call. */ - Status DecreaseWorkerRef(const std::vector &objectKeys); + Status DecreaseWorkerRef(const std::vector &objectKeys); /** * @brief Send getting object rpc request to worker. diff --git a/src/datasystem/client/object_cache/object_client_impl.cpp b/src/datasystem/client/object_cache/object_client_impl.cpp index 35c7ccc5e52f892a23c2d63db8803160aa683f7f..2289b6b59e6816323f651a95cb4719fca07ca038 100644 --- a/src/datasystem/client/object_cache/object_client_impl.cpp +++ b/src/datasystem/client/object_cache/object_client_impl.cpp @@ -50,7 +50,9 @@ #include "datasystem/common/log/logging.h" #include "datasystem/common/log/trace.h" #include "datasystem/common/log/spdlog/provider.h" +#include "datasystem/common/parallel/parallel_for.h" #include "datasystem/common/rpc/rpc_constants.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/memory.h" #include "datasystem/common/util/net_util.h" @@ -72,6 +74,9 @@ const size_t BATCH_SET_MAX_KEY_COUNT = 2000; static constexpr size_t OBJ_META_MAX_SIZE_LIMIT = 64; static constexpr size_t QUERY_SIZE_OBJECT_LIMIT = 10000; const std::string K_SEPARATOR = "$"; +const std::string CLIENT_PARALLEL_THREAD_MIN_NUM_ENV = "CLIENT_PARALLEL_THREAD_MIN_NUM"; +const std::string CLIENT_PARALLEL_THREAD_MAX_NUM_ENV = "CLIENT_PARALLEL_THREAD_MAX_NUM"; +const std::string CLIENT_MEMORY_COPY_THREAD_NUM_ENV = "CLIENT_MEMORY_COPY_THREAD_NUM"; namespace datasystem { inline void ReadFromEnv(std::string ¶m, std::string env) @@ -130,6 +135,7 @@ ObjectClientImpl::ObjectClientImpl(const ConnectOptions &connectOptions1) timeoutMs_ = connectOptions.connectTimeoutMs; tenantId_ = connectOptions.tenantId; signature_ = std::make_unique(connectOptions.accessKey, connectOptions.secretKey); + enableExclusiveConnection_ = connectOptions.enableExclusiveConnection; enableCrossNodeConnection_ = connectOptions.enableCrossNodeConnection; (void)authKeys_.SetClientPublicKey(connectOptions.clientPublicKey); (void)authKeys_.SetClientPrivateKey(connectOptions.clientPrivateKey); @@ -183,6 +189,7 @@ Status ObjectClientImpl::ShutDown(bool &needRollbackState, bool isDestruct) } } } + asyncReleasePool_.reset(); } // The destructor of devOcImpl_ should occur after the client disconnect request so that the device asynchronous @@ -211,7 +218,7 @@ Status ObjectClientImpl::Init(bool &needRollbackState, bool enableHeartbeat) workerApi_.resize(STANDBY2_WORKER + 1); workerApi_[LOCAL_WORKER] = std::make_shared(hostPort, cred_, heartbeatType, signature_.get(), tenantId_, - enableCrossNodeConnection_); + enableCrossNodeConnection_, enableExclusiveConnection_); RETURN_IF_NOT_OK(workerApi_[LOCAL_WORKER]->Init(timeoutMs_)); mmapManager_ = std::make_unique(workerApi_[LOCAL_WORKER]); memoryCopyThreadPool_ = std::make_shared(0, GetRecommendedMemoryCopyThreadsNum()); @@ -221,6 +228,7 @@ Status ObjectClientImpl::Init(bool &needRollbackState, bool enableHeartbeat) asyncGetRPCPool_ = std::make_shared(0, threadCount, "async_get_rpc"); asyncSwitchWorkerPool_ = std::make_shared(0, 1, "switch"); asyncDevDeletePool_ = std::make_shared(0, threadCount); + asyncReleasePool_ = std::make_shared(0, 1, "async_release_buffer"); std::shared_ptr decShmUnit; if (workerApi_[LOCAL_WORKER]->GetShmQueueUnit(decShmUnit)) { RETURN_IF_NOT_OK(mmapManager_->LookupUnitsAndMmapFd("", decShmUnit)); @@ -245,9 +253,33 @@ Status ObjectClientImpl::Init(bool &needRollbackState, bool enableHeartbeat) devOcImpl_ = std::make_unique(this); RETURN_IF_NOT_OK(devOcImpl_->Init()); StartPerfThread(); + InitParallelFor(); return Status::OK(); } +void ObjectClientImpl::InitParallelFor() +{ + static const int defaultThreadNum = 4; + auto getEnvInt = [](const std::string &envName, int defaultValue) -> int { + const char* val = std::getenv(envName.c_str()); + int result = defaultValue; + if (val && !Uri::StrToInt(val, result)) { + result = defaultValue; + } + return result; + }; + parallismNum_ = getEnvInt(CLIENT_MEMORY_COPY_THREAD_NUM_ENV, defaultThreadNum); + int minThreadNum = getEnvInt(CLIENT_PARALLEL_THREAD_MIN_NUM_ENV, defaultThreadNum); + minThreadNum = minThreadNum < parallismNum_ ? parallismNum_ : minThreadNum; + int maxThreadNum = getEnvInt(CLIENT_PARALLEL_THREAD_MAX_NUM_ENV, minThreadNum); + LOG(INFO) << FormatString("Init parallel for with parallismNum: %d, minThreadNum: %d, maxThreadNum: %d", + parallismNum_, minThreadNum, maxThreadNum); + if (minThreadNum == 0) { + return; + } + datasystem::Parallel::InitParallelThreadPool(minThreadNum, maxThreadNum); +} + void ObjectClientImpl::MGetAsyncRpcThread(const std::shared_ptr &resourcePtr) { auto result = resourcePtr->rpcFuture.get(); @@ -413,7 +445,7 @@ bool ObjectClientImpl::SwitchToStandbyWorkerImpl(const std::shared_ptrGetHeartbeatType(); workerApi_[next] = std::make_shared(standbyWorker, cred_, heartbeatType, signature_.get(), tenantId_, - enableCrossNodeConnection_); + enableCrossNodeConnection_, enableExclusiveConnection_); workerApi_[next]->SetIsUseStandbyWorker(true); Status rc = workerApi_[next]->Init(timeoutMs_); if (rc.IsError()) { @@ -493,7 +525,7 @@ bool ObjectClientImpl::WaitStandbyWorkerReady(const std::shared_ptrGetWorkHost(), clientWorkerApi->GetWorkPort()); - constexpr uint64_t maxWaitMilliseconds = 10'000; + constexpr uint64_t maxWaitMilliseconds = 10000; constexpr uint64_t waitIntervalMs = 500; uint64_t waitMilliseconds = std::min(clientWorkerApi->GetHeartBeatInterval() * 2, maxWaitMilliseconds); Timer timer; @@ -629,6 +661,7 @@ Status ObjectClientImpl::DeviceDataCreate(const std::vector &object BlobListInfo blobInfo; RETURN_IF_NOT_OK(PrepareDataSizeList(dataSizeList, devBlobList, blobInfo)); LOG(INFO) << blobInfo.ToString(true); + exists.resize(objectKeys.size(), false); RETURN_IF_NOT_OK(MultiCreate(objectKeys, dataSizeList, param, false, bufferList, exists)); std::vector> filterBufferList; std::vector filterDevBlobList; @@ -839,25 +872,21 @@ Status ObjectClientImpl::Create(const std::string &objectKey, uint64_t dataSize, Status ObjectClientImpl::ConstructMultiCreateParam(const std::vector &objectKeyList, const std::vector &dataSizeList, std::vector> &bufferList, - std::vector &multiCreateParamList) + std::vector &multiCreateParamList, + uint64_t &dataSizeSum) { - CHECK_FAIL_RETURN_STATUS(objectKeyList.size() == dataSizeList.size(), K_INVALID, + auto sz = objectKeyList.size(); + CHECK_FAIL_RETURN_STATUS(sz == dataSizeList.size(), K_INVALID, "The length of objectKeyList and dataSizeList should be the same."); - for (size_t i = 0; i < objectKeyList.size(); i++) { + multiCreateParamList.reserve(sz); + for (size_t i = 0; i < sz; i++) { auto &objectKey = objectKeyList[i]; auto dataSize = dataSizeList[i]; - CHECK_FAIL_RETURN_STATUS(!objectKey.empty(), K_INVALID, "The objectKey is empty"); - RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); CHECK_FAIL_RETURN_STATUS(dataSize > 0, K_INVALID, "The dataSize value should be bigger than zero."); + dataSizeSum += dataSize; + multiCreateParamList.emplace_back(i, objectKey, dataSize); } - bufferList.resize(objectKeyList.size()); - // if total size >=500k , transfer by shm - size_t index = 0; - std::for_each(objectKeyList.begin(), objectKeyList.end(), - [&index, &multiCreateParamList, &dataSizeList](const std::string &objectKey) { - multiCreateParamList.emplace_back(index, objectKey, dataSizeList[index]); - index++; - }); + bufferList.resize(sz); return Status::OK(); } @@ -873,21 +902,21 @@ Status ObjectClientImpl::MultiCreate(const std::vector &objectKeyLi std::shared_lock lck(memoryRefMutex_); std::vector multiCreateParamList; - RETURN_IF_NOT_OK(ConstructMultiCreateParam(objectKeyList, dataSizeList, bufferList, multiCreateParamList)); + uint64_t dataSizeSum = 0; + RETURN_IF_NOT_OK( + ConstructMultiCreateParam(objectKeyList, dataSizeList, bufferList, multiCreateParamList, dataSizeSum)); // If failed with create, need to rollback. auto version = 0u; auto useShmTransfer = false; - auto sizeSum = std::accumulate(dataSizeList.begin(), dataSizeList.end(), 0); - if (!skipCheckExistence || static_cast(sizeSum) >= workerApi_[LOCAL_WORKER]->GetShmThreshold()) { + if (dataSizeSum >= workerApi_[LOCAL_WORKER]->GetShmThreshold()) { RETURN_IF_NOT_OK(workerApi_[LOCAL_WORKER]->MultiCreate(skipCheckExistence, multiCreateParamList, version, exists, useShmTransfer)); } else { exists.resize(objectKeyList.size(), false); } - if (!useShmTransfer) { for (size_t i = 0; i < objectKeyList.size(); i++) { - if (exists[i]) { + if (!skipCheckExistence && exists[i]) { continue; } auto &objectKey = objectKeyList[i]; @@ -919,7 +948,7 @@ Status ObjectClientImpl::MultiCreate(const std::vector &objectKeyLi }); Status injectRC = Status::OK(); for (auto &createParam : multiCreateParamList) { - if (exists[createParam.index]) { + if (!skipCheckExistence && exists[createParam.index]) { continue; } PerfPoint mmapPoint(PerfKey::CLIENT_LOOK_UP_MMAP_FD); @@ -951,7 +980,7 @@ Status ObjectClientImpl::MultiCreate(const std::vector &objectKeyLi void ObjectClientImpl::BatchReleaseBufferPtr(const std::vector &buffers) { - std::vector> shmInfos; + std::vector> shmInfos; for (auto &buffer : buffers) { if (!buffer || !buffer->isShm_) { @@ -963,17 +992,17 @@ void ObjectClientImpl::BatchReleaseBufferPtr(const std::vector &buffer BatchDecreaseRefCnt(shmInfos); } -void ObjectClientImpl::BatchDecreaseRefCnt(const std::vector> &shmInfos) +void ObjectClientImpl::BatchDecreaseRefCnt(const std::vector> &shmInfos) { - auto DecreaseRefCnt = [this](const std::vector> &shmInfos) { + auto DecreaseRefCnt = [this](const std::vector> &shmInfos) { std::shared_lock lck(memoryRefMutex_); std::vector> batchLock; - std::vector descreaseShms; + std::vector descreaseShms; for (auto &info : shmInfos) { if (!IsBufferAlive(info.second)) { continue; } - const std::string &shmId = info.first; + const auto &shmId = info.first; auto accessorPtr = std::make_shared(); auto &accessor = *accessorPtr; auto found = memoryRefCount_.find(accessor, shmId); @@ -1011,11 +1040,11 @@ void ObjectClientImpl::BatchDecreaseRefCnt(const std::vectorGetClientId(), shmId); - auto DecreaseRefCnt = [this](const std::string &shmId, bool isShm) { + auto DecreaseRefCnt = [this](const ShmKey &shmId, bool isShm) { std::shared_lock lck(memoryRefMutex_); TbbMemoryRefTable::accessor accessor; auto found = memoryRefCount_.find(accessor, shmId); @@ -1355,12 +1384,12 @@ Status ObjectClientImpl::SetShmObjectBuffer(const std::string &objectKey, const param.cacheType = CacheType(info.cache_type()); ObjectBufferInfo bufferInfo = SetObjectBufferInfo(objectKey, pointer, info.data_size(), info.metadata_size(), param, info.is_seal(), version, - info.shm_id(), nullptr, std::move(mmapEntry)); + ShmKey::Intern(info.shm_id()), nullptr, std::move(mmapEntry)); std::shared_lock lck(memoryRefMutex_); // Update shared memory reference count. TbbMemoryRefTable::accessor accessor; - auto found = memoryRefCount_.insert(accessor, info.shm_id()); + auto found = memoryRefCount_.insert(accessor, ShmKey::Intern(info.shm_id())); accessor->second = (found ? 1 : accessor->second + 1); return Buffer::CreateBuffer(bufferInfo, shared_from_this(), buffer); } @@ -1383,7 +1412,7 @@ Status ObjectClientImpl::MmapShmUnit(int64_t fd, uint64_t mmapSize, ptrdiff_t of ObjectBufferInfo ObjectClientImpl::SetObjectBufferInfo(const std::string &objectKey, uint8_t *pointer, uint64_t size, uint64_t metaSize, const FullParam ¶m, bool isSeal, - uint32_t version, const std::string &shmId, + uint32_t version, const ShmKey &shmId, const std::shared_ptr &payloadPointer, std::shared_ptr mmapEntry) { @@ -1459,7 +1488,17 @@ Status ObjectClientImpl::GetObjectBuffers(const std::vector &object const std::string &objectKey = objectsNeedToGet[index]; Status status; std::shared_ptr &bufferPtr = buffers[i + j]; - if (i < shmCount && objectKey == rsp.objects(i).object_key()) { + bool isShm = false; + bool isNoShm = false; + if (i < shmCount) { + isShm = rsp.objects(i).object_key().empty() ? index == rsp.objects(i).object_index() + : objectKey == rsp.objects(i).object_key(); + } + if (j < noShmCount) { + isNoShm = rsp.payload_info(j).object_key().empty() ? index == rsp.payload_info(j).object_index() + : objectKey == rsp.payload_info(j).object_key(); + } + if (isShm) { const GetRspPb::ObjectInfoPb &info = rsp.objects(i); i++; if (info.store_fd() == -1) { @@ -1472,7 +1511,7 @@ Status ObjectClientImpl::GetObjectBuffers(const std::vector &object status = SetOffsetReadObjectBuffer(objectKey, info, version, readParams[index].offset, readParams[index].size, bufferPtr); } - } else if (j < noShmCount && objectKey == rsp.payload_info(j).object_key()) { + } else if (isNoShm) { const GetRspPb::PayloadInfoPb &payloadInfo = rsp.payload_info(j); status = SetNonShmObjectBuffer(objectKey, payloadInfo, version, payloads, bufferPtr); j++; @@ -1558,14 +1597,14 @@ Status ObjectClientImpl::SetOffsetReadObjectBuffer(const std::string &objectKey, param.cacheType = CacheType(info.cache_type()); ObjectBufferInfo bufferInfo = SetObjectBufferInfo(objectKey, pointer, info.data_size(), info.metadata_size(), param, info.is_seal(), version, - info.shm_id(), nullptr, std::move(mmapEntry)); + ShmKey::Intern(info.shm_id()), nullptr, std::move(mmapEntry)); std::shared_lock lck(memoryRefMutex_); // Update shared memory reference count. std::shared_ptr tmpbuffer; { TbbMemoryRefTable::accessor accessor; - auto found = memoryRefCount_.insert(accessor, info.shm_id()); + auto found = memoryRefCount_.insert(accessor, ShmKey::Intern(info.shm_id())); accessor->second = (found ? 1 : accessor->second + 1); RETURN_IF_NOT_OK(Buffer::CreateBuffer(bufferInfo, shared_from_this(), tmpbuffer)); } @@ -1884,22 +1923,35 @@ Status ObjectClientImpl::Set(const StringView &val, const SetParam &setParam, st Status ObjectClientImpl::CheckMultiSetInputParamValidationNtx(const std::vector &keys, const std::vector &vals, std::vector &outFailedKeys, - std::map &kv) + std::vector &deduplicateKeys, + std::vector &deduplicateVals) { - CHECK_FAIL_RETURN_STATUS(keys.size() > 0, K_INVALID, "The keys should not be empty."); + std::unordered_set keySet; + keySet.reserve(keys.size()); + CHECK_FAIL_RETURN_STATUS(!keys.empty(), K_INVALID, "The keys should not be empty."); CHECK_FAIL_RETURN_STATUS(keys.size() == vals.size(), K_INVALID, "The number of key and value is not the same."); + RETURN_IF_NOT_OK(CheckValidObjectKey(*keys.begin())); for (size_t i = 0; i < keys.size(); ++i) { CHECK_FAIL_RETURN_STATUS(!keys[i].empty(), K_INVALID, "The key should not be empty."); - RETURN_IF_NOT_OK(CheckValidObjectKey(keys[i])); CHECK_FAIL_RETURN_STATUS(vals[i].data() != nullptr, K_INVALID, FormatString("The value associated with key %s should not be empty.", keys[i])); - if (kv.find(keys[i]) == kv.end()) { - kv[keys[i]] = vals[i]; - } else { + auto [it, inserted] = keySet.emplace(keys[i]); + (void)it; + if (!inserted) { LOG(ERROR) << "The input parameter contains duplicate key " << keys[i]; outFailedKeys.emplace_back(keys[i]); } } + if (!outFailedKeys.empty()) { + for (size_t i = 0; i < keys.size(); ++i) { + if (keySet.find(keys[i]) == keySet.end()) { + continue; + } + deduplicateKeys.emplace_back(keys[i]); + deduplicateVals.emplace_back(vals[i]); + keySet.erase(keys[i]); + } + } return Status::OK(); } @@ -1915,9 +1967,9 @@ Status ObjectClientImpl::CheckMultiSetInputParamValidation(const std::vector keyRecord; + RETURN_IF_NOT_OK(CheckValidObjectKey(*keys.begin())); for (size_t i = 0; i < keys.size(); ++i) { CHECK_FAIL_RETURN_STATUS(!keys[i].empty(), K_INVALID, "The key should not be empty."); - RETURN_IF_NOT_OK(CheckValidObjectKey(keys[i])); CHECK_FAIL_RETURN_STATUS(vals[i].data() != nullptr, K_INVALID, FormatString("The value associated with key %s should not be empty.", keys[i])); CHECK_FAIL_RETURN_STATUS(kv.find(keys[i]) == kv.end(), K_INVALID, @@ -1973,12 +2025,48 @@ Status ObjectClientImpl::AllocateMemoryForMSet(const std::map &keys, + const std::vector &vals, const FullParam &creatParam, + std::vector> &bufferList, + std::vector> &bufferInfoList) +{ + const int sz = static_cast(bufferList.size()); + auto memoryCopy = [&](int start, int end) { + for (int i = start; i < end; i++) { + auto &buffer = bufferList[i]; + if (buffer == nullptr) { + bufferInfoList[i] = std::make_shared( + SetObjectBufferInfo(keys[i], reinterpret_cast(const_cast(vals[i].data())), + vals[i].size(), 0, creatParam, false, 0)); + continue; + } + RETURN_IF_NOT_OK(buffer->CheckDeprecated()); + CHECK_FAIL_RETURN_STATUS(!buffer->bufferInfo_->isSeal, K_OC_ALREADY_SEALED, + "Client object is already sealed"); + RETURN_IF_NOT_OK(buffer->MemoryCopy(vals[i].data(), vals[i].size())); + bufferInfoList[i] = buffer->bufferInfo_; + } + return Status::OK(); + }; + if (!isParallel || parallismNum_ == 0) { + return memoryCopy(0, sz); + } + int workerNum = parallismNum_; + size_t chunkSize = 4; + if (sz <= parallismNum_) { + workerNum = sz; + chunkSize = 1; + } + return Parallel::ParallelFor(0, bufferInfoList.size(), memoryCopy, chunkSize, workerNum); +} + Status ObjectClientImpl::MSet(const std::vector &keys, const std::vector &vals, const MSetParam ¶m, std::vector &outFailedKeys) { RETURN_IF_NOT_OK(IsClientReady()); - std::map kv; - RETURN_IF_NOT_OK(CheckMultiSetInputParamValidationNtx(keys, vals, outFailedKeys, kv)); + std::vector deduplicateKeys; + std::vector deduplicateVals; + RETURN_IF_NOT_OK(CheckMultiSetInputParamValidationNtx(keys, vals, outFailedKeys, deduplicateKeys, deduplicateVals)); std::shared_ptr workerApi; std::unique_ptr raii; RETURN_IF_NOT_OK(GetAvailableWorkerApi(workerApi, raii)); @@ -1987,47 +2075,43 @@ Status ObjectClientImpl::MSet(const std::vector &keys, const std::v creatParam.writeMode = param.writeMode; creatParam.consistencyType = ConsistencyType::CAUSAL; creatParam.cacheType = param.cacheType; - std::vector filteredKeys; - std::vector filteredValues; - filteredKeys.reserve(kv.size()); - filteredValues.reserve(kv.size()); - for (const auto &key : keys) { - if (kv.find(key) != kv.end()) { - filteredKeys.emplace_back(key); - filteredValues.emplace_back(kv[key]); - } - } + const std::vector &filteredKeys = deduplicateKeys.empty() ? keys : deduplicateKeys; + const std::vector &filteredValues = deduplicateVals.empty() ? vals : deduplicateVals; PerfPoint point(PerfKey::CLIENT_MSET_MULTICREATE); std::vector dataSizeList; + uint64_t dataSizeSum = 0; + dataSizeList.reserve(filteredValues.size()); for (const auto &val : filteredValues) { dataSizeList.emplace_back(val.size()); + dataSizeSum += val.size(); } std::vector> bufferList; std::vector exist; RETURN_IF_NOT_OK(MultiCreate(filteredKeys, dataSizeList, creatParam, true, bufferList, exist)); - std::vector> bufferInfoList; - bufferInfoList.reserve(bufferList.size()); + std::vector> bufferInfoList(bufferList.size()); + static const int minSizeThreshold = 500 * KB; + static const int sizeThreshold = 4 * MB_TO_BYTES; + static const int countThreshold = 32; + bool isParallel = + dataSizeSum > minSizeThreshold && (dataSizeSum >= sizeThreshold || filteredKeys.size() >= countThreshold); point.RecordAndReset(PerfKey::CLIENT_MSET_MEMCOPY); - auto idx = 0; - for (auto &buffer : bufferList) { - if (buffer == nullptr) { - bufferInfoList.emplace_back(std::make_shared(SetObjectBufferInfo( - filteredKeys[idx], reinterpret_cast(const_cast(filteredValues[idx].data())), - filteredValues[idx].size(), 0, creatParam, false, 0))); - continue; - } - RETURN_IF_NOT_OK(buffer->CheckDeprecated()); - CHECK_FAIL_RETURN_STATUS(!buffer->bufferInfo_->isSeal, K_OC_ALREADY_SEALED, "Client object is already sealed"); - RETURN_IF_NOT_OK(buffer->MemoryCopy(filteredValues[idx].data(), filteredValues[idx].size())); - bufferInfoList.emplace_back(buffer->bufferInfo_); - idx++; - } + RETURN_IF_NOT_OK( + MemoryCopyParallel(isParallel, filteredKeys, filteredValues, creatParam, bufferList, bufferInfoList)); point.RecordAndReset(PerfKey::CLIENT_MSET_MULTI_PUBLSIH); MultiPublishRspPb rsp; PublishParam publishParam{ .isTx = false, .isReplica = false, .existence = param.existence, .ttlSecond = param.ttlSecond }; RETURN_IF_NOT_OK(workerApi->MultiPublish(bufferInfoList, publishParam, rsp)); + asyncReleasePool_->Execute([this, buffers = std::move(bufferList)]() mutable { + std::shared_lock shutdownLck(shutdownMux_); + if (!IsClientReady()) { + return; + } + for (const auto &buf : buffers) { + buf->Release(); + } + }); for (const auto &objKey : rsp.failed_object_keys()) { outFailedKeys.emplace_back(objKey); } diff --git a/src/datasystem/client/object_cache/object_client_impl.h b/src/datasystem/client/object_cache/object_client_impl.h index 40d9a80c4d3c2541bd7e073ae348bf55781c59ee..e329c0d5367c6131fbc9858eb72d98a2b87e1699 100644 --- a/src/datasystem/client/object_cache/object_client_impl.h +++ b/src/datasystem/client/object_cache/object_client_impl.h @@ -42,6 +42,7 @@ #include "datasystem/common/object_cache/object_base.h" #include "datasystem/common/rpc/rpc_credential.h" #include "datasystem/common/rpc/rpc_helper.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/thread_local.h" #include "datasystem/common/util/thread_pool.h" @@ -59,7 +60,7 @@ namespace datasystem { namespace object_cache { -using TbbMemoryRefTable = tbb::concurrent_hash_map; +using TbbMemoryRefTable = tbb::concurrent_hash_map; using TbbGlobalRefTable = tbb::concurrent_hash_map; using GlobalRefInfo = std::pair>; @@ -149,7 +150,7 @@ public: * @param[in] isShm A flag indicating how the object will be published (shm or non-shm). * @param[in] version Worker version. */ - void DecreaseReferenceCnt(const std::string &shmId, bool isShm, uint32_t version = 0); + void DecreaseReferenceCnt(const ShmKey &shmId, bool isShm, uint32_t version = 0); /** * @brief Increase the global reference count to objects in worker. @@ -584,6 +585,11 @@ private: std::vector failList; }; + /** + * @brief Init ParallelFor thread pool. + */ + void InitParallelFor(); + /** * @brief process too handler MGetH2D rpc response. * @param[in] asyncResource A struct type containing a future with rpc, a buffer that cannot be destructured for the @@ -603,7 +609,7 @@ private: Status ConstructMultiCreateParam(const std::vector &objectKeyList, const std::vector &dataSizeList, std::vector> &bufferList, - std::vector &multiCreateParamList); + std::vector &multiCreateParamList, uint64_t &dataSizeSum); /** * @brief For device object, to async get multiple objects @@ -749,7 +755,7 @@ private: * @brief Batch release local memory ref by zmq RPC. * @param[in] shmInfos The shared memory info of buffers. */ - void BatchDecreaseRefCnt(const std::vector> &shmInfos); + void BatchDecreaseRefCnt(const std::vector> &shmInfos); /** * @brief Set the offset read object buffer. @@ -786,7 +792,7 @@ private: */ static ObjectBufferInfo SetObjectBufferInfo(const std::string &objectKey, uint8_t *pointer, uint64_t size, uint64_t metaSize, const FullParam ¶m, bool isSeal, - uint32_t version, const std::string &shmId = {}, + uint32_t version, const ShmKey &shmId = {}, const std::shared_ptr &payloadPointer = nullptr, std::shared_ptr mmapEntry = nullptr); @@ -847,12 +853,10 @@ private: static Status CheckValidObjectKeyVector(const Vec &vec, bool nullable = false) { CHECK_FAIL_RETURN_STATUS(nullable || !vec.empty(), K_INVALID, "The keys are empty"); - size_t index = 0; - for (const auto &objectKey : vec) { - CHECK_FAIL_RETURN_STATUS( - !objectKey.empty(), K_INVALID, FormatString("The objectKey at position %d is empty", index)); - RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); - index++; + if (!vec.empty()) { + CHECK_FAIL_RETURN_STATUS(!vec.begin()->empty(), K_INVALID, + FormatString("The objectKey at position %d is empty", 0)); + RETURN_IF_NOT_OK(CheckValidObjectKey(*vec.begin())); } return Status::OK(); } @@ -1003,14 +1007,15 @@ private: * @param[in] vals The values for the keys. * @param[in] existence Whether set if some key exists. * @param[out] outFailedKeys out failed keys. - * @param[out] kvs key and values for set. + * @param[out] deduplicateKeys the deduplicate key . + * @param[out] deduplicateVals the deduplicate vals . * @return K_OK on success; the error code otherwise. */ Status CheckMultiSetInputParamValidationNtx(const std::vector &keys, const std::vector &vals, std::vector &outFailedKeys, - std::map &kv); - + std::vector &deduplicateKeys, + std::vector &deduplicateVals); /** * @brief Check the validation of the input parameter of the multiple set. * @param[in] keys The keys to be set. @@ -1069,12 +1074,28 @@ private: void StartPerfThread(); void ShutdownPerfThread(); + /** + * @brief Memory copy in parallel or serial mode. + * @param[in] isParallel Enable parallel or not. + * @param[in] keys Object keys. + * @param[in] vals Object values. + * @param[in] creatParam The creating parameter of the buffer. + * @param[in] bufferList The buffer of the objects. + * @param[out] bufferInfoList The buffers information for creating buffers.. + * @return K_OK on success; the error code otherwise. + */ + Status MemoryCopyParallel(bool isParallel, const std::vector &keys, + const std::vector &vals, const FullParam &creatParam, + std::vector> &bufferList, + std::vector> &bufferInfoList); + std::string ipAddress_; RpcAuthKeys authKeys_; RpcCredential cred_; int32_t timeoutMs_; std::string tenantId_; bool enableCrossNodeConnection_ = false; + bool enableExclusiveConnection_ = false; std::unique_ptr signature_{ nullptr }; std::vector> workerApi_; std::atomic currentNode_{ LOCAL_WORKER }; @@ -1100,6 +1121,7 @@ private: std::shared_ptr asyncGetCopyPool_; std::shared_ptr asyncSwitchWorkerPool_; std::shared_ptr asyncDevDeletePool_; + std::shared_ptr asyncReleasePool_; // Listenworker needs to be placed at the bottom to ensure that it is destructed first. std::vector> listenWorker_; @@ -1113,6 +1135,7 @@ private: WaitPost switchPost_; bool clientEnableP2Ptransfer_ = false; + int parallismNum_ = 0; }; } // namespace object_cache } // namespace datasystem diff --git a/src/datasystem/client/stream_cache/producer_impl.cpp b/src/datasystem/client/stream_cache/producer_impl.cpp index 7357fc77fc381dc8eab7a3109ba6de0172aef09c..83eadd9df170624033fa154179de3fc1a8af463b 100644 --- a/src/datasystem/client/stream_cache/producer_impl.cpp +++ b/src/datasystem/client/stream_cache/producer_impl.cpp @@ -487,7 +487,6 @@ Status ProducerImpl::CreatePagePostProcessing(const ShmView &lastPageView, std:: writePage_ = std::move(page); RETURN_IF_NOT_OK(writePage_->RefPage(LogPrefix())); curView_ = writePage_->GetShmView(); - pageId_ = writePage_->GetPageId(); fixPageCount++; // We may (or may not) lock this page. But let's track it in the work area if (WorkAreaIsV2()) { @@ -495,7 +494,8 @@ Status ProducerImpl::CreatePagePostProcessing(const ShmView &lastPageView, std:: } } VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Acquire page id: %s, isSharedPage: %d, lastPageView: %s", - LogPrefix(), pageId_, writePage_->IsSharedPage(), lastPageView.ToStr()); + LogPrefix(), writePage_->GetPageId(), writePage_->IsSharedPage(), + lastPageView.ToStr()); return Status::OK(); } diff --git a/src/datasystem/client/stream_cache/producer_impl.h b/src/datasystem/client/stream_cache/producer_impl.h index 43af25cce130fbb0bef71e143bc7d37a96f13dbc..21ec4d33ee56a9817b095b100352822d9f0fbe16 100644 --- a/src/datasystem/client/stream_cache/producer_impl.h +++ b/src/datasystem/client/stream_cache/producer_impl.h @@ -217,7 +217,6 @@ private: std::unique_ptr unfixTimer_; std::unique_ptr unfixWaitPost_{ nullptr }; std::shared_ptr pageUnit_; - std::string pageId_; ShmView curView_; std::mutex flushMutex_; // Guarantee FIFO, single on-the-fly flush. std::atomic pageDirty_{ false }; diff --git a/src/datasystem/common/CMakeLists.txt b/src/datasystem/common/CMakeLists.txt index 321130c158c981d9624e9d2014d54e45e450f489..b0cc34750e9a6bc08cf44590862485d945e21cf7 100644 --- a/src/datasystem/common/CMakeLists.txt +++ b/src/datasystem/common/CMakeLists.txt @@ -20,3 +20,4 @@ add_subdirectory(l2cache) add_subdirectory(flags) add_subdirectory(signal) add_subdirectory(rdma) +add_subdirectory(parallel) diff --git a/src/datasystem/common/device/ascend/ffts_dispatcher.cpp b/src/datasystem/common/device/ascend/ffts_dispatcher.cpp index 3465000269a3f92aa35afcf1832abf614b6fef1c..29211e3aa9c2d9b4bbf342115c8695f7c5743637 100644 --- a/src/datasystem/common/device/ascend/ffts_dispatcher.cpp +++ b/src/datasystem/common/device/ascend/ffts_dispatcher.cpp @@ -189,7 +189,6 @@ HcclResult FftsDispatcher::ConstructFftsSqe(rtFftsPlusSqe_t &fftsPlusSqe, uint16 // Identifies the communication task and optimizes the FFTS+ scheduling performance. (The RTS requires that the // AIV/AIC task be 0x5B. Otherwise, the task is 0x5A.) // 0x5A: identifies the communication task and optimizes the FFTS+ scheduling performance. - const uint8_t TASK_TYPE_AIV_AIC = 0x5B; const uint8_t TASK_TYPE_OTHER = 0x5A; fftsPlusSqe.subType = TASK_TYPE_OTHER; return HCCL_SUCCESS; diff --git a/src/datasystem/common/immutable_string/immutable_string.cpp b/src/datasystem/common/immutable_string/immutable_string.cpp index f60d8c6d97cc4d2cc1da208232e894e200f40803..8ba1f7e7d112ed83d6e290a67e7a80b467215f3d 100644 --- a/src/datasystem/common/immutable_string/immutable_string.cpp +++ b/src/datasystem/common/immutable_string/immutable_string.cpp @@ -22,61 +22,61 @@ #include "datasystem/common/immutable_string/immutable_string_pool.h" namespace datasystem { -ImmutableString::ImmutableString(const std::string &val) noexcept +ImmutableStringImpl::ImmutableStringImpl(const std::string &val) noexcept { if (!val.empty()) { ImmutableStringPool::Instance().Intern(val, strHandle_); } } -ImmutableString::ImmutableString(const char *cStr) : ImmutableString(std::string(cStr)) +ImmutableStringImpl::ImmutableStringImpl(const char *cStr) : ImmutableStringImpl(std::string(cStr)) { } -std::ostream &operator<<(std::ostream &os, const ImmutableString &obj) +std::ostream &operator<<(std::ostream &os, const ImmutableStringImpl &obj) { os << obj.ToString(); return os; } -size_t ImmutableString::GetHash() const +size_t ImmutableStringImpl::GetHash() const { return strHandle_.ToRefCountStr().GetHash(); } -const RefCountString &ImmutableString::ToRefCountStr() const +const RefCountString &ImmutableStringImpl::ToRefCountStr() const { return strHandle_.ToRefCountStr(); } -const std::string &ImmutableString::ToString() const +const std::string &ImmutableStringImpl::ToString() const { return strHandle_.ToStr(); } -bool ImmutableString::operator==(const ImmutableString &rhs) const +bool ImmutableStringImpl::operator==(const ImmutableStringImpl &rhs) const { const auto &lhsRCString = strHandle_.ToRefCountStr(); const auto &rhsRCString = rhs.strHandle_.ToRefCountStr(); return &lhsRCString == &rhsRCString || lhsRCString == rhsRCString; } -bool ImmutableString::operator!=(const ImmutableString &rhs) const +bool ImmutableStringImpl::operator!=(const ImmutableStringImpl &rhs) const { return this != &rhs && ToString() != rhs.ToString(); } -bool ImmutableString::operator<(const ImmutableString &rhs) const +bool ImmutableStringImpl::operator<(const ImmutableStringImpl &rhs) const { return ToString() < rhs.ToString(); } -const char* ImmutableString::Data() const +const char *ImmutableStringImpl::Data() const { return ToString().data(); } -std::string::size_type ImmutableString::Size() const +std::string::size_type ImmutableStringImpl::Size() const { return ToString().size(); } @@ -84,19 +84,19 @@ std::string::size_type ImmutableString::Size() const } // namespace datasystem namespace std { -size_t hash::operator()(const datasystem::ImmutableString &str) const +size_t hash::operator()(const datasystem::ImmutableStringImpl &str) const { return str.GetHash(); } -bool equal_to::operator()(const datasystem::ImmutableString &lhs, - const datasystem::ImmutableString &rhs) const +bool equal_to::operator()(const datasystem::ImmutableStringImpl &lhs, + const datasystem::ImmutableStringImpl &rhs) const { return lhs == rhs; } -bool less::operator()(const datasystem::ImmutableString &lhs, - const datasystem::ImmutableString &rhs) const +bool less::operator()(const datasystem::ImmutableStringImpl &lhs, + const datasystem::ImmutableStringImpl &rhs) const { return lhs < rhs; } diff --git a/src/datasystem/common/immutable_string/immutable_string.h b/src/datasystem/common/immutable_string/immutable_string.h index 3cd5a02e986236677d9fbab1a658cb6dc8879764..979c77b1fbcf588bac490384c74abf8a74627da6 100644 --- a/src/datasystem/common/immutable_string/immutable_string.h +++ b/src/datasystem/common/immutable_string/immutable_string.h @@ -25,18 +25,19 @@ #include "datasystem/common/immutable_string/ref_count_string.h" namespace datasystem { -class ImmutableString { +using ImmutableString = std::string; +class ImmutableStringImpl { public: - ImmutableString() = default; - ImmutableString(const ImmutableString &val) = default; - ImmutableString &operator=(const ImmutableString &other) = default; - ~ImmutableString() = default; + ImmutableStringImpl() = default; + ImmutableStringImpl(const ImmutableStringImpl &val) = default; + ImmutableStringImpl &operator=(const ImmutableStringImpl &other) = default; + ~ImmutableStringImpl() = default; - ImmutableString(const std::string &val) noexcept; - ImmutableString(const char *cStr); + ImmutableStringImpl(const std::string &val) noexcept; + ImmutableStringImpl(const char *cStr); /** - * @brief Get the hash of ImmutableString. - * @return The hash of ImmutableString. + * @brief Get the hash of ImmutableStringImpl. + * @return The hash of ImmutableStringImpl. */ size_t GetHash() const; @@ -52,14 +53,14 @@ public: */ const std::string &ToString() const; - bool operator==(const ImmutableString &rhs) const; + bool operator==(const ImmutableStringImpl &rhs) const; - bool operator!=(const ImmutableString &rhs) const; + bool operator!=(const ImmutableStringImpl &rhs) const; - bool operator<(const ImmutableString &rhs) const; + bool operator<(const ImmutableStringImpl &rhs) const; /** - * @brief The operator to convert a ImmutableString to std::string. + * @brief The operator to convert a ImmutableStringImpl to std::string. * @return The the const reference of std::string. */ operator const std::string &() const @@ -74,22 +75,22 @@ public: private: RefCountStringHandle strHandle_; }; -std::ostream &operator<<(std::ostream &os, const ImmutableString &obj); +std::ostream &operator<<(std::ostream &os, const ImmutableStringImpl &obj); } // namespace datasystem namespace tbb { template <> #if TBB_INTERFACE_VERSION >= 12050 -struct detail::d1::tbb_hash_compare { +struct detail::d1::tbb_hash_compare { #else -struct tbb_hash_compare { +struct tbb_hash_compare { #endif - static size_t hash(const datasystem::ImmutableString &a) + static size_t hash(const datasystem::ImmutableStringImpl &a) { return a.GetHash(); } - static size_t equal(const datasystem::ImmutableString &a, const datasystem::ImmutableString &b) + static size_t equal(const datasystem::ImmutableStringImpl &a, const datasystem::ImmutableStringImpl &b) { return a == b; } @@ -98,18 +99,18 @@ struct tbb_hash_compare { namespace std { template <> -struct hash { - size_t operator()(const datasystem::ImmutableString &str) const; +struct hash { + size_t operator()(const datasystem::ImmutableStringImpl &str) const; }; template <> -struct equal_to { - bool operator()(const datasystem::ImmutableString &lhs, const datasystem::ImmutableString &rhs) const; +struct equal_to { + bool operator()(const datasystem::ImmutableStringImpl &lhs, const datasystem::ImmutableStringImpl &rhs) const; }; template <> -struct less { - bool operator()(const datasystem::ImmutableString &lhs, const datasystem::ImmutableString &rhs) const; +struct less { + bool operator()(const datasystem::ImmutableStringImpl &lhs, const datasystem::ImmutableStringImpl &rhs) const; }; } // namespace std #endif \ No newline at end of file diff --git a/src/datasystem/common/log/access_recorder.h b/src/datasystem/common/log/access_recorder.h index 4156b86a1f961fa2e517f98a97be8cf38520c295..56370873bb6cbc0da86a0ffacce6ce5e92d48918 100644 --- a/src/datasystem/common/log/access_recorder.h +++ b/src/datasystem/common/log/access_recorder.h @@ -133,7 +133,6 @@ public: void Record(StatusCode code, const StreamRequestParam &reqParam, const StreamResponseParam &respParam); private: - using clock = std::chrono::steady_clock; std::chrono::time_point beg_; std::string handleName_; @@ -229,7 +228,7 @@ std::string objectKeysToString(const std::vector &keys); std::string objectKeysToString(const char **cKey, size_t keyLen); template -std::string objectKeysToAbbrStr(T &&keys) +std::string ObjectKeysToAbbrStr(T &&keys) { if (keys.empty()) { return "+count:0"; diff --git a/src/datasystem/common/object_cache/object_base.h b/src/datasystem/common/object_cache/object_base.h index 7f69fd8118f53fc159655928632245c9eb5f7d2b..560ea41f9d6168b7d9dd7d0f9895e7a4e7b62bef 100644 --- a/src/datasystem/common/object_cache/object_base.h +++ b/src/datasystem/common/object_cache/object_base.h @@ -30,6 +30,7 @@ #include "datasystem/common/object_cache/object_bitmap.h" #include "datasystem/common/rpc/rpc_message.h" #include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/object_client.h" #include "datasystem/object/object_enum.h" #include "datasystem/utils/optional.h" @@ -318,7 +319,7 @@ struct ObjectInterface { struct ObjectBufferInfo { std::string objectKey; - std::string shmId; + ShmKey shmId; uint8_t *pointer; uint64_t dataSize; uint64_t metadataSize; diff --git a/src/datasystem/common/object_cache/object_ref_info.cpp b/src/datasystem/common/object_cache/object_ref_info.cpp index 1a6267768ca8575d3bc280d64ecb0d212872d1be..9380ae5b85cf71fc4cb2ad1ccc801998df7242ee 100644 --- a/src/datasystem/common/object_cache/object_ref_info.cpp +++ b/src/datasystem/common/object_cache/object_ref_info.cpp @@ -21,8 +21,10 @@ #include +#include "datasystem/common/immutable_string/immutable_string.h" #include "datasystem/common/log/log.h" #include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/strings_util.h" #include "datasystem/common/util/raii.h" @@ -32,102 +34,6 @@ namespace datasystem { namespace object_cache { -bool ObjectRefInfo::AddRef(const std::string &objectKey, uint32_t ref) -{ - std::shared_lock lock(objectKeyMapMutex_); - TbbObjKeyTable::accessor objAccessor; - VLOG(1) << "add object key " << objectKey << " ref:" << ref; - bool res = objectKeys_.emplace(objAccessor, objectKey, ref); - if (res) { - return true; - } - if (isUniqueCnt_) { - return false; - } - objAccessor->second += ref; - return true; -} - -uint32_t ObjectRefInfo::GetRefCount(const std::string &objectKey) -{ - std::shared_lock lock(objectKeyMapMutex_); - TbbObjKeyTable::const_accessor objAccessor; - if (objectKeys_.find(objAccessor, objectKey)) { - return objAccessor->second; - } - return 0; -} - -Status ObjectRefInfo::UpdateRefCount(const std::string &objectKey, int count) -{ - if (count < 0) { - RETURN_STATUS(StatusCode::K_INVALID, FormatString("[ObjectId %s] Invalid count: %d", objectKey, count)); - } - std::shared_lock lock(objectKeyMapMutex_); - TbbObjKeyTable::accessor objAccessor; - if (objectKeys_.find(objAccessor, objectKey)) { - if (isUniqueCnt_ && count > 1) { - RETURN_STATUS(StatusCode::K_DUPLICATED, "object key is marked to be unique"); - } - objAccessor->second = static_cast(count); - return Status::OK(); - } - auto result = objectKeys_.emplace(objAccessor, objectKey, count); - if (!result) { - RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "emplace on objectKeys_ failed."); - } - return Status::OK(); -} - -bool ObjectRefInfo::RemoveRef(const std::string &objectKey) -{ - std::shared_lock lock(objectKeyMapMutex_); - TbbObjKeyTable::accessor objAccessor; - if (!objectKeys_.find(objAccessor, objectKey)) { - return false; - } - if (isUniqueCnt_) { - auto result = objectKeys_.erase(objAccessor); - return result > 0; - } - objAccessor->second -= 1; - if (objAccessor->second == 0) { - (void)objectKeys_.erase(objAccessor); - } - return true; -} - -bool ObjectRefInfo::Contains(const std::string &objectKey) const -{ - std::shared_lock lock(objectKeyMapMutex_); - return objectKeys_.count(objectKey) == 1; -} - -void ObjectRefInfo::GetRefIds(std::vector &objectKeys) const -{ - std::lock_guard lock(objectKeyMapMutex_); - std::transform(objectKeys_.begin(), objectKeys_.end(), std::back_inserter(objectKeys), - [](auto &kv) { return kv.first; }); -} - -bool ObjectRefInfo::CheckIsNoneRef(const std::string &objectKey) const -{ - std::shared_lock lock(objectKeyMapMutex_); - TbbObjKeyTable::const_accessor objAccessor; - if (!objectKeys_.find(objAccessor, objectKey)) { - return true; - } else if (objAccessor->second == 0) { - return true; - } - return false; -} - -bool ObjectRefInfo::CheckIsRefIdsEmpty() const -{ - std::shared_lock lock(objectKeyMapMutex_); - return objectKeys_.empty(); -} - Status ObjectGlobalRefTable::GIncreaseRef(const std::string &clientId, const std::vector &objectKeys, std::vector &failedIncIds, std::vector &firstIncIds, bool isRemoteClient) @@ -136,7 +42,7 @@ Status ObjectGlobalRefTable::GIncreaseRef(const std::string &clientId, const std TbbClientRefTable::const_accessor clientAccessor; while (!clientRefTable_.find(clientAccessor, clientId)) { TbbClientRefTable::accessor accessor; - auto clientInfo = std::make_shared(); + auto clientInfo = std::make_shared>(); // std::string -> ObjectKey clientRefTable_.emplace(accessor, clientId, std::move(clientInfo)); // In the off-cloud reference counting scenario, if remoteClient appears for the first time, record it to // remoteClientIdTable_. @@ -273,7 +179,7 @@ void ObjectGlobalRefTable::GetRemoteClientIds(std::unordered_set &r remoteClientIds.clear(); for (TbbFirstRemoteClientTable::const_iterator it = remoteClientIdTable_.begin(); it != remoteClientIdTable_.end(); ++it) { - (void)remoteClientIds.insert(it->first.ToString()); + (void)remoteClientIds.insert(it->first); } } @@ -320,9 +226,9 @@ void ObjectGlobalRefTable::GetObjRefIds(const std::string &objectKey, std::vecto } } -Status SharedMemoryRefTable::GetShmUnit(const std::string &shmId, std::shared_ptr &shmUnit) +Status SharedMemoryRefTable::GetShmUnit(const ShmKey &shmId, std::shared_ptr &shmUnit) { - TbbMemoryObjectRefTable::accessor shmAccessor; + TbbMemoryObjectRefTable::const_accessor shmAccessor; auto found = shmRefTable_.find(shmAccessor, shmId); if (!found) { RETURN_STATUS(K_NOT_FOUND, FormatString("Get a not found shm: %s ", shmId)); @@ -338,7 +244,7 @@ void SharedMemoryRefTable::AddShmUnit(const std::string &clientId, std::shared_p TbbMemoryObjectRefTable::accessor objectAccessor; if (!clientRefTable_.find(clientAccessor, clientId)) { - auto clientInfo = std::make_shared(); + auto clientInfo = std::make_shared>(); clientRefTable_.emplace(clientAccessor, clientId, std::move(clientInfo)); } if (clientAccessor->second->AddRef(shmId)) { @@ -358,7 +264,39 @@ void SharedMemoryRefTable::AddShmUnit(const std::string &clientId, std::shared_p VLOG(1) << "AddShmUnit for shmid: " << shmUnit->id << " client id: " << clientId; } -Status SharedMemoryRefTable::RemoveShmUnit(const std::string &clientId, const std::string &shmId) +void SharedMemoryRefTable::AddShmUnits(const std::string &clientId, std::vector> &shmUnits) +{ + TbbMemoryClientRefTable::accessor clientAccessor; + if (!clientRefTable_.find(clientAccessor, clientId)) { + auto clientInfo = std::make_shared>(); + clientRefTable_.emplace(clientAccessor, clientId, std::move(clientInfo)); + } + TbbMemoryObjectRefTable::accessor objectAccessor; + for (auto &shmUnit : shmUnits) { + if (shmUnit == nullptr) { + continue; + } + const auto &shmId = shmUnit->GetId(); + if (clientAccessor->second->AddRef(shmId)) { + shmUnit->IncrementRefCount(); + } + if (!shmRefTable_.find(objectAccessor, shmId)) { + shmRefTable_.emplace(objectAccessor, shmId, + std::make_pair(shmUnit, std::unordered_set{ clientId })); + } else { + objectAccessor->second.second.emplace(clientId); + } + objectAccessor.release(); + + if (shmUnit->GetRefCount() == 1) { + datasystem::memory::Allocator::Instance()->ChangeNoRefPageCount(-1); + datasystem::memory::Allocator::Instance()->ChangeRefPageCount(1); + } + VLOG(1) << "AddShmUnit for shmid: " << shmUnit->id << " client id: " << clientId; + } +} + +Status SharedMemoryRefTable::RemoveShmUnit(const std::string &clientId, const ShmKey &shmId) { TbbMemoryClientRefTable::accessor clientAccessor; TbbMemoryObjectRefTable::accessor shmAccessor; @@ -392,7 +330,7 @@ void SharedMemoryRefTable::RemoveShmUnitDetail(const std::string &clientId, shmUnit->DecrementRefCount(); } else { LOG(WARNING) << "RemoveShmUnit: The value of refCount is 0 and cannot be decreased. id:" - << BytesUuidToString(shmId); + << BytesUuidToString(shmId.ToString()); } } if (clientAccessor->second->CheckIsRefIdsEmpty()) { @@ -415,7 +353,8 @@ void SharedMemoryRefTable::RemoveShmUnitDetail(const std::string &clientId, } } -bool SharedMemoryRefTable::Contains(const std::string &clientId, const std::string &shmId) const +#ifdef WITH_TESTS +bool SharedMemoryRefTable::Contains(const std::string &clientId, const ShmKey &shmId) const { TbbMemoryClientRefTable::accessor accessor; if (clientRefTable_.find(accessor, clientId)) { @@ -423,8 +362,9 @@ bool SharedMemoryRefTable::Contains(const std::string &clientId, const std::stri } return false; } +#endif -void SharedMemoryRefTable::GetClientRefIds(const std::string &clientId, std::vector &shmIds) const +void SharedMemoryRefTable::GetClientRefIds(const std::string &clientId, std::vector &shmIds) const { TbbMemoryClientRefTable::accessor accessor; if (clientRefTable_.find(accessor, clientId)) { @@ -438,7 +378,7 @@ Status SharedMemoryRefTable::RemoveClient(const std::string &clientId) if (!clientRefTable_.find(clientAccessor, clientId)) { return Status::OK(); } - std::vector shmIds; + std::vector shmIds; clientAccessor->second->GetRefIds(shmIds); for (const auto &shmId : shmIds) { TbbMemoryObjectRefTable::accessor shmAccessor; diff --git a/src/datasystem/common/object_cache/object_ref_info.h b/src/datasystem/common/object_cache/object_ref_info.h index faff48740830d3371548a8ac74d92b3dabe8f0ba..ab54b65dcef44e1c698f0d3b5db99d50ffff56cf 100644 --- a/src/datasystem/common/object_cache/object_ref_info.h +++ b/src/datasystem/common/object_cache/object_ref_info.h @@ -30,6 +30,7 @@ #include "datasystem/common/immutable_string/immutable_string.h" #include "datasystem/common/object_cache/object_base.h" #include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/net_util.h" #include "datasystem/common/object_cache/safe_object.h" #include "datasystem/utils/status.h" @@ -37,8 +38,9 @@ namespace datasystem { namespace object_cache { -using TbbObjKeyTable = tbb::concurrent_hash_map; +template class ObjectRefInfo { +using TbbObjKeyTable = tbb::concurrent_hash_map; public: explicit ObjectRefInfo(bool isUniqueCnt = true) : isUniqueCnt_(isUniqueCnt) { @@ -52,28 +54,28 @@ public: * @param[in] ref ref num need to add. * @return True on success, false otherwise. */ - bool AddRef(const std::string &objectKey, uint32_t ref = 1); + bool AddRef(const T &objectKey, uint32_t ref = 1); /** * @brief Remove the reference to the object. * @param[in] objectKey The object key to remove, it cannot be empty. * @return True on success, false otherwise. */ - bool RemoveRef(const std::string &objectKey); + bool RemoveRef(const T &objectKey); /** * @brief Check if the id is contains. * @param[in] objectKey The id of object. * @return True on success, false otherwise. */ - bool Contains(const std::string &objectKey) const; + bool Contains(const T &objectKey) const; /** * @brief Get number of references for a object key. * @param[in] objectKey The id of object. * @return reference count if Id is present, 0 otherwise. */ - uint32_t GetRefCount(const std::string &objectKey); + uint32_t GetRefCount(const T &objectKey); /** * @brief Used to update reference count for a objectKey during recovery @@ -81,20 +83,20 @@ public: * @param[in] count The reference count for the object. * @return Status of the call. */ - Status UpdateRefCount(const std::string &objectKey, int count); + Status UpdateRefCount(const T &objectKey, int count); /** * @brief Get all ref ids. * @param[out] objectKeys The object keys. */ - void GetRefIds(std::vector &objectKeys) const; + void GetRefIds(std::vector &objectKeys) const; /** * @brief Check if the obj is dependent on other objs. * @param[in] objectKey The id of object. * @return Whether it is no ref. */ - bool CheckIsNoneRef(const std::string &objectKey) const; + bool CheckIsNoneRef(const T &objectKey) const; /** * @brief Check is objectKeys_ are empty. @@ -110,7 +112,8 @@ private: bool isUniqueCnt_; }; -using TbbClientRefTable = tbb::concurrent_hash_map>; +using TbbClientRefTable = + tbb::concurrent_hash_map>>; // std::string -> ObjectKey using TbbObjRefTable = tbb::concurrent_hash_map>; using TbbFirstRemoteClientTable = tbb::concurrent_hash_map; class ObjectGlobalRefTable { @@ -246,9 +249,9 @@ private: std::function removeFromKvStore_; }; -using TbbMemoryClientRefTable = tbb::concurrent_hash_map>; +using TbbMemoryClientRefTable = tbb::concurrent_hash_map>>; using TbbMemoryObjectRefTable = - tbb::concurrent_hash_map, std::unordered_set>>; + tbb::concurrent_hash_map, std::unordered_set>>; class SharedMemoryRefTable { public: SharedMemoryRefTable() = default; @@ -262,13 +265,20 @@ public: */ void AddShmUnit(const std::string &clientId, std::shared_ptr &shmUnit); + /** + * @brief Add shared memory units reference to the client table. + * @param[in] clientId uuid of client. + * @param[in] shmUnits The safe objects. + */ + void AddShmUnits(const std::string &clientId, std::vector> &shmUnits); + /** * @brief Check one shared memory unit whether be referred by client. * @param[in] objectKey Shared memory unit id of object. * @param[out] shmUnit Shared memory unit shared ptr. * @return true on success, false otherwise. */ - Status GetShmUnit(const std::string &shmId, std::shared_ptr &shmUnit); + Status GetShmUnit(const ShmKey &shmId, std::shared_ptr &shmUnit); /** * @brief Remove a client from the client table. @@ -276,7 +286,7 @@ public: * @param[in] shmId The shared memory id. * @return Status of the call */ - Status RemoveShmUnit(const std::string &clientId, const std::string &shmId); + Status RemoveShmUnit(const std::string &clientId, const ShmKey &shmId); /** * @brief Remove a client from the client table. @@ -285,20 +295,22 @@ public: */ Status RemoveClient(const std::string &clientId); +#ifdef WITH_TESTS /** * @brief Check one object whether be referred by client. * @param[in] clientId uuid of client. * @param[in] shmId Shared memory unit id of object. * @return true on success, false otherwise. */ - bool Contains(const std::string &clientId, const std::string &shmId) const; + bool Contains(const std::string &clientId, const ShmKey &shmId) const; +#endif /** * @brief Get all share memory unit ids referred by client. * @param[in] clientId uuid of client. * @param[out] shmIds Shared memory unit id of object. */ - void GetClientRefIds(const std::string &clientId, std::vector &shmIds) const; + void GetClientRefIds(const std::string &clientId, std::vector &shmIds) const; private: /** @@ -317,6 +329,110 @@ private: mutable std::shared_timed_mutex mutex_; }; + +template +bool ObjectRefInfo::AddRef(const T &objectKey, uint32_t ref) +{ + std::shared_lock lock(objectKeyMapMutex_); + typename TbbObjKeyTable::accessor objAccessor; + VLOG(1) << "add object key " << objectKey << " ref:" << ref; + bool res = objectKeys_.emplace(objAccessor, objectKey, ref); + if (res) { + return true; + } + if (isUniqueCnt_) { + return false; + } + objAccessor->second += ref; + return true; +} + +template +uint32_t ObjectRefInfo::GetRefCount(const T &objectKey) +{ + std::shared_lock lock(objectKeyMapMutex_); + typename TbbObjKeyTable::const_accessor objAccessor; + if (objectKeys_.find(objAccessor, objectKey)) { + return objAccessor->second; + } + return 0; +} + +template +Status ObjectRefInfo::UpdateRefCount(const T &objectKey, int count) +{ + if (count < 0) { + RETURN_STATUS(StatusCode::K_INVALID, FormatString("[ObjectId %s] Invalid count: %d", objectKey, count)); + } + std::shared_lock lock(objectKeyMapMutex_); + typename TbbObjKeyTable::accessor objAccessor; + if (objectKeys_.find(objAccessor, objectKey)) { + if (isUniqueCnt_ && count > 1) { + RETURN_STATUS(StatusCode::K_DUPLICATED, "object key is marked to be unique"); + } + objAccessor->second = static_cast(count); + return Status::OK(); + } + auto result = objectKeys_.emplace(objAccessor, objectKey, count); + if (!result) { + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "emplace on objectKeys_ failed."); + } + return Status::OK(); +} + +template +bool ObjectRefInfo::RemoveRef(const T &objectKey) +{ + std::shared_lock lock(objectKeyMapMutex_); + typename TbbObjKeyTable::accessor objAccessor; + if (!objectKeys_.find(objAccessor, objectKey)) { + return false; + } + if (isUniqueCnt_) { + auto result = objectKeys_.erase(objAccessor); + return result > 0; + } + objAccessor->second -= 1; + if (objAccessor->second == 0) { + (void)objectKeys_.erase(objAccessor); + } + return true; +} + +template +bool ObjectRefInfo::Contains(const T &objectKey) const +{ + std::shared_lock lock(objectKeyMapMutex_); + return objectKeys_.count(objectKey) == 1; +} + +template +void ObjectRefInfo::GetRefIds(std::vector &objectKeys) const +{ + std::lock_guard lock(objectKeyMapMutex_); + std::transform(objectKeys_.begin(), objectKeys_.end(), std::back_inserter(objectKeys), + [](auto &kv) { return kv.first; }); +} + +template +bool ObjectRefInfo::CheckIsNoneRef(const T &objectKey) const +{ + std::shared_lock lock(objectKeyMapMutex_); + typename TbbObjKeyTable::const_accessor objAccessor; + if (!objectKeys_.find(objAccessor, objectKey)) { + return true; + } else if (objAccessor->second == 0) { + return true; + } + return false; +} + +template +bool ObjectRefInfo::CheckIsRefIdsEmpty() const +{ + std::shared_lock lock(objectKeyMapMutex_); + return objectKeys_.empty(); +} } // namespace object_cache } // namespace datasystem #endif // DATASYSTEM_COMMON_OBJECT_CACHE_OBJECTREFINFO_H diff --git a/src/datasystem/common/parallel/CMakeLists.txt b/src/datasystem/common/parallel/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..204a9d5a971419e1f891c42a5bc1c69cb7d496b7 --- /dev/null +++ b/src/datasystem/common/parallel/CMakeLists.txt @@ -0,0 +1,5 @@ +set(PARALLEL_SRCS + detail/parallel_for_local.cpp) + +add_library(common_parallel STATIC ${PARALLEL_SRCS}) +target_link_libraries(common_parallel PRIVATE common_util) diff --git a/src/datasystem/common/parallel/detail/barrier.h b/src/datasystem/common/parallel/detail/barrier.h new file mode 100644 index 0000000000000000000000000000000000000000..4643737c482dafc55fd1f0bd0a520e219b7948a2 --- /dev/null +++ b/src/datasystem/common/parallel/detail/barrier.h @@ -0,0 +1,80 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: A generic barrier implementation for thread synchronization. + */ +#ifndef DATASYSTEM_COMMON_PARALLEL_BARRIER_H +#define DATASYSTEM_COMMON_PARALLEL_BARRIER_H + +#include +#include + +#define DS_UNLIKELY(x) __builtin_expect(!!(x), 0) + +namespace datasystem { +namespace Parallel { + +template class Barrier { +public: + Barrier(const Barrier &) = delete; + Barrier(Barrier &&) = delete; + Barrier &operator = (const Barrier &) = delete; + Barrier &operator = (Barrier &&) = delete; + + Barrier() + { + semData = new T(); + } + + ~Barrier() + { + if (semData != nullptr) { + semData->SemDestroy(); + delete semData; + semData = nullptr; + } + } + + inline void ForkBarrier(uint32_t initCnt) + { + // just master thread call it + semData->SemInit(0); + awaited = initCnt; + } + + inline void JoinBarrier(bool isMaster) + { + if (DS_UNLIKELY(isMaster)) { + semData->SemPend(); + // after wait, master should destroy it + awaited = 0; + return; + } + // workers meet join barrier point + if (DS_UNLIKELY(awaited.fetch_add(-1, std::memory_order_relaxed) == 1)) { + semData->SemPost(); + } + } + +private: + T *semData; + std::atomic awaited { 0 }; +}; +} +} + +#endif \ No newline at end of file diff --git a/src/datasystem/common/parallel/detail/native_sem.h b/src/datasystem/common/parallel/detail/native_sem.h new file mode 100644 index 0000000000000000000000000000000000000000..74493553cbfda5602531cff0799b58d11320b082 --- /dev/null +++ b/src/datasystem/common/parallel/detail/native_sem.h @@ -0,0 +1,74 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: A RAII wrapper for POSIX semaphore operations + */ +#ifndef DATASYSTEM_COMMON_PARALLEL_NATIVE_SEM_H +#define DATASYSTEM_COMMON_PARALLEL_NATIVE_SEM_H + +#include +#include + +namespace datasystem { +namespace Parallel { +class NativeSem { +public: + NativeSem(const NativeSem &) = delete; + NativeSem(NativeSem &&) = delete; + NativeSem &operator = (const NativeSem &) = delete; + NativeSem &operator = (NativeSem &&) = delete; + + NativeSem() + { + sem = (sem_t *)malloc(sizeof(sem_t)); + } + + ~NativeSem() + { + if (sem != nullptr) { + free(sem); + sem = nullptr; + } + } + + inline void SemInit(int32_t initCnt) + { + (void)sem_init(sem, 0, initCnt); + } + + inline void SemDestroy() + { + (void)sem_destroy(sem); + } + + inline void SemPend() + { + (void)sem_wait(sem); + } + + inline void SemPost() + { + (void)sem_post(sem); + } + +private: + sem_t *sem; +}; +} +} + +#endif \ No newline at end of file diff --git a/src/datasystem/common/parallel/detail/parallel_for_local.cpp b/src/datasystem/common/parallel/detail/parallel_for_local.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2fb6d0edbe6ff2688ae7c66eec87f6eba0fdc48e --- /dev/null +++ b/src/datasystem/common/parallel/detail/parallel_for_local.cpp @@ -0,0 +1,48 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: A high-performance multi-threaded parallel computing framework. + */ + +#include "datasystem/common/parallel/detail/parallel_for_local.h" +#include "datasystem/common/util/thread_pool.h" + +namespace datasystem { +namespace Parallel { + +ParallelThreadPool *ParallelThreadPool::Instance() +{ + static ParallelThreadPool threadPool; + return &threadPool; +} + +void ParallelThreadPool::InitThreadPool(int minThreadNum, int maxThreadNum) +{ + bool expected = false; + if (isInit_.compare_exchange_strong(expected, true)) { + threadPool_ = std::make_unique(minThreadNum, maxThreadNum, "parallel_for"); + threadPool_->SetWarnLevel(ThreadPool::WarnLevel::NO_WARN); + threadNum_ = maxThreadNum == 0 ? minThreadNum : maxThreadNum; + } +} + +void ParallelThreadPool::LocalSubmit(std::function &&func) +{ + (void)threadPool_->Submit(func); +} +} +} \ No newline at end of file diff --git a/src/datasystem/common/parallel/detail/parallel_for_local.h b/src/datasystem/common/parallel/detail/parallel_for_local.h new file mode 100644 index 0000000000000000000000000000000000000000..236d912bf14855cdded790d8d0c33c554e570d25 --- /dev/null +++ b/src/datasystem/common/parallel/detail/parallel_for_local.h @@ -0,0 +1,187 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: A high-performance multi-threaded parallel computing framework. + */ +#ifndef DATASYSTEM_COMMON_PARALLEL_PARALLEL_FOR_LOCAL_H +#define DATASYSTEM_COMMON_PARALLEL_PARALLEL_FOR_LOCAL_H + +#include +#include + +#include "datasystem/common/parallel/detail/barrier.h" +#include "datasystem/common/parallel/detail/native_sem.h" + +namespace datasystem { +class ThreadPool; +namespace Parallel { +extern thread_local int g_threadid; +static inline int GetThreadid() +{ + return g_threadid; +} + +/*! + @struct Context + @brief thread context info + */ +struct Context { + /*! + @brief thread identify id + */ + size_t id; +}; + +class ParallelThreadPool { +public: + static ParallelThreadPool *Instance(); + + void InitThreadPool(int minThreadNum, int maxThreadNum = 0); + + bool IsInitialized() + { + return isInit_; + } + + int GetThreadNum() const + { + return threadNum_; + } + + void LocalSubmit(std::function &&func); + +private: + std::atomic_bool isInit_{ false }; + std::unique_ptr threadPool_; + int threadNum_; +}; + +template +struct ParallelForLocal : public std::enable_shared_from_this> { +public: + ParallelForLocal(const ParallelForLocal &) = delete; + ParallelForLocal(ParallelForLocal &&) = delete; + ParallelForLocal &operator=(const ParallelForLocal &) = delete; + ParallelForLocal &operator=(ParallelForLocal &&) = delete; + + ParallelForLocal(const Index &start, const Index &end, const Handler &handler, const size_t &chunkSize) + : startIndex(start), endIndex(end), bodyHandler(handler), chunksize(chunkSize) + { + threadBarrier = new Barrier(); + } + + ~ParallelForLocal() + { + if (threadBarrier != nullptr) { + delete threadBarrier; + threadBarrier = nullptr; + } + } + + void DoParallelFor(const int ¶llelDegree) + { + size_t chunkCount = (endIndex - startIndex + chunksize - 1) / chunksize; + // master thread fork, this will init barrier + threadBarrier->ForkBarrier(chunkCount); + + // worker threads do work + for (int i = 0; i < parallelDegree - 1; i++) { + Context ctx; + ctx.id = i; + auto weak = this->weak_from_this(); + ParallelThreadPool::Instance()->LocalSubmit([weak, ctx]() { + if (auto ptr = weak.lock(); ptr) { + ptr->ParallelForLocal::ParallelForDynamicEntryTask(ctx); + } + }); + } + Context ctx; + ctx.id = parallelDegree - 1; + ParallelForDynamicEntryTask(ctx); + + // master thread join + threadBarrier->JoinBarrier(true); + } + + static constexpr bool HandlerTypeCheck() + { + if constexpr (std::is_invocable_v) { + return true; + } else if constexpr (std::is_invocable_v) { + return true; + } + return false; + } + + static void CallBodyHandler(Index start, Index end, const Handler &handler, const Context &ctx) + { + // match the argument format of Handler + if constexpr (std::is_invocable_v) { + handler(start, end); + } else if constexpr (std::is_invocable_v) { + handler(start, end, ctx); + } + } + +private: + inline void ParallelForDynamicEntryTask(const Context &ctx) + { + Index start; + Index end; + for (;;) { + bool hasNextSlice = GetNextSliceDynamic(&start, &end); + if (hasNextSlice) { + CallBodyHandler(start, end, bodyHandler, ctx); + threadBarrier->JoinBarrier(false); + } else { + break; + } + } + } + + // modify startIndex concurrency + inline bool GetNextSliceDynamic(Index *start, Index *end) + { + bool success; + do { + *start = startIndex; + + if (DS_UNLIKELY(*start >= endIndex)) { + return false; + } + + *end = *start + (Index)chunksize; + success = __sync_bool_compare_and_swap(&startIndex, *start, *end); + } while (!success); + + if (*end > endIndex) { + *end = endIndex; + } + + return true; + } + + Index startIndex; + const Index endIndex; + const Handler bodyHandler; + const size_t chunksize; + Barrier *threadBarrier; +}; +} +} + +#endif \ No newline at end of file diff --git a/src/datasystem/common/parallel/parallel_for.h b/src/datasystem/common/parallel/parallel_for.h new file mode 100644 index 0000000000000000000000000000000000000000..a45b69ddf5e3bf796e775c2165f76ffd3c8f3716 --- /dev/null +++ b/src/datasystem/common/parallel/parallel_for.h @@ -0,0 +1,109 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: A high-performance multi-threaded parallel computing framework. + */ +#ifndef DATASYSTEM_COMMON_PARALLEL_PARALLEL_FOR_H +#define DATASYSTEM_COMMON_PARALLEL_PARALLEL_FOR_H + +#include +#include +#include + +#include "datasystem/common/parallel/detail/parallel_for_local.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +namespace Parallel { + +static inline void InitParallelThreadPool(int minThreadNum, int maxThreadNum = 0) +{ + ParallelThreadPool::Instance()->InitThreadPool(minThreadNum, maxThreadNum); +} + +/** + * @brief ParallelFor is a function framework for parallel computing, enabling tasks to be executed in parallel across + * multiple threads to improve computational efficiency. ParallelFor internally implements parallel computing through + * task allocation and scheduling, automatically distributing tasks to available threads. + * @tparam Index The type of the iteration variable. + * @tparam Handler The type of the function to be called. + * @param start The starting value of the loop iteration range. + * @param end The ending value of the loop iteration range (exclusive). + * @param handler The function to be executed in the loop. The parameter list of the user-defined handler can be one of + * the following two types: + * 1. (Index, Index) + * 2. (Index, Index, const datasystem::Parallel::Context&) + * When the user's handler uses the datasystem::Parallel::Context parameter, the value of context.id in the handler will + * be in the range [0, parallelism). + * @param chunkSize The granularity of the task. + * @param workThreadSize The number of worker threads. If set to -1 (default), it will be set to the number of threads + * in the thread pool plus 1. + * @throws Exception + * 1. If ParallelFor is called before initialization, an exception "Assertion IsInitialized() failed !!!" will be + * thrown. + * 2. If the parameter list of the user-defined handler does not match the specified format, a compilation error will + * occur: "error: static assertion failed: handler must have 2 or 3 arguments. And arguments should be (Index, Index) or + * (Index, Index, const datasystem::Parallel::Context&)". + */ +template +Status ParallelFor(Index start, Index end, const Handler &handler, size_t chunkSize = 0, int workThreadSize = -1) +{ + if (end == start) { + return Status::OK(); + } + if (end < start) { + return Status(K_INVALID, "ParallelFor: end must be greater than start"); + } + if ((workThreadSize != -1 && workThreadSize < 1) || (Index)(INT_MAX - chunkSize) <= end) { + return Status(K_INVALID, "Parameter validation failed"); + } + if (!ParallelThreadPool::Instance()->IsInitialized()) { + return Status(K_INVALID, "ParallelThreadPool is not initialized"); + } + int parallelDegree; + if (workThreadSize == -1) { + parallelDegree = ParallelThreadPool::Instance()->GetThreadNum() + 1; + } else { + parallelDegree = std::min(workThreadSize, ParallelThreadPool::Instance()->GetThreadNum() + 1); + } + if (chunkSize == 0) { + static const int DEFAULT_CHUNK_COUNT_PER_THREAD_ON_AVERAGE = 4; // each thread processes 4 chunks on averge + chunkSize = std::max( + static_cast(end - start) / parallelDegree / DEFAULT_CHUNK_COUNT_PER_THREAD_ON_AVERAGE, 1); + } + auto taskNum = static_cast(ceil(static_cast(end - start) / static_cast(chunkSize))); + parallelDegree = std::min(taskNum, parallelDegree); + + constexpr bool typeCheck = ParallelForLocal::HandlerTypeCheck(); + static_assert(typeCheck, + "handler must have 2 or 3 arguments. And arguments should be (Index, Index) or (Index, " + "Index, const datasystem::Parallel::Context&)"); + + if (taskNum == 1) { + ParallelForLocal::CallBodyHandler(start, end, handler, Context{0}); + return Status::OK(); + } + // allocate workshare + auto local = std::make_shared>(start, end, handler, chunkSize); + local->DoParallelFor(parallelDegree); + return Status::OK(); +} +} +} + +#endif \ No newline at end of file diff --git a/src/datasystem/common/perf/perf_point.def b/src/datasystem/common/perf/perf_point.def index 5a5beb7443bdaf0c422ea51d28a8855e18d60fc5..8c7c768ad685594a9f347a5bbab8ce79085b90b4 100644 --- a/src/datasystem/common/perf/perf_point.def +++ b/src/datasystem/common/perf/perf_point.def @@ -208,6 +208,8 @@ PERF_KEY_DEF(WORKER_CONSTRUCT_BATCH_GET_REQ) PERF_KEY_DEF(WORKER_HANDLE_BATCH_SUB_RESP) PERF_KEY_DEF(WORKER_HANDLE_BATCH_SUB_RESP_PT_2) PERF_KEY_DEF(WORKER_CONSTRUCT_AND_SEND) +PERF_KEY_DEF(WORKER_URMA_GET_SEGMENT) + // stream worker PERF_KEY_DEF(WORKER_CREATE_PRODUCER_ALL) PERF_KEY_DEF(WORKER_CREATE_SUB_ALL) @@ -397,6 +399,7 @@ PERF_KEY_DEF(ZMQ_SVC_TO_ROUTER) PERF_KEY_DEF(ZMQ_ROUTER_TO_SVC) PERF_KEY_DEF(ZMQ_FRONTEND_TO_IOSVC) PERF_KEY_DEF(ZMQ_PAYLOAD_TRANSFER) +PERF_KEY_DEF(ZMQ_STUB_TO_EXCL_CONN) // stream client PERF_KEY_DEF(RPC_WORKER_CREATE_PRODUCER) PERF_KEY_DEF(RPC_WORKER_CREATE_SUBSCRIBE) diff --git a/src/datasystem/common/rdma/CMakeLists.txt b/src/datasystem/common/rdma/CMakeLists.txt index 0e43263dd4856845722e6ed342bf95c626d90bbb..03b0999ef427ff3acfbc63cbb18455760a1caece 100644 --- a/src/datasystem/common/rdma/CMakeLists.txt +++ b/src/datasystem/common/rdma/CMakeLists.txt @@ -11,13 +11,22 @@ set(URMA_DEPEND_LIBS if (BUILD_WITH_URMA) list(APPEND URMA_SRCS urma_manager.cpp urma_info.cpp) list(APPEND URMA_DEPEND_LIBS ${URMA_LIBRARY}) + list(APPEND URMA_DEPEND_LIBS ${TBB_LIBRARY}) set(URMA_STUB_SRCS urma_stub.cpp urma_info.cpp) add_library(common_stub_rdma STATIC ${URMA_STUB_SRCS}) + target_link_libraries(common_stub_rdma PRIVATE ${TBB_LIBRARY}) endif() -add_library(common_rdma STATIC ${URMA_SRCS}) -target_link_libraries(common_rdma PRIVATE ${URMA_DEPEND_LIBS}) \ No newline at end of file +if (BUILD_WITH_RDMA) + list(APPEND RDMA_SRCS ucp_manager.cpp) + list(APPEND RDMA_DEPEND_LIBS ${UCX_LIBRARIES}) + + # TODO: STUB RDMA_STUB_SRCS +endif() + +add_library(common_rdma STATIC ${URMA_SRCS} ${RDMA_SRCS}) +target_link_libraries(common_rdma PRIVATE ${URMA_DEPEND_LIBS} ${RDMA_DEPEND_LIBS}) \ No newline at end of file diff --git a/src/datasystem/common/rdma/ucp_manager.cpp b/src/datasystem/common/rdma/ucp_manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..96ab10633788d1f9db56b293166770e4f84ed8c4 --- /dev/null +++ b/src/datasystem/common/rdma/ucp_manager.cpp @@ -0,0 +1,20 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: UCX-UCP manager for ucp context, ucp worker, ucp endpoint, etc. + */ +#include "datasystem/common/rdma/ucp_manager.h" diff --git a/src/datasystem/common/rdma/ucp_manager.h b/src/datasystem/common/rdma/ucp_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..ace687126c96b87ba2ba1a4eda7c60d71214aa53 --- /dev/null +++ b/src/datasystem/common/rdma/ucp_manager.h @@ -0,0 +1,373 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: UCX-UCP manager for ucp context, ucp worker, ucp endpoint, etc. + */ +#ifndef DATASYSTEM_COMMON_RPC_RDMA_MANAGER_H +#define DATASYSTEM_COMMON_RPC_RDMA_MANAGER_H + +#include +#include +#include +#include + +#include + +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/rpc/rpc_channel.h" +#include "datasystem/common/util/lock_map.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/protos/meta_zmq.pb.h" +#include "datasystem/protos/utils.pb.h" +#include "datasystem/utils/status.h" + +DS_DECLARE_bool(enable_rdma); +DS_DECLARE_bool(rdma_register_whole_arena); + +namespace datasystem { +template +using custom_unique_ptr = std::unique_ptr>; + +template +custom_unique_ptr MakeCustomUnique(T *p, std::function custom_delete) +{ + if (p) { + return custom_unique_ptr(p, custom_delete); + } else { + LOG(WARNING) << "Input pointer is null"; + return nullptr; + } +} + +class UcpSegment { +public: + /** + * @brief Create a new UcpSegment object. + */ + UcpSegment(); + ~UcpSegment(); + + /** + * @brief Sets UcpSegment object + * @param[in] seg target UcpSegment object + * @param[in] local if local no need to unimport the UcpSegment + * @return Status of the call. + */ + void Set(ucp_mem_h *seg, bool local); + + void Clear(); + +private: + custom_unique_ptr segment_; + bool local_; +}; + +using UcpSegmentMap = LockMap; + +class UcpEndpoint { +public: + /** + * @brief Create a new UcpEndpoint object. + */ + UcpEndpoint(){}; + + ~UcpEndpoint(); + + /** + * @brief Get remote UcpSegment or import remote UcpSegment from the device + * @param[in] UrmaImportSegmentPb Pb with remote segment info + * @param[out] constAccessor Accessor in segment table + * @return Status of the call. + */ + Status GetOrImportRemoteUcpSeg(const RdmaImportSegmentPb &RdmaInfo, UcpSegmentMap::ConstAccessor &constAccessor); +}; + +using UcpEndpointMap = LockMap; + +class Event { +public: + /** + * @brief Create a new Event object. + */ + explicit Event(uint64_t requestId) : requestId_(requestId), ready_(false) + { + } + + /** + * @brief Wait on event until timeout or someone notify + * @param[in] timeout time in milliseconds to wait + * @return Status of the call. + */ + Status WaitFor(std::chrono::milliseconds timeout) + { + std::unique_lock lock(eventMutex_); + bool gotNotification = cv.wait_for(lock, timeout, [this] { return ready_; }); + if (!gotNotification && !ready_) { + // Return timeout + RETURN_STATUS_LOG_ERROR(K_RPC_DEADLINE_EXCEEDED, + FormatString("timedout waiting for request: %d", requestId_)); + } + return Status::OK(); + } + + /** + * @brief Notify all threads that are waiting for the event + */ + void NotifyAll() + { + std::unique_lock lock(eventMutex_); + ready_ = true; + cv.notify_all(); + } + + /** + * @brief Sets the event status as failed + */ + void set_failed() + { + failed_ = true; + } + + /** + * @brief Checks the event status + */ + bool is_failed() + { + return failed_; + } + +private: + std::condition_variable cv; + mutable std::mutex eventMutex_; + uint64_t requestId_; + bool ready_{ false }; + bool failed_{ false }; +}; + +class UcpWorker { +}; + +class UcpManager { +public: + /** + * @brief Singleton mode, obtaining instance. + * @return Reference of UcpManager + */ + static UcpManager &Instance(); + + ~UcpManager(); + + /** + * @brief Init a Rdma device + * @return Status of the call. + */ + Status Init(); + + /** + * @brief Check we should register whole arena upfront + * @return True if flag is set, else false + */ + static bool IsRegisterWholeArenaEnabled() + { + return FLAGS_rdma_register_whole_arena; + } + + /** + * @brief Check we should use event mode for interrupts + * @return True if flag is set, else false + */ + static bool IsEventModeEnabled(); + + /** + * @brief Register segment + * @param[in] segAddress Starting address of the segment + * @param[in] segSize Size of the segment + * @return Status of the call. + */ + Status RegisterSegment(const uint64_t &segAddress, const uint64_t &segSize); + + /** + * @brief Gets the segment if present or + * Registers the segment if address is not already registered + * @param[in] segAddress Starting address of the segment + * @param[in] segSize Size of the segment + * @param[out] segVA virtual address of the segment + * @param[out] segLen Size of the segment (==segSize) + * @param[out] segFlag Flags set for the segment + * @param[out] segTokenId Token provided for the segment + * @return Status of the call. + */ + Status GetSegmentInfo(const uint64_t &segAddress, const uint64_t &segSize, uint64_t &segVA, uint64_t &segLen, + uint32_t &segFlag, uint32_t &segTokenId); + + /** + * @brief Does a RDMA write to remote worker memory location + * 1. Registers the segment if address is not already registered + * 2. Imports remote segment + * 3. does a Ucp write + * @param[in] RdmaImportSegmentPb Protobuf contians remote worker RDMA info + * @param[in] localSegAddress Starting address of the segment (e.g. Arena + * start address) + * @param[in] localSegSize Total size of the segment (e.g. Arena size) + * @param[in] localObjectAddress Object address + * @param[in] readOffset Offset in the object to read + * @param[in] readSize Size of the object + * @param[in] metaDataSize Size of metadata (SHM metadata stored as part of + * object) + * @param[in] blocking Whether to blocking wait for the ucp_put_nbx to finish. + * @param[out] keys The new request id to wait for if not blocking. + * @return Status of the call. + */ + Status ImportSegAndPutPayload(const RdmaImportSegmentPb &urmaInfo, const uint64_t &localSegAddress, + const uint64_t &localSegSize, const uint64_t &localObjectAddress, + const uint64_t &readOffset, const uint64_t &readSize, const uint64_t &metaDataSize, + bool blocking, std::vector &keys); + + /** + * @brief Remove Remote Endpoint and all associated segments + * @param[in] remoteAddress Remote Worker Address + * @return Status of the call. + */ + Status RemoveEndpoint(const HostPort &remoteAddress); + + /** + * @brief Ucp write operation waits on the CV to check completion status + * @param[in] requestId unique id for the urma request (passed as user_ctx in + * urma_write) + * @param[in] timeoutMs timeout waiting for the request to end + * @return Status of the call. + */ + Status WaitToFinish(uint64_t requestId, int64_t timeoutMs); + + /** + * @brief Handshake for RDMA purposes. + * @param[in] req RDMA handshake request. + * @param[out] rsp RDMA handshake response. + * @return Status of the call. + */ + Status UcpHandshake(const RdmaHandshakeReqPb &req, RdmaHandshakeRspPb &rsp); + +private: + UcpManager(); + + /** + * @brief Initialize ucp. + * @return Status of the call. + */ + Status UcpInit(); + + /** + * @brief Uninitialize ucp. + * @return Status of the call. + */ + Status UcpUninit(); + + /** + * @brief Creates Ucp context + * @param[in] urmaDevice local Urma device + * @param[in] eidIndex eid index of the device + * @return Status of the call. + */ + Status UcpCreateContext(ucp_params_t *params, ucp_config_t *config); + + /** + * @brief Deletes Ucp context object + * @return Status of the call. + */ + Status UcpDeleteContext(); + + /** + * @brief Continously running Event handler thread that polls JFC + * @return Status of the call. + */ + Status ServerEventHandleThreadMain(); + + /** + * @brief Register segment + * @param[in] segAddress Starting address of the segment + * @param[in] segSize Size of the segment + * @param[out] constAccessor const accessor to Segment map + * @return Status of the call. + */ + Status GetOrRegisterSegment(const uint64_t &segAddress, const uint64_t &segSize, + UcpSegmentMap::ConstAccessor &constAccessor); + + /** + * @brief UnImport segment + * @param[in] remoteAddress Remote worker address + * @param[in] segmentAddress Segment start address + * @return Status of the call + */ + Status UnimportSegment(const HostPort &remoteAddress, uint64_t segmentAddress); + + /** + * @brief Stops the polling thread + * @return Status of the call. + */ + Status Stop(); + + /** + * @brief Checks if waiting requests are completed and notifys them + * @return Status of the call. + */ + Status CheckAndNotify(); + + /** + * @brief Gets event object of request id + * @param[in] requestId unique id of the Urma request + * @param[out] event event object for the request + * @return Status of the call. + */ + Status GetEvent(uint64_t requestId, std::shared_ptr &event); + + /** + * @brief Create Event object for the request + * @param[in] requestId unique id for the Urma request + * @param[out] event event object for the request + * @return Status of the call. + */ + Status CreateEvent(uint64_t requestId, std::shared_ptr &event); + + /** + * @brief Deletes the Event object for the request + * @param[in] requestId unique id for the Urma request + * @return Status of the call. + */ + void DeleteEvent(uint64_t requestId); + + // Polling thread + std::unique_ptr serverEventThread_{ nullptr }; + + ucp_context_h *ucpContext_ = nullptr; + // ip -> ucp endpoint + std::unordered_map> endPointCache_; + std::vector> ucpWorkerVec_; + std::atomic requestId_{ 0 }; + + std::unique_ptr localSegmentMap_; + // Eid to segment maps mapping for remote segment. + std::unique_ptr remoteEndpointMap_; + mutable std::shared_timed_mutex eventMapMutex_; + std::unordered_map> eventMap_; + std::unordered_set finishedRequests_; + std::unordered_set failedRequests_; + std::atomic serverStop_{ false }; +}; + +} // namespace datasystem +#endif diff --git a/src/datasystem/common/rdma/urma_manager.cpp b/src/datasystem/common/rdma/urma_manager.cpp index 7684351bf640f05aef58a5724a705082a4fa3460..119188dcb97add0e6814ddfe993f240d64341b78 100644 --- a/src/datasystem/common/rdma/urma_manager.cpp +++ b/src/datasystem/common/rdma/urma_manager.cpp @@ -20,13 +20,14 @@ #include "datasystem/common/rdma/urma_manager.h" #include "datasystem/common/constants.h" -#include "datasystem/common/log/log.h" #include "datasystem/common/flags/flags.h" +#include "datasystem/common/log/log.h" #include "datasystem/common/perf/perf_manager.h" #include "datasystem/common/rdma/rdma_util.h" #include "datasystem/common/rdma/urma_manager_wrapper.h" #include "datasystem/common/rpc/rpc_constants.h" #include "datasystem/common/util/raii.h" +#include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/thread_local.h" #include "datasystem/utils/status.h" #include "urma_opcode.h" @@ -69,7 +70,6 @@ UrmaManager::UrmaManager() #endif localSegmentMap_ = std::make_unique(); remoteDeviceMap_ = std::make_unique(); - eventMap_ = std::make_unique(); } UrmaManager::~UrmaManager() @@ -78,7 +78,7 @@ UrmaManager::~UrmaManager() VLOG(RPC_LOG_LEVEL) << "UrmaManager::~UrmaManager()"; remoteDeviceMap_.reset(); localSegmentMap_.reset(); - eventMap_.reset(); + tbbEventMap_.clear(); urmaJfrVec_.clear(); urmaJfsVec_.clear(); urmaJfc_.reset(); @@ -470,33 +470,28 @@ Status AddUbBondSegInfo(urma_context_t *ctx, urma_bond_add_remote_seg_info_in_t #endif } // namespace -Status UrmaManager::GetSegmentInfo(const uint64_t &segAddress, const uint64_t &segSize, const uint64_t &shmOffset, - const uint64_t &metaSz, const HostPort &localAddress, UrmaImportSegmentPb &segInfo) +Status UrmaManager::GetSegmentInfo(UrmaHandshakeReqPb &handshakeReq) { - SegmentMap::ConstAccessor constAccessor; - RETURN_IF_NOT_OK(GetOrRegisterSegment(segAddress, segSize, constAccessor)); - auto &localSegment = constAccessor.entry->data.segment_; - auto segPb = segInfo.mutable_seg(); - UrmaSeg::ToProto(localSegment->seg, *segPb); - LOG(INFO) << "local seg info: " << UrmaSeg::ToString(localSegment->seg); - - if (IsRegisterWholeArenaEnabled()) { - segInfo.set_seg_data_offset(shmOffset + metaSz); - } else { - segInfo.set_seg_data_offset(metaSz); - } - segInfo.mutable_request_address()->set_host(localAddress.Host()); - segInfo.mutable_request_address()->set_port(localAddress.Port()); + PerfPoint point(PerfKey::WORKER_URMA_GET_SEGMENT); + // Traverse the list of local registered segments. + std::unique_lock l(localMapMutex_); + for (auto iter = localSegmentMap_->begin(); iter != localSegmentMap_->end(); iter++) { + auto *segInfo = handshakeReq.add_seg_infos(); + auto &localSegment = iter->second.data.segment_; + auto segPb = segInfo->mutable_seg(); + UrmaSeg::ToProto(localSegment->seg, *segPb); + LOG(INFO) << "local seg info: " << UrmaSeg::ToString(localSegment->seg); #ifdef URMA_OVER_UB - // UB bond prehandling for segment. - if (GetUrmaMode() == UrmaMode::UB) { - UrmaBondSegInfo info; - RETURN_IF_NOT_OK(GetUbBondSegInfo(localSegment.get(), info.raw)); - auto *bondInfo = segInfo.mutable_bond_info(); - info.ToProto(*bondInfo); - LOG(INFO) << "local bond seg info: " << info.ToString(); - } + // UB bond prehandling for segment. + if (GetUrmaMode() == UrmaMode::UB) { + UrmaBondSegInfo info; + RETURN_IF_NOT_OK(GetUbBondSegInfo(localSegment.get(), info.raw)); + auto *bondInfo = segInfo->mutable_bond_info(); + info.ToProto(*bondInfo); + LOG(INFO) << "local bond seg info: " << info.ToString(); + } #endif + } return Status::OK(); } @@ -591,19 +586,14 @@ Status UrmaManager::CheckAndNotify() void UrmaManager::DeleteEvent(uint64_t requestId) { - std::shared_lock lock(eventMapMutex_); - EventMap::Accessor accessor; - if (eventMap_->Find(accessor, requestId)) { - eventMap_->BlockingErase(accessor); - } + tbbEventMap_.erase(requestId); } Status UrmaManager::GetEvent(uint64_t requestId, std::shared_ptr &event) { - std::shared_lock lock(eventMapMutex_); - EventMap::Accessor accessor; - if (eventMap_->Find(accessor, requestId)) { - event = accessor.entry->data; + TbbEventMap::accessor mapAccessor; + if (tbbEventMap_.find(mapAccessor, requestId)) { + event = mapAccessor->second; return Status::OK(); } // Can happen if event is not yet inserted by sender thread. @@ -612,15 +602,13 @@ Status UrmaManager::GetEvent(uint64_t requestId, std::shared_ptr &event) Status UrmaManager::CreateEvent(uint64_t requestId, std::shared_ptr &event) { - std::shared_lock lock(eventMapMutex_); - EventMap::Accessor accessor; - auto res = eventMap_->Insert(accessor, requestId); + TbbEventMap::accessor mapAccessor; + auto res = tbbEventMap_.insert(mapAccessor, requestId); if (!res) { // If this happens that means requestId is duplicated. RETURN_STATUS_LOG_ERROR(K_DUPLICATED, FormatString("Request id %d already exists in event map", requestId)); } else { - event = std::make_shared(requestId); - accessor.entry->data = event; + mapAccessor->second = std::make_shared(requestId); } return Status::OK(); } @@ -740,8 +728,6 @@ Status UrmaManager::PollJfcWait(const custom_unique_ptr &jfc, const Status UrmaManager::ImportRemoteJfr(const UrmaJfrInfo &urmaInfo) { PerfPoint point1(PerfKey::URMA_CONNECT_WITH_REMOTE_DEVICE); - // Do not need to import jfr for the local node. - RETURN_OK_IF_TRUE(localUrmaInfo_.localAddress == urmaInfo.localAddress); const std::string remoteDeviceId = urmaInfo.localAddress.ToString(); std::shared_lock l(remoteMapMutex_); // Insert or update the import jfr (in case the sending worker restarts) @@ -807,15 +793,36 @@ Status UrmaManager::ImportRemoteJfr(const UrmaJfrInfo &urmaInfo) return Status::OK(); } -Status UrmaManager::ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo, const uint64_t &localSegAddress, - const uint64_t &localSegSize, const uint64_t &localObjectAddress, - const uint64_t &readOffset, const uint64_t &readSize, - const uint64_t &metaDataSize, bool blocking, std::vector &keys) +Status UrmaManager::ImportRemoteInfo(const UrmaHandshakeReqPb &req) +{ + const HostPort requestAddress(req.address().host(), req.address().port()); + const std::string remoteDeviceId = requestAddress.ToString(); + PerfPoint point1(PerfKey::URMA_CONNECT_WITH_REMOTE_DEVICE); + std::shared_lock l(remoteMapMutex_); + RemoteDeviceMap::ConstAccessor constAccessor; + // The comm layer (zmq) has already exchanged the jfr, and we should be able to locate the entry. + auto res = remoteDeviceMap_->Find(constAccessor, remoteDeviceId); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(res, K_RUNTIME_ERROR, + FormatString("Failed to find jfr from %s", remoteDeviceId)); + point1.Record(); + PerfPoint point2(PerfKey::URMA_IMPORT_REMOTE_SEGMENT); + for (int i = 0; i < req.seg_infos_size(); i++) { + auto &segInfo = req.seg_infos(i); + RETURN_IF_NOT_OK(constAccessor.entry->data.ImportRemoteSeg(urmaContext_, segInfo)); + } + point2.Record(); + return Status::OK(); +} + +Status UrmaManager::UrmaWritePayload(const UrmaRemoteAddrPb &urmaInfo, const uint64_t &localSegAddress, + const uint64_t &localSegSize, const uint64_t &localObjectAddress, + const uint64_t &readOffset, const uint64_t &readSize, const uint64_t &metaDataSize, + bool blocking, std::vector &keys) { // Note that the returned keys only contain the new key(s). keys.clear(); PerfPoint point(PerfKey::URMA_IMPORT_AND_WRITE_PAYLOAD); - auto segVa = urmaInfo.seg().va(); + auto segVa = urmaInfo.seg_va(); const HostPort requestAddress(urmaInfo.request_address().host(), urmaInfo.request_address().port()); const std::string remoteDeviceId = requestAddress.ToString(); PerfPoint point1(PerfKey::URMA_CONNECT_WITH_REMOTE_DEVICE); @@ -828,7 +835,7 @@ Status UrmaManager::ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo point1.Record(); PerfPoint point2(PerfKey::URMA_IMPORT_REMOTE_SEGMENT); SegmentMap::ConstAccessor remoteSegAccessor; - RETURN_IF_NOT_OK(constAccessor.entry->data.GetOrImportRemoteSeg(urmaContext_, urmaInfo, remoteSegAccessor)); + RETURN_IF_NOT_OK(constAccessor.entry->data.GetRemoteSeg(segVa, remoteSegAccessor)); point2.Record(); PerfPoint point3(PerfKey::URMA_REGISTER_LOCAL_SEGMENT); @@ -922,7 +929,10 @@ Status UrmaManager::ExchangeJfr(const UrmaHandshakeReqPb &req, UrmaHandshakeRspP UrmaJfrInfo urmaInfo; RETURN_IF_NOT_OK(urmaInfo.FromProto(req)); LOG(INFO) << "Start import remote jfr, remote urma info: " << urmaInfo.ToString(); + // Do not need to import remote jfr or segment for the local node. + RETURN_OK_IF_TRUE(localUrmaInfo_.localAddress == urmaInfo.localAddress); LOG_IF_ERROR(mgr.ImportRemoteJfr(urmaInfo), "Error in import incoming jfr"); + LOG_IF_ERROR(mgr.ImportRemoteInfo(req), "Error in import remote segments"); // Do not need to fill in jfr response for urma_write scenario. } return Status::OK(); @@ -998,12 +1008,19 @@ void RemoteDevice::SetJfrs(std::vector &jetties) } } -Status RemoteDevice::GetOrImportRemoteSeg(urma_context_t *urmaContext, const UrmaImportSegmentPb &importSegmentInfo, - SegmentMap::ConstAccessor &constAccessor) +Status RemoteDevice::GetRemoteSeg(uint64_t segVa, SegmentMap::ConstAccessor &constAccessor) { + if (remoteSegments_.Find(constAccessor, segVa)) { + return Status::OK(); + } + RETURN_STATUS(K_NOT_FOUND, "Remote segment is not found"); +} + +Status RemoteDevice::ImportRemoteSeg(urma_context_t *urmaContext, const UrmaImportSegmentPb &importSegmentInfo) +{ + SegmentMap::Accessor accessor; auto segVa = importSegmentInfo.seg().va(); - if (!remoteSegments_.Find(constAccessor, segVa)) { - SegmentMap::Accessor accessor; + if (!remoteSegments_.Find(accessor, segVa)) { if (remoteSegments_.Insert(accessor, segVa)) { bool needErase = true; Raii eraseSegment([this, &accessor, &needErase]() { @@ -1031,10 +1048,6 @@ Status RemoteDevice::GetOrImportRemoteSeg(urma_context_t *urmaContext, const Urm accessor.entry->data.Set(segment, false); needErase = false; } - accessor.Release(); - // Switch to const accessor so it does not block the others. - CHECK_FAIL_RETURN_STATUS(remoteSegments_.Find(constAccessor, segVa), K_RUNTIME_ERROR, - "Failed to operate on remote segment map."); } return Status::OK(); } diff --git a/src/datasystem/common/rdma/urma_manager.h b/src/datasystem/common/rdma/urma_manager.h index d39837b2b039eacde502a4de891b53cb5939e445..eebb86089d4e6f909ac6f18889951dc63e07deb4 100644 --- a/src/datasystem/common/rdma/urma_manager.h +++ b/src/datasystem/common/rdma/urma_manager.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #ifdef URMA_OVER_UB @@ -104,14 +105,20 @@ public: void SetJfrs(std::vector &jetties); /** - * @brief Get remote segment or import remote segment from the device. - * @param[in] urmaContext The urma context. - * @param[in] UrmaImportSegmentPb Pb with remote segment info. - * @param[out] constAccessor Accessor in segment table. + * @brief Get remote segment from the device + * @param[in] segVa The remote segment address + * @param[out] constAccessor Accessor in segment table * @return Status of the call. */ - Status GetOrImportRemoteSeg(urma_context_t *urmaContext, const UrmaImportSegmentPb &urmaInfo, - SegmentMap::ConstAccessor &constAccessor); + Status GetRemoteSeg(uint64_t segVa, SegmentMap::ConstAccessor &constAccessor); + + /** + * @brief Import remote segment and keep record in the device + * @param[in] urmaContext The urma context + * @param[in] UrmaImportSegmentPb Pb with remote segment info + * @return Status of the call. + */ + Status ImportRemoteSeg(urma_context_t *urmaContext, const UrmaImportSegmentPb &urmaInfo); /** * @brief Unimport a remote segment @@ -192,6 +199,7 @@ private: }; using EventMap = LockMap>; +using TbbEventMap = tbb::concurrent_hash_map>; class UrmaManager { public: @@ -261,24 +269,25 @@ public: Status RegisterSegment(const uint64_t &segAddress, const uint64_t &segSize); /** - * @brief Fill segment info. Register the segment if not already registered. - * @param[in] segAddress Starting address of the segment. - * @param[in] segSize Size of the segment. - * @param[in] shmOffset The shared memory offset of the object. - * @param[in] metaSz The size of the shared memory metadata size. - * @param[in] localAddress The local worker hostport. - * @param[out] segInfo The urma segment info for import purposes. + * @brief Fill segment info into request + * @param[out] handshakeReq The protobuf to fill with segment info + * @return Status of the call. + */ + Status GetSegmentInfo(UrmaHandshakeReqPb &handshakeReq); + + /** + * @brief Import segment info from request + * @param[in] handshakeReq The protobuf to import segment info from * @return Status of the call. */ - Status GetSegmentInfo(const uint64_t &segAddress, const uint64_t &segSize, const uint64_t &shmOffset, - const uint64_t &metaSz, const HostPort &localAddress, UrmaImportSegmentPb &segInfo); + Status ImportRemoteInfo(const UrmaHandshakeReqPb &req); /** * @brief Does a RDMA write to remote worker memory location * 1. Registers the segment if address is not already registered * 2. Imports remote segment * 3. does a urma write - * @param[in] UrmaImportSegmentPb Protobuf contians remote worker URMA info + * @param[in] UrmaRemoteAddrPb Protobuf contians remote host address, remote urma segment address and data offset * @param[in] localSegAddress Starting address of the segment (e.g. Arena start address) * @param[in] localSegSize Total size of the segment (e.g. Arena size) * @param[in] localObjectAddress Object address @@ -289,10 +298,10 @@ public: * @param[out] keys The new request id to wait for if not blocking. * @return Status of the call. */ - Status ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo, const uint64_t &localSegAddress, - const uint64_t &localSegSize, const uint64_t &localObjectAddress, - const uint64_t &readOffset, const uint64_t &readSize, const uint64_t &metaDataSize, - bool blocking, std::vector &keys); + Status UrmaWritePayload(const UrmaRemoteAddrPb &urmaInfo, const uint64_t &localSegAddress, + const uint64_t &localSegSize, const uint64_t &localObjectAddress, + const uint64_t &readOffset, const uint64_t &readSize, const uint64_t &metaDataSize, + bool blocking, std::vector &keys); /** * @brief Remove Remote Device and all associated segments @@ -574,8 +583,7 @@ private: std::unique_ptr localSegmentMap_; // Eid to segment maps mapping for remote jfr and segment. std::unique_ptr remoteDeviceMap_; - mutable std::shared_timed_mutex eventMapMutex_; - std::unique_ptr eventMap_; + TbbEventMap tbbEventMap_; std::unordered_set finishedRequests_; std::unordered_set failedRequests_; std::atomic serverStop_{ false }; diff --git a/src/datasystem/common/rdma/urma_manager_wrapper.cpp b/src/datasystem/common/rdma/urma_manager_wrapper.cpp index cd5daddb5bc51e97dcec977c4740162dc8c033dc..86d916625f5902c768a4e0bf352d95b150678571 100644 --- a/src/datasystem/common/rdma/urma_manager_wrapper.cpp +++ b/src/datasystem/common/rdma/urma_manager_wrapper.cpp @@ -102,29 +102,9 @@ void GetSegmentInfoFromShmUnit(std::shared_ptr shmUnit, uint64_t memory #endif } -Status FillUrmaInfo(std::shared_ptr shmUnit, const HostPort &localAddress, uint64_t metaSz, - UrmaImportSegmentPb &urmaInfo) -{ - (void)shmUnit; - (void)localAddress; - (void)metaSz; - (void)urmaInfo; -#ifdef USE_URMA - if (UrmaManager::IsUrmaEnabled()) { - uint64_t segAddress; - uint64_t segSize; - GetSegmentInfoFromShmUnit(shmUnit, reinterpret_cast(shmUnit->GetPointer()), segAddress, segSize); - RETURN_IF_NOT_OK(UrmaManager::Instance().GetSegmentInfo(segAddress, segSize, shmUnit->GetOffset(), metaSz, - localAddress, urmaInfo)); - } -#endif - return Status::OK(); -} - -Status ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo, const uint64_t &localSegAddress, - const uint64_t &localSegSize, const uint64_t &localObjectAddress, - const uint64_t &readOffset, const uint64_t &readSize, const uint64_t &metaDataSize, - bool blocking, std::vector &keys) +Status UrmaWritePayload(const UrmaRemoteAddrPb &urmaInfo, const uint64_t &localSegAddress, const uint64_t &localSegSize, + const uint64_t &localObjectAddress, const uint64_t &readOffset, const uint64_t &readSize, + const uint64_t &metaDataSize, bool blocking, std::vector &keys) { (void)urmaInfo; (void)localSegAddress; @@ -136,10 +116,10 @@ Status ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo, const uint6 (void)blocking; (void)keys; #ifdef USE_URMA - RETURN_IF_NOT_OK(UrmaManager::Instance().ImportSegAndWritePayload(urmaInfo, localSegAddress, localSegSize, - localObjectAddress, readOffset, readSize, - metaDataSize, blocking, keys)); + RETURN_IF_NOT_OK(UrmaManager::Instance().UrmaWritePayload(urmaInfo, localSegAddress, localSegSize, + localObjectAddress, readOffset, readSize, metaDataSize, + blocking, keys)); #endif return Status::OK(); } -} // namespace datasystem \ No newline at end of file +} // namespace datasystem diff --git a/src/datasystem/common/rdma/urma_manager_wrapper.h b/src/datasystem/common/rdma/urma_manager_wrapper.h index f612c147ac63fe9863fc7a29b1642c199f48b8d9..c78c7d8b957639640b477e622faddf5eeca7652d 100644 --- a/src/datasystem/common/rdma/urma_manager_wrapper.h +++ b/src/datasystem/common/rdma/urma_manager_wrapper.h @@ -82,7 +82,7 @@ void GetSegmentInfoFromShmUnit(std::shared_ptr shmUnit, uint64_t memory /** * @brief Trigger UrmaManager logic to import segment and write payload. - * @param[in] UrmaImportSegmentPb Protobuf contians remote worker URMA info. + * @param[in] UrmaRemoteAddrPb Protobuf contians remote host address, remote urma segment address and data offset * @param[in] localSegAddress Starting address of the segment (e.g. Arena start address). * @param[in] localSegSize Total size of the segment (e.g. Arena size). * @param[in] localObjectAddress Object address. @@ -93,19 +93,9 @@ void GetSegmentInfoFromShmUnit(std::shared_ptr shmUnit, uint64_t memory * @param[out] keys The new request id to wait for if not blocking. * @return Status of the call. */ -Status ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo, const uint64_t &localSegAddress, - const uint64_t &localSegSize, const uint64_t &localObjectAddress, - const uint64_t &readOffset, const uint64_t &readSize, const uint64_t &metaDataSize, - bool blocking, std::vector &keys); +Status UrmaWritePayload(const UrmaRemoteAddrPb &urmaInfo, const uint64_t &localSegAddress, const uint64_t &localSegSize, + const uint64_t &localObjectAddress, const uint64_t &readOffset, const uint64_t &readSize, + const uint64_t &metaDataSize, bool blocking, std::vector &keys); -/** - * @brief Fill in import segment pb for URMA. - * @param[in] shmUnit The shared memory unit. - * @param[in] metaSz The metadata size of shared memory. - * @param[out] urmaInfo Protobuf contians remote worker URMA info. - * @return Status of the call. - */ -Status FillUrmaInfo(std::shared_ptr shmUnit, const HostPort &localAddress, uint64_t metaSz, - UrmaImportSegmentPb &urmaInfo); } // namespace datasystem -#endif // DATASYSTEM_COMMON_RDMA_URMA_MANAGER_WRAPPER_H \ No newline at end of file +#endif // DATASYSTEM_COMMON_RDMA_URMA_MANAGER_WRAPPER_H diff --git a/src/datasystem/common/rdma/urma_stub.cpp b/src/datasystem/common/rdma/urma_stub.cpp index 81b16e83a4e713144ed5ece735bbcc5144a32ac1..5729edfd654914ba180ead9a8bb102747251263a 100644 --- a/src/datasystem/common/rdma/urma_stub.cpp +++ b/src/datasystem/common/rdma/urma_stub.cpp @@ -22,6 +22,12 @@ namespace datasystem { __attribute__((weak)) UrmaManager::UrmaManager() = default; __attribute__((weak)) UrmaManager::~UrmaManager() = default; +Status __attribute__((weak)) UrmaManager::GetSegmentInfo(UrmaHandshakeReqPb &handshakeReq) +{ + (void)handshakeReq; + return Status::OK(); +} + Status __attribute__((weak)) UrmaManager::ExchangeJfr(const UrmaHandshakeReqPb &req, UrmaHandshakeRspPb &rsp) { (void)req; diff --git a/src/datasystem/common/rpc/CMakeLists.txt b/src/datasystem/common/rpc/CMakeLists.txt index c4013b3d0a2a478eb37a797c16ac89667351e99b..3ab91bf501cd7dae77b0a98495017441e996474c 100644 --- a/src/datasystem/common/rpc/CMakeLists.txt +++ b/src/datasystem/common/rpc/CMakeLists.txt @@ -55,7 +55,9 @@ set(COMMON_RPC_ZMQ_SRCS zmq/zmq_unary_client_impl.h zmq/zmq_server_stream_base.cpp zmq/zmq_client_stream_base.cpp - zmq/zmq_stream_base.cpp) + zmq/zmq_stream_base.cpp + zmq/exclusive_conn_mgr.cpp + zmq/work_agent.cpp) set(COMMON_RPC_ZMQ_DEPENDS_LIBS protobuf::libprotobuf diff --git a/src/datasystem/common/rpc/plugin_generator/service_cpp_generator.cpp b/src/datasystem/common/rpc/plugin_generator/service_cpp_generator.cpp index 7a062547baf01ec1bd9e9a432b3e08e864b02324..ed4dc9514edf9bf4397b121ef109e0f3d5b8bea0 100644 --- a/src/datasystem/common/rpc/plugin_generator/service_cpp_generator.cpp +++ b/src/datasystem/common/rpc/plugin_generator/service_cpp_generator.cpp @@ -35,6 +35,7 @@ void ZmqRpcGenerator::CreateServiceCpp(const google::protobuf::FileDescriptor &f const std::string &svcName = svc->name(); GenerateInitMethodMapDef(printer, *svc, PREFIX, svcName); ImplementZmqCallMethodDef(printer, *svc, PREFIX, svcName); + ImplementZmqDirectCallMethodDef(printer, *svc, PREFIX, svcName); } printer.PrintRaw(namespaceEnd); @@ -57,7 +58,7 @@ void ZmqRpcGenerator::GenerateServiceCppPrologue(io::Printer &printer, void ZmqRpcGenerator::ImplementCallMethodNoStream(io::Printer &printer, const google::protobuf::MethodDescriptor &method, int methodIndex, - const std::string &indent) + const std::string &indent, bool enableMsgQ) { (void)indent; std::map vars; @@ -67,17 +68,20 @@ void ZmqRpcGenerator::ImplementCallMethodNoStream(io::Printer &printer, vars["outputTypeName"] = method.output_type()->name(); vars["optSendPayload1"] = HasPayloadSendOption(method) ? ", std::move(payload)" : ""; vars["optRecvPayload1"] = HasPayloadRecvOption(method) ? ", outPayload" : ""; - std::string impl = + std::string impl; + std::string sockArg = (enableMsgQ) ? "sock" : "nullptr"; + impl += " auto &methodObj = methodMap_.find($methodIndex$)->second;\n" " $inputTypeName$ rq;\n" " $outputTypeName$ reply;\n" " auto serverApi =\n" " std::make_unique<::datasystem::ServerUnaryWriterReaderImpl<$outputTypeName$, " "$inputTypeName$>>(\n" - " sock, meta, std::move(inMsg), methodObj->HasPayloadSendOption(),\n" + " " + sockArg + ", meta, std::move(inMsg), methodObj->HasPayloadSendOption(),\n" " methodObj->HasPayloadRecvOption());\n" " rc = serverApi->Read(rq);\n" " if (rc.IsError()) { break; }\n"; + if (HasPayloadSendOption(method)) { impl += " std::vector<::datasystem::RpcMessage> payload;\n" @@ -88,10 +92,17 @@ void ZmqRpcGenerator::ImplementCallMethodNoStream(io::Printer &printer, impl += " std::vector<::datasystem::RpcMessage> outPayload;\n"; } impl += - " rc = $methodName$(rq, reply$optSendPayload1$$optRecvPayload1$);\n" - " if (rc.IsError()) { rc = serverApi->SendStatus(rc); break; }\n" - " rc = serverApi->Write(reply);\n" - " if (rc.IsError()) { break; }\n"; + " rc = $methodName$(rq, reply$optSendPayload1$$optRecvPayload1$);\n"; + if (enableMsgQ) { + impl += + " if (rc.IsError()) { rc = serverApi->SendStatus(rc); break; }\n" + " rc = serverApi->Write(reply);\n" + " if (rc.IsError()) { break; }\n"; + } else { + impl += + " if (rc.IsError()) { outMsg.push_back(std::move(StatusToZmqMessage(rc))); break; }\n" + " rc = serverApi->ConstructWriteMsg(reply, outMsg);\n"; + } if (HasPayloadRecvOption(method)) { impl += " rc = serverApi->SendPayload(outPayload);\n" @@ -194,7 +205,7 @@ void ZmqRpcGenerator::ImplementCallMethodStream(io::Printer &printer, const goog void ZmqRpcGenerator::ImplementCallMethodUnarySocket(io::Printer &printer, const google::protobuf::MethodDescriptor &method, int methodIndex, - const std::string &indent) + const std::string &indent, bool enableMsgQ) { (void)indent; std::map vars; @@ -202,17 +213,24 @@ void ZmqRpcGenerator::ImplementCallMethodUnarySocket(io::Printer &printer, vars["methodName"] = method.name(); vars["inputTypeName"] = method.input_type()->name(); vars["outputTypeName"] = method.output_type()->name(); - std::string impl = + std::string impl; + std::string sockArg = (enableMsgQ) ? "sock" : "nullptr"; + impl += " auto &methodObj = methodMap_.find($methodIndex$)->second;\n" " auto pimpl =\n" " std::make_unique<::datasystem::ServerUnaryWriterReaderImpl<$outputTypeName$, " "$inputTypeName$>>(\n" - " sock, meta, std::move(inMsg), methodObj->HasPayloadSendOption(),\n" + " " + sockArg + ", meta, std::move(inMsg), methodObj->HasPayloadSendOption(),\n" " methodObj->HasPayloadRecvOption());\n" " auto serverApi = std::make_shared<::datasystem::ServerUnaryWriterReader<$outputTypeName$, " "$inputTypeName$>>(std::move(pimpl));\n" " rc = $methodName$(serverApi);\n" " if (rc.IsError()) { rc = serverApi->SendStatus(rc); break; }\n"; + if (!enableMsgQ && method.name() == "Get") { + impl += + " rc = serverApi->GetOutMsg(outMsg);" + " if (rc.IsError()) { rc = serverApi->SendStatus(rc); break; }\n"; + } printer.Print(vars, impl.c_str()); } @@ -268,4 +286,51 @@ void ZmqRpcGenerator::ImplementZmqCallMethodDef(io::Printer &printer, const goog "}\n"; printer.PrintRaw(endFunction); } + +void ZmqRpcGenerator::ImplementZmqDirectCallMethodDef(io::Printer &printer, + const google::protobuf::ServiceDescriptor &svc, const std::string &indent, const std::string &svcName) +{ + const std::string &level1Indent = indent; + const std::string level2Indent = level1Indent + indent; + const std::string level3Indent = level2Indent + indent; + std::map vars; + vars["svcName"] = svcName; + std::string startFunction = + "::datasystem::Status $svcName$::DirectCallMethod(::datasystem::MetaPb meta,\n" + " std::deque<::datasystem::ZmqMessage> &&inMsg, int64_t seqNo,\n" + " std::deque<::datasystem::ZmqMessage> &outMsg) {\n" + " datasystem::Status rc;\n" + " (void)seqNo;\n" + " switch(meta.method_index()) {\n"; + + printer.Print(vars, startFunction.c_str()); + for (auto j = 0; j < svc.method_count(); ++j) { + if (svc.method(j) == nullptr) { + continue; + } + auto &method = *(svc.method(j)); + vars["methodName"] = method.name(); + vars["methodIndex"] = std::to_string(j); + printer.Print(vars, " case $methodIndex$: { // $methodName$\n"); + if (!method.client_streaming() && !method.server_streaming()) { + if (UnarySocketNeeded(method)) { + ImplementCallMethodUnarySocket(printer, method, j, level3Indent, false); + } else { + ImplementCallMethodNoStream(printer, method, j, level3Indent, false); + } + } + printer.PrintRaw(" break;\n" + " } // case\n"); + } + std::string endFunction = + " default: {\n" + " rc = datasystem::Status(datasystem::StatusCode::K_UNKNOWN_ERROR, __LINE__, __FILE__,\n" + " \"Unknown method\");\n" + " break;\n" + " }\n" + " } // switch\n" + " return rc;\n" + "}\n"; + printer.PrintRaw(endFunction); +} } // namespace datasystem diff --git a/src/datasystem/common/rpc/plugin_generator/service_header_generator.cpp b/src/datasystem/common/rpc/plugin_generator/service_header_generator.cpp index b014c386033a5fdfdec0eb626aa9bc4d644e95a9..36f8e3eec22cd19f99660316d314703863de8fff 100644 --- a/src/datasystem/common/rpc/plugin_generator/service_header_generator.cpp +++ b/src/datasystem/common/rpc/plugin_generator/service_header_generator.cpp @@ -198,5 +198,9 @@ void ZmqRpcGenerator::ImplementZmqCallMethodDecl(io::Printer &printer) printer.PrintRaw( " ::datasystem::Status CallMethod(std::shared_ptr<::datasystem::ZmqServerMsgQueRef> sock, " "::datasystem::MetaPb meta, std::deque<::datasystem::ZmqMessage> &&inMsg, int64_t seqNo) override;\n"); + printer.PrintRaw( + " ::datasystem::Status DirectCallMethod(::datasystem::MetaPb meta," + " std::deque<::datasystem::ZmqMessage> &&inMsg, int64_t seqNo," + " std::deque<::datasystem::ZmqMessage> &outMsg) override;\n"); } } // namespace datasystem diff --git a/src/datasystem/common/rpc/plugin_generator/stub_cpp_generator.cpp b/src/datasystem/common/rpc/plugin_generator/stub_cpp_generator.cpp index 52c803e32bc6a1e4f8cb9bdf84fbd47e21184c8e..67c2aa2c53354b35f7d0d5901a48eefa702763d4 100644 --- a/src/datasystem/common/rpc/plugin_generator/stub_cpp_generator.cpp +++ b/src/datasystem/common/rpc/plugin_generator/stub_cpp_generator.cpp @@ -147,10 +147,15 @@ void ZmqRpcGenerator::ImplementGenericStubOtherFuncDef(io::Printer &printer, con "Status $stub$::GetInitStatus() {\n" " return stub_->GetInitStatus();\n" "}\n"; + const std::string setExclusiveConnInfo = + "void $stub$::SetExclusiveConnInfo(const std::optional &exclusiveId, const std::string &sockPath) {\n" + " return stub_->SetExclusiveConnInfo(exclusiveId, sockPath);\n" + "}\n"; printer.Print(vars, forgetRequest.c_str()); printer.Print(vars, isPeerAlive.c_str()); printer.Print(vars, cacheSession.c_str()); printer.Print(vars, getInitStatus.c_str()); + printer.Print(vars, setExclusiveConnInfo.c_str()); } void ZmqRpcGenerator::ImplementGenericStubConstructor(io::Printer &printer, @@ -631,11 +636,26 @@ void ZmqRpcGenerator::ImplementStubNoStreamDefHelper(std::string &impl, impl += " ::datasystem::Status rc;\n" " auto &methodObj = methodMap_.find($methodIndex$)->second;\n" - " std::shared_ptr<::datasystem::ZmqMsgQueRef> sock;\n" - " ::datasystem::RpcOptions o(opt);\n" - " o.SetHWM(2);\n" - " rc = pimpl_->CreateMsgQ(sock, serviceName_, o);\n" - " if (rc.IsError()) { return rc; }\n"; + " std::unique_ptr> clientApi;\n" + " if (!exclusiveId_.has_value()) {\n" + " std::shared_ptr<::datasystem::ZmqMsgQueRef> sock;\n" + " ::datasystem::RpcOptions o(opt);\n" + " o.SetHWM(2);\n" + " rc = pimpl_->CreateMsgQ(sock, serviceName_, o);\n" + " if (rc.IsError()) { return rc; }\n" + " clientApi =\n" + " std::make_unique>(\n" + " std::move(sock), ServiceName(), methodObj->MethodIndex(),\n" + " methodObj->HasPayloadSendOption(), methodObj->HasPayloadRecvOption());\n" + " } else {\n" + " // Exclusive connection mode has different constructor to set it up and requires an init.\n" + " clientApi =\n" + " std::make_unique>(\n" + " exclusiveId_.value(), ServiceName(), methodObj->MethodIndex(),\n" + " methodObj->HasPayloadSendOption(), methodObj->HasPayloadRecvOption());\n" + " rc = clientApi->InitExclusiveConnection(exclusiveSockPath_, opt.GetTimeout());\n" + " if (rc.IsError()) { return rc; }\n" + " }\n"; if (criticalFunc) { impl += @@ -643,10 +663,6 @@ void ZmqRpcGenerator::ImplementStubNoStreamDefHelper(std::string &impl, " PerfPoint point2(PerfKey::ZMQ_$upperMethodName$_RPC);\n"; } impl += - " auto clientApi =\n" - " std::make_unique>(\n" - " std::move(sock), ServiceName(), methodObj->MethodIndex(),\n" - " methodObj->HasPayloadSendOption(), methodObj->HasPayloadRecvOption());\n" " rc = clientApi->Write(rq);\n" " if (rc.IsError()) { return rc; }\n"; diff --git a/src/datasystem/common/rpc/plugin_generator/stub_header_generator.cpp b/src/datasystem/common/rpc/plugin_generator/stub_header_generator.cpp index 159cae9c8c31084fc27ecdb7a277a6862da8ef85..a530c33a22b9e3cc5ebb7a622dcc5f80fad0d5c4 100644 --- a/src/datasystem/common/rpc/plugin_generator/stub_header_generator.cpp +++ b/src/datasystem/common/rpc/plugin_generator/stub_header_generator.cpp @@ -166,7 +166,8 @@ void ZmqRpcGenerator::ImplementGenericStubOtherFuncDecl(io::Printer &printer) const std::string otherFuncDecl = " void ForgetRequest(int64_t tagId);\n" " bool IsPeerAlive(uint32_t threshold);\n" - " void CacheSession(bool cache);\n"; + " void CacheSession(bool cache);\n" + " void SetExclusiveConnInfo(const std::optional &exclusiveId, const std::string &sockPath);\n"; printer.PrintRaw(otherFuncDecl); } diff --git a/src/datasystem/common/rpc/plugin_generator/zmq_rpc_generator.h b/src/datasystem/common/rpc/plugin_generator/zmq_rpc_generator.h index 96d24d86096c3806b025f837efc6b583c4431aa4..d6622321f10c35e41df505f337a783cfb0ef05b8 100644 --- a/src/datasystem/common/rpc/plugin_generator/zmq_rpc_generator.h +++ b/src/datasystem/common/rpc/plugin_generator/zmq_rpc_generator.h @@ -156,8 +156,10 @@ private: static void ImplementZmqCallMethodDecl(io::Printer &printer); static void ImplementZmqCallMethodDef(io::Printer &printer, const google::protobuf::ServiceDescriptor &svc, const std::string &indent, const std::string &svcName); + static void ImplementZmqDirectCallMethodDef(io::Printer &printer, const google::protobuf::ServiceDescriptor &svc, + const std::string &indent, const std::string &svcName); static void ImplementCallMethodNoStream(io::Printer &printer, const google::protobuf::MethodDescriptor &method, - int methodIndex, const std::string &indent); + int methodIndex, const std::string &indent, bool enableMsgQ = true); static void ImplementCallMethodClientStream(io::Printer &printer, const google::protobuf::MethodDescriptor &method, int methodIndex, const std::string &indent); static void ImplementCallMethodServerStream(io::Printer &printer, const google::protobuf::MethodDescriptor &method, @@ -165,7 +167,7 @@ private: static void ImplementCallMethodStream(io::Printer &printer, const google::protobuf::MethodDescriptor &method, int methodIndex, const std::string &indent); static void ImplementCallMethodUnarySocket(io::Printer &printer, const google::protobuf::MethodDescriptor &method, - int methodIndex, const std::string &indent); + int methodIndex, const std::string &indent, bool enableMsgQ = true); /** * @brief Implement stub api for both sides streaming. diff --git a/src/datasystem/common/rpc/rpc_channel.cpp b/src/datasystem/common/rpc/rpc_channel.cpp index 353d0667c7b5aaec8b64af2be59a975a38a7d26e..6a939a1fe31c4bb5bac8b1328c7ab2c0d4909ea9 100644 --- a/src/datasystem/common/rpc/rpc_channel.cpp +++ b/src/datasystem/common/rpc/rpc_channel.cpp @@ -20,7 +20,6 @@ */ #include #include "datasystem/common/rpc/rpc_channel.h" -#include "datasystem/common/rdma/urma_manager_wrapper.h" namespace datasystem { RpcChannel::RpcChannel(std::string zmqEndPoint, const RpcCredential &cred) diff --git a/src/datasystem/common/rpc/rpc_server_stream_base.h b/src/datasystem/common/rpc/rpc_server_stream_base.h index f0f514ee022c5cd02617d65894dffc70133d3b41..b0ddd960fd0af60c93840ff1a998269e847ba3ba 100644 --- a/src/datasystem/common/rpc/rpc_server_stream_base.h +++ b/src/datasystem/common/rpc/rpc_server_stream_base.h @@ -269,6 +269,26 @@ public: return std::visit([&payload](auto &pimpl) { return pimpl->SendPayload(payload); }, pimpl_); } + virtual Status GetOutMsg(ZmqMsgFrames &outMsg) + { + return std::visit([&outMsg](auto &pimpl) { return pimpl->GetOutMsg(outMsg); }, pimpl_); + } + + virtual bool EnableMsgQ() + { + return std::visit([](auto &pimpl) { return pimpl->EnableMsgQ(); }, pimpl_); + } + + void SetRequestInProgress() + { + return std::visit([](auto &pimpl) { return pimpl->SetRequestInProgress(); }, pimpl_); + } + + void SetRequestComplete() + { + return std::visit([](auto &pimpl) { return pimpl->SetRequestComplete(); }, pimpl_); + } + private: std::variant>> pimpl_; }; diff --git a/src/datasystem/common/rpc/unix_sock_fd.cpp b/src/datasystem/common/rpc/unix_sock_fd.cpp index 4733367d43d6be4736407e21c081982c564cf46b..1c9872619f4e89c8b9286aeff3a34cedaa253aa8 100644 --- a/src/datasystem/common/rpc/unix_sock_fd.cpp +++ b/src/datasystem/common/rpc/unix_sock_fd.cpp @@ -22,12 +22,12 @@ #include #include #include - #include "datasystem/common/flags/flags.h" #include "datasystem/common/util/fd_manager.h" #include "datasystem/common/util/file_util.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/timer.h" #include "datasystem/protos/meta_zmq.pb.h" #include "datasystem/protos/utils.pb.h" @@ -68,6 +68,17 @@ Status UnixSockFd::Poll(short event, int timeout) const } Status UnixSockFd::Recv(void *data, size_t size, bool blocking) const +{ + if (timeoutEnabled_) { + CHECK_FAIL_RETURN_STATUS(blocking, K_RUNTIME_ERROR, + "Receive with timeout is only supported for blocking receive!"); + return RecvWithTimeout(data, size); + } else { + return RecvNoTimeout(data, size, blocking); + } +} + +Status UnixSockFd::RecvNoTimeout(void *data, size_t size, bool blocking) const { PerfPoint point(PerfKey::ZMQ_SOCKET_FD_RECV); Status rc; @@ -101,6 +112,15 @@ Status UnixSockFd::Recv(void *data, size_t size, bool blocking) const } Status UnixSockFd::Send(MemView &buf) const +{ + if (timeoutEnabled_) { + return SendWithTimeout(buf); + } else { + return SendNoTimeout(buf); + } +} + +Status UnixSockFd::SendNoTimeout(MemView &buf) const { PerfPoint point(PerfKey::ZMQ_SOCKET_FD_SEND); Status rc; @@ -234,15 +254,16 @@ Status UnixSockFd::SetBlocking() const Status UnixSockFd::SetTimeout(int64_t timeout) const { - auto s = timeout / ONE_THOUSAND; - auto us = (timeout % ONE_THOUSAND) * ONE_THOUSAND; - struct timeval t { - .tv_sec = s, .tv_usec = us - }; - auto err = setsockopt(fd_, SOL_SOCKET, SO_RCVTIMEO, &t, sizeof(t)); - CHECK_FAIL_RETURN_STATUS(err != -1, K_RUNTIME_ERROR, FormatString("Socket set timeout error: errno = %d", errno)); - err = setsockopt(fd_, SOL_SOCKET, SO_SNDTIMEO, &t, sizeof(t)); - CHECK_FAIL_RETURN_STATUS(err != -1, K_RUNTIME_ERROR, FormatString("Socket set timeout error: errno = %d", errno)); + // Set both the send and recv timeout for this socket + RETURN_IF_NOT_OK(SetTimeout(TimeoutType::SendTimeout, timeout)); + RETURN_IF_NOT_OK(SetTimeout(TimeoutType::RecvTimeout, timeout)); + return Status::OK(); +} + +Status UnixSockFd::SetTimeoutEnforced(int64_t timeout) +{ + RETURN_IF_NOT_OK(SetTimeout(timeout)); + timeoutEnabled_ = true; return Status::OK(); } @@ -429,6 +450,22 @@ Status UnixSockFd::Connect(const std::string &ZmqEndPt) return Status::OK(); } +Status UnixSockFd::Accept(UnixSockFd &outSockFd) +{ + int newFd = accept(fd_, nullptr, nullptr); + if (newFd <= 0) { + Status rc = UnixSockFd::ErrnoToStatus(errno, fd_); + if (rc.IsError() && rc.GetCode() != K_TRY_AGAIN) { + VLOG(RPC_LOG_LEVEL) << FormatString("Spawn uds connection with listener fd %d failed with status %s", fd_, + rc.ToString()); + } + return rc; + } + + outSockFd = UnixSockFd(newFd); + return Status::OK(); +} + Status UnixSockFd::GetBindingHostPort(HostPort &out) const { sockaddr_in addr{}; @@ -445,4 +482,125 @@ Status UnixSockFd::GetBindingHostPort(HostPort &out) const return Status::OK(); } +Status UnixSockFd::SetTimeout(TimeoutType timeoutType, int64_t timeoutMs) const +{ + auto s = timeoutMs / ONE_THOUSAND; + auto us = (timeoutMs % ONE_THOUSAND) * ONE_THOUSAND; + struct timeval t { + .tv_sec = s, .tv_usec = us + }; + + // Note: std::to_underlying() is available in more recent C++ versions. static_cast for now. + auto err = setsockopt(fd_, SOL_SOCKET, static_cast(timeoutType), &t, sizeof(t)); + CHECK_FAIL_RETURN_STATUS(err != -1, K_RUNTIME_ERROR, FormatString("Socket set timeout error: errno = %d", errno)); + return Status::OK(); +} + +Status UnixSockFd::GetTimeout(TimeoutType timeoutType, int64_t &timeoutMs) const +{ + socklen_t tLen = sizeof(struct timeval); + struct timeval t; + + // Note: std::to_underlying() is available in more recent C++ versions. static_cast for now. + auto err = getsockopt(fd_, SOL_SOCKET, static_cast(timeoutType), &t, &tLen); + CHECK_FAIL_RETURN_STATUS(err != -1, K_RUNTIME_ERROR, FormatString("Socket get timeout error: errno = %d", errno)); + + // convert to ms for output + timeoutMs = (t.tv_sec * ONE_THOUSAND) + (t.tv_usec / ONE_THOUSAND); + return Status::OK(); +} + +Status UnixSockFd::RecvWithTimeout(void *data, size_t size) const +{ + PerfPoint point(PerfKey::ZMQ_SOCKET_FD_RECV); + int64_t startingTimeoutMs = 0; + int64_t timeRemainingMs = 0; + auto sizeRemain = static_cast(size); + + // Fetch the timeout from the fd. The timeout within the socket may continuously be adjusted as calls are made + // (and time has been consumed). + RETURN_IF_NOT_OK(GetTimeout(TimeoutType::RecvTimeout, startingTimeoutMs)); + + // Create a timer with the given amount of time remaining + Timer timer(startingTimeoutMs); + + while (sizeRemain > 0) { + ssize_t bytesReceived; + int err; + bytesReceived = recv(fd_, data, size, 0); + err = errno; + + timeRemainingMs = timer.GetRemainingTimeMs(); + // Regardless of success or fail of the call. If we ran out of time then return to the caller with error. + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(timeRemainingMs > 0, K_RPC_DEADLINE_EXCEEDED, "Socket recv timeout"); + + if (bytesReceived == -1) { + Status rc = ErrnoToStatus(err, fd_); + if (rc.GetCode() != K_TRY_AGAIN) { + VLOG(RPC_LOG_LEVEL) << "recv failed with rc: " << rc.ToString(); + } + RETURN_IF_NOT_OK_EXCEPT(rc, K_TRY_AGAIN); + } else if (bytesReceived == 0) { + RETURN_STATUS(StatusCode::K_RPC_CANCELLED, "bytesReceived is 0"); + } else { + // Record the received data so far + data = static_cast(data) + bytesReceived; + size -= bytesReceived; + sizeRemain -= bytesReceived; + } + + // Assign the timeout for the next receive call so that it has less time allowed than before, then reloop. + RETURN_IF_NOT_OK(SetTimeout(TimeoutType::RecvTimeout, timeRemainingMs)); + } + point.Record(); + // The recv timeout was set already naturally after the last recv. Update the send timeout to be the same. + RETURN_IF_NOT_OK(SetTimeout(TimeoutType::SendTimeout, timeRemainingMs)); + return Status::OK(); +} + +Status UnixSockFd::SendWithTimeout(MemView &buf) const +{ + PerfPoint point(PerfKey::ZMQ_SOCKET_FD_SEND); + Status rc; + int64_t startingTimeoutMs = 0; + int64_t timeRemainingMs = 0; + + // Fetch the timeout from the fd. The timeout within the socket may continuously be adjusted as calls are made + // (and time has been consumed). + RETURN_IF_NOT_OK(GetTimeout(TimeoutType::SendTimeout, startingTimeoutMs)); + + // Create a timer with the given amount of time remaining + Timer timer(startingTimeoutMs); + + auto sizeRemain = static_cast(buf.Size()); + while (sizeRemain > 0) { + ssize_t bytesSend; + int err; + bytesSend = send(fd_, buf.Data(), buf.Size(), MSG_NOSIGNAL); + err = errno; + + timeRemainingMs = timer.GetRemainingTimeMs(); + // Regardless of success or fail of the call. If we ran out of time then return to the caller with error. + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(timeRemainingMs > 0, K_RPC_DEADLINE_EXCEEDED, "Socket send timeout"); + + if (bytesSend == -1) { + rc = ErrnoToStatus(err, fd_); + if (rc.GetCode() != K_TRY_AGAIN) { + VLOG(RPC_LOG_LEVEL) << "send failed with rc: " << rc.ToString(); + } + RETURN_IF_NOT_OK_EXCEPT(rc, K_TRY_AGAIN); + } else { + buf += bytesSend; + sizeRemain -= bytesSend; + } + + // Assign the timeout for the next send call so that it has less time allowed than before, then reloop + RETURN_IF_NOT_OK(SetTimeout(TimeoutType::SendTimeout, timeRemainingMs)); + } + point.Record(); + + // The send timeout was set already naturally after the last send. Update the recv timeout to be the same. + RETURN_IF_NOT_OK(SetTimeout(TimeoutType::RecvTimeout, timeRemainingMs)); + return Status::OK(); +} } // namespace datasystem diff --git a/src/datasystem/common/rpc/unix_sock_fd.h b/src/datasystem/common/rpc/unix_sock_fd.h index 200dd28d24c1192a2d6678dcf473b6687382fb27..0a72b730d7a7f259491d1d30f1752fa15029c64d 100644 --- a/src/datasystem/common/rpc/unix_sock_fd.h +++ b/src/datasystem/common/rpc/unix_sock_fd.h @@ -167,8 +167,7 @@ public: // of the serialized protobuf followed by the protobuf itself. PerfPoint point(PerfKey::ZMQ_SOCK_RECV_PB); uint32_t sz; - // No need to block if nothing to read. - RETURN_IF_NOT_OK(Recv32(sz, false)); + RETURN_IF_NOT_OK(Recv32(sz, true)); std::unique_ptr wa; void *buf = nullptr; if (sz <= waSz_) { @@ -280,13 +279,32 @@ public: */ Status Connect(const std::string &ZmqEndPt); + /** + * @brief Uses this UnixSockFd as a listener socket and accepts incoming connection. Creates a new UnixSockFd as + * output as the connected socket for the caller. + * @param[out] outSockFd The new socket that was created when the listening socket accepted the connection. + * @return Status of the call. + */ + Status Accept(UnixSockFd &outSockFd); + /** * @brief Set timeout on a socket. + * This timeout only provides a wakeup code. The caller codepath continuously retries. + * Does not cause a failure of K_RPC_DEADLINE_EXCEEDED. * @param[in] timeout Timeout in milliseconds. * @return Status of call. */ Status SetTimeout(int64_t timeout) const; + /** + * @brief Set timeout on a socket for exclusive connection mode. + * Same as the above SetTimeout, however this version of the timeout disables the continuous retry logic and + * enforces that if the timeout is exceeded, it will fail with Status of K_RPC_DEADLINE_EXCEEDED. + * @param[in] timeout Timeout in milliseconds. + * @return Status of call. + */ + Status SetTimeoutEnforced(int64_t timeout); + /** * @brief Set buf size on a socket * @param sz @@ -330,8 +348,79 @@ public: Status GetBindingHostPort(datasystem::HostPort &out) const; private: + // The underlying type of this enum (int) matches the third argument for getsockopt and setsockopt from sys/socket.h + // system calls. + enum class TimeoutType : int { SendTimeout = SO_SNDTIMEO, RecvTimeout = SO_RCVTIMEO }; + + /** + * @brief Sets the timeout for either send or recv. + * @param[in] timeoutType The type to set (either send or recv) + * @param[in] timeoutMs The amount of time in milliseconds to set for the timeout + * @return Status of the call + */ + Status SetTimeout(TimeoutType timeoutType, int64_t timeoutMs) const; + + /** + * @brief Gets the timeout for either send or recv. + * @param[in] timeoutType The type to get (either send or recv) + * @param[out] timeoutMs The amount of time in milliseconds to set for the timeout + * @return Status of the call + */ + Status GetTimeout(TimeoutType timeoutType, int64_t &timeoutMs) const; + + /** + * @brief Receive a raw buffer, has some retry logic but does not respect overall timeout + * @param[in] data -- Address for the receiving buffer + * @param[in] size -- Size of the receiving buffer + * @param[in] blocking. For non-blocking fd, force non-blocking if true. + * @return Status of call. + */ + Status RecvNoTimeout(void *data, size_t size, bool blocking) const; + + /** + * @brief Receive a raw buffer with timeout support + * @param[in] data -- Address for the receiving buffer + * @param[in] size -- Size of the receiving buffer + * @return Status of call. + */ + Status RecvWithTimeout(void *data, size_t size) const; + + /** + * @brief Send a raw buffer. + * @param[in] buf Zmq immutable buffer. + * @return Status of call. + */ + Status SendNoTimeout(MemView &buf) const; + + /** + * @brief Send a raw buffer. + * @param[in] buf Zmq immutable buffer. + * @return Status of call. + */ + Status SendWithTimeout(MemView &buf) const; + + /** + * @brief Compute how much time has elapsed and the determine if the timeout has been exceeded. + * @param[in] startTime The start time to compare with + * @param[in] startingTimeout The amount of overall time that is allowed + * @param[out] timeDiff The computed amount of time between current time now and the start + * @param[out] timeRemaining The computed amount of time allowed for future send/recv calls + * @return Status of the call. Return K_RPC_DEADLINE_EXCEEDED if the timeout has expired. + */ + inline static Status CheckAndComputeTimeout(const std::chrono::time_point &startTime, + const int64_t &startingTimeoutMs, int64_t &timeDiff, + int64_t &timeRemainingMs) + { + std::chrono::time_point currentTime = std::chrono::steady_clock::now(); + timeDiff = std::chrono::duration_cast(currentTime - startTime).count(); + timeRemainingMs = startingTimeoutMs - timeDiff; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(timeRemainingMs > 0, K_RPC_DEADLINE_EXCEEDED, "Socket send/recv timeout"); + return Status::OK(); + } + constexpr static int waSz_ = 64; int fd_; + bool timeoutEnabled_{ false }; char workArea_[waSz_]{}; }; /** diff --git a/src/datasystem/common/rpc/zmq/exclusive_conn_mgr.cpp b/src/datasystem/common/rpc/zmq/exclusive_conn_mgr.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d0a569ee3e5b6f158c43d47cf8c14cb015798cab --- /dev/null +++ b/src/datasystem/common/rpc/zmq/exclusive_conn_mgr.cpp @@ -0,0 +1,120 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Exclusive connection manager. + */ + +#include "datasystem/common/rpc/zmq/exclusive_conn_mgr.h" + +#include +#include +#include +#include + +#include "datasystem/common/log/log.h" + +// A glibc work-around for gettid(). Newer glibc has gettid() wrapper rather than direct system call. +#define gettid() syscall(SYS_gettid) + +namespace datasystem { + +// global variable in datasystem thread scope +// The instantiation only happens on first access from the given thread to this global var. +thread_local ExclusiveConnMgr gExclusiveConnMgr; + +ExclusiveConnMgr::ExclusiveConnMgr() : sockTable_(TABLE_SIZE) +{ + pid_ = getpid(); + tid_ = gettid(); + LOG(INFO) << "A user thread " << tid_ << " in process " << pid_ << " created an exclusive connection manager."; +} + +Status ExclusiveConnMgr::CreateExclusiveConnection(int32_t exclusiveId, int64_t timeoutMs, const std::string &sockPath) +{ + CHECK_FAIL_RETURN_STATUS(IsExclusiveIdInRange(exclusiveId), K_RUNTIME_ERROR, "Exclusive id out of range"); + + // All slots in the table are pre-allocated (fixed length table). However, they might contain nullptr, which means + // the connection has not been created yet (it needs to be created). + if (!sockTable_[exclusiveId]) { + // Create connection entry with a uds socket and connect it to the service + std::unique_ptr conn; + RETURN_IF_NOT_OK(CreateExclusiveConn(conn, sockPath)); + VLOG(RPC_LOG_LEVEL) << FormatString("Exclusive connection created. exclusiveId: %d, fd: %d, timeout: %d", + exclusiveId, conn->sockFd_.GetFd(), timeoutMs); + sockTable_[exclusiveId] = std::move(conn); + } + + // A previous usage of the socket fd changes the remaining time allowed for sends/recv. + // Update this timeout now so that it starts with a fresh value + sockTable_[exclusiveId]->sockFd_.SetTimeoutEnforced(timeoutMs); + VLOG(RPC_LOG_LEVEL) << "Exclusive connection initialized. Timeout: " << timeoutMs; + return Status::OK(); +} + +Status ExclusiveConnMgr::GetExclusiveConnDecoder(int32_t exclusiveId, ZmqMsgDecoder *&decoder) +{ + CHECK_FAIL_RETURN_STATUS(IsExclusiveIdInRange(exclusiveId), K_RUNTIME_ERROR, "Exclusive id out of range"); + + CHECK_FAIL_RETURN_STATUS(sockTable_[exclusiveId] != nullptr, K_RUNTIME_ERROR, + "Missing exclusive socket connection: " + std::to_string(exclusiveId)); + + decoder = sockTable_[exclusiveId]->decoder_.get(); + return Status::OK(); +} + +Status ExclusiveConnMgr::GetExclusiveConnEncoder(int32_t exclusiveId, ZmqMsgEncoder *&encoder) +{ + CHECK_FAIL_RETURN_STATUS(IsExclusiveIdInRange(exclusiveId), K_RUNTIME_ERROR, "Exclusive id out of range"); + + CHECK_FAIL_RETURN_STATUS(sockTable_[exclusiveId] != nullptr, K_RUNTIME_ERROR, + "Missing exclusive socket connection: " + std::to_string(exclusiveId)); + + encoder = sockTable_[exclusiveId]->encoder_.get(); + return Status::OK(); +} + +Status ExclusiveConnMgr::CreateExclusiveConn(std::unique_ptr &conn, const std::string &sockPath) +{ + VLOG(RPC_LOG_LEVEL) << "Creaing an exclusive UDS connection and associated encoders/decoders"; + conn = std::make_unique(); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(conn->sockFd_.Connect(FormatString("ipc://%s", sockPath)), + FormatString("Failed to create exclusive UDS connection with sockPath %s", sockPath)); + + // The sock fd is owned here in the exclusive connection manager. + // These encoder/decoder constructors will create encoder/decoders that directly reference this sockFd, + // via pointer. + conn->encoder_ = std::make_unique(&conn->sockFd_); + conn->decoder_ = std::make_unique(&conn->sockFd_); + + return Status::OK(); +} + +std::string ExclusiveConnMgr::GetExclusiveConnMgrName() const +{ + std::ostringstream ss; + ss << pid_ << ":" << tid_; + return ss.str(); +} + +Status ExclusiveConnMgr::CloseExclusiveConn(int32_t exclusiveId) +{ + if (IsExclusiveIdInRange(exclusiveId)) { + sockTable_[exclusiveId].reset(); + } + return Status::OK(); +} +} // namespace datasystem diff --git a/src/datasystem/common/rpc/zmq/exclusive_conn_mgr.h b/src/datasystem/common/rpc/zmq/exclusive_conn_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..5d149e54614b5d4aae3a0aabdf49c54e950f6029 --- /dev/null +++ b/src/datasystem/common/rpc/zmq/exclusive_conn_mgr.h @@ -0,0 +1,131 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Manager for exclusive connections in thread local memory + */ +#ifndef DATASYSTEM_COMMON_RPC_ZQM_EXCLUSIVE_CONN_MGR_H +#define DATASYSTEM_COMMON_RPC_ZMQ_EXCLUSIVE_CONN_MGR_H + +#include +#include +#include +#include "datasystem/common/log/log.h" +#include "datasystem/common/rpc/unix_sock_fd.h" +#include "datasystem/common/rpc/zmq/zmq_msg_decoder.h" +#include "datasystem/common/util/status_helper.h" + +namespace datasystem { + +/** + * @brief The common use-case is that a single thread only uses one client and so it only needs to manage one + * connection in the thread local memory. + * But, it is not impossible that a given user thread might have other clients created. Since this class exists in + * thread local memory, we need an index or key to find the connection given a client identifier (exclusiveId) + * This manager is performance-optimized to only have one or few connections. Instead of a map and having to incur key + * lookups in a performance sensitive codepath, it uses a pre-allocated table (slots) with fast index deference for + * lookup. + * The table has a small size and it can grow to support more entries, but this is highly unlikely to ever happen. + */ +class ExclusiveConnMgr { +public: + /** + * @brief Constructor. Creates connection table with default number of allocated slots. + */ + ExclusiveConnMgr(); + + /** + * @brief Destructor + */ + ~ExclusiveConnMgr() = default; + + ExclusiveConnMgr(const ExclusiveConnMgr &) = delete; + ExclusiveConnMgr &operator=(const ExclusiveConnMgr &) = delete; + + /** + * @brief Fetchs the decoder for the exclusive connection + * @param[in] exclusiveId The identifier to indicate which slot to get the decoder from. + * @param[out] decoder Pointer to the decoder returned + * @return Status of the call + */ + Status GetExclusiveConnDecoder(int32_t exclusiveId, ZmqMsgDecoder *&decoder); + + /** + * @brief Fetchs the encoder for the exclusive connection + * @param[in] exclusiveId The identifier to indicate which slot to get the encoder from. + * @param[out] decoder Pointer to the encoder returned + * @return Status of the call + */ + Status GetExclusiveConnEncoder(int32_t exclusiveId, ZmqMsgEncoder *&encoder); + + /** + * @brief Creates the exclusive connection and saves it for future use + * @param[in] exclusiveId The client identifier for the connection + * @param[in] timeoutMs The timeout to use for send/recv calls in the connection + * @param[in] sockPath The path to use for connect with server to create the socket connection + * @return Status of the call + */ + Status CreateExclusiveConnection(int32_t exclusiveId, int64_t timeoutMs, const std::string &sockPath); + + /** + * @brief Returns the name of this exclusive conn manager. Used for diagnostic purposes. + * @return The string name for this manager + */ + std::string GetExclusiveConnMgrName() const; + + /** + * @brief Closes and frees the exclusive connection. + * @param[in] exclusiveId The client identifier for the connection + * @return Status of the call + */ + Status CloseExclusiveConn(int32_t exclusiveId); + +private: + class ExclusiveConn { + public: + ExclusiveConn() = default; + ~ExclusiveConn() + { + VLOG(RPC_LOG_LEVEL) << "Exclusive conn destructor will close socket fd: " << sockFd_.GetFd(); + sockFd_.Close(); + } + UnixSockFd sockFd_; + std::unique_ptr encoder_; + std::unique_ptr decoder_; + }; + static const int TABLE_SIZE = 128; + + Status CreateExclusiveConn(std::unique_ptr &conn, const std::string &sockPath); + + /** + * @brief Returns true of the exclusiveId exists in the correct range + * @param[in] exclusiveId The id to check + */ + inline bool IsExclusiveIdInRange(int32_t exclusiveId) + { + return (exclusiveId >= 0 && static_cast(exclusiveId) < sockTable_.size()); + } + + // Fixed length table (for performance on lookups). + std::vector> sockTable_; + pid_t pid_; + pid_t tid_; +}; + +extern thread_local ExclusiveConnMgr gExclusiveConnMgr; + +} // namespace datasystem +#endif // DATASYSTEM_COMMON_RPC_ZMQ_EXCLUSIVE_CONN_MGR_H diff --git a/src/datasystem/common/rpc/zmq/work_agent.cpp b/src/datasystem/common/rpc/zmq/work_agent.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb1c3d30990b248cf54b3016910cb47ae63a2588 --- /dev/null +++ b/src/datasystem/common/rpc/zmq/work_agent.cpp @@ -0,0 +1,112 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/common/rpc/zmq/work_agent.h" +#include +#include + +namespace datasystem { +WorkAgent::WorkAgent(const UnixSockFd &fd, ZmqService *svc, bool uds) + : sockFd_(fd), + uds_(uds), + decoder_(std::make_unique(&sockFd_)), + encoder_(std::make_unique(&sockFd_)), + svc_(svc) +{ +} + +Status WorkAgent::ClientToService(ZmqMetaMsgFrames &p) +{ + // *** Protocol FRAME 0 *** + // First, get the request header + // We can use V2 protocol to receive. It is compatible with V1 + ZmqMsgFrames frames; + auto rc = decoder_->ReceiveMsgFramesV2(frames); + if (rc.IsError()) { + LOG(ERROR) << FormatString("Error detected in decoder on fd %d. %s", sockFd_.GetFd(), rc.ToString()); + interrupted_.store(true, std::memory_order_release); + return rc; + } + + VLOG(RPC_LOG_LEVEL) << "# of frames received " << frames.size(); + // MetaPb is embedded in the incoming socket connection. For now, pass a fake one. + p.first = MetaPb(); + p.second = std::move(frames); + return Status::OK(); +} + +Status WorkAgent::ServiceToClient(ZmqMetaMsgFrames &p) +{ + MetaPb &meta = p.first; + CHECK_FAIL_RETURN_STATUS(meta.ticks_size() > 0, K_RUNTIME_ERROR, + FormatString("Incomplete MetaPb:\n%s", meta.DebugString())); + PerfPoint::RecordElapsed(PerfKey::ZMQ_APP_WORKLOAD, GetLapTime(meta, "ZMQ_APP_WORKLOAD")); + ZmqMsgFrames &frames = p.second; + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(meta.trace_id()); + // No need to prepend the gateway if it is direct connection + RETURN_IF_NOT_OK(PushFrontProtobufToFrames(meta, frames)); + RETURN_IF_NOT_OK(PushFrontStringToFrames(meta.client_id(), frames)); + // We need to match the client protocol to be downward compatible + auto rc = encoder_->SendMsgFrames(static_cast(meta.event_type()), frames); + if (rc.IsError()) { + LOG(ERROR) << FormatString("Error detected in encoder on fd %d. %s", sockFd_.GetFd(), rc.ToString()); + interrupted_.store(true, std::memory_order_release); + return rc; + } + return Status::OK(); +} + +Status WorkAgent::DoWork() +{ + VLOG(RPC_LOG_LEVEL) << "Work Agent Doing work"; + ZmqMetaMsgFrames inMsg; + ZmqMetaMsgFrames outMsg; + // receive msg from client + RETURN_IF_NOT_OK(ClientToService(inMsg)); + // execute internal method and get the reply + EventType type = uds_ ? EventType::V1MTP : (decoder_->V2Client() ? EventType::V2MTP : EventType::V1MTP); + RETURN_IF_NOT_OK(svc_->DirectExecInternalMethod(sockFd_.GetFd(), type, inMsg, outMsg)); + // send reply back to client + RETURN_IF_NOT_OK(ServiceToClient(outMsg)); + return Status::OK(); +} + +Status WorkAgent::Run() +{ + CHECK_FAIL_RETURN_STATUS(!interrupted_.load(std::memory_order_acquire), K_RUNTIME_ERROR, + FormatString("Wrong state in work agent. Interrupted flag should be false.")); + while (!interrupted_.load(std::memory_order_acquire)) { + DoWork(); + } + // Once the interrupt condition is triggered, close the socket. + CloseSocket(); + return Status::OK(); +} + +Status WorkAgent::CloseSocket() +{ + VLOG(RPC_LOG_LEVEL) << "WorkAgent shuts down and closes socket fd " << sockFd_.GetFd(); + sockFd_.Close(); + return Status::OK(); +} + +Status WorkAgent::Stop() +{ + interrupted_.store(true, std::memory_order_release); + return Status::OK(); +} + +} // namespace datasystem diff --git a/src/datasystem/common/rpc/zmq/work_agent.h b/src/datasystem/common/rpc/zmq/work_agent.h new file mode 100644 index 0000000000000000000000000000000000000000..51691f0a55ff563abe8c2e81260c44d4c91853f5 --- /dev/null +++ b/src/datasystem/common/rpc/zmq/work_agent.h @@ -0,0 +1,58 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Zmq server work agent for exclusive connections. + */ +#ifndef DATASYSTEM_WORKER_WORK_AGENT_H +#define DATASYSTEM_WORKER_WORK_AGENT_H + +#include +#include +#include +#include +#include "datasystem/common/rpc/unix_sock_fd.h" +#include "datasystem/common/rpc/zmq/zmq_msg_decoder.h" +#include "datasystem/common/rpc/zmq/zmq_service.h" + +namespace datasystem { +class ZmqService; + +class WorkAgent { +public: + WorkAgent(const UnixSockFd &fd, ZmqService *svc, bool uds); + ~WorkAgent() = default; + Status Run(); + Status DoWork(); + Status Stop(); + Status CloseSocket(); + +private: + friend class ZmqService; + Status ClientToService(ZmqMetaMsgFrames &p); + Status ServiceToClient(ZmqMetaMsgFrames &p); + UnixSockFd sockFd_; + bool uds_; + std::atomic interrupted_{ false }; + std::thread::id workerId_; + + std::unique_ptr decoder_; + std::unique_ptr encoder_; + ZmqService *svc_; +}; + +} // namespace datasystem +#endif // DATASYSTEM_WORKER_WORK_AGENT_H diff --git a/src/datasystem/common/rpc/zmq/zmq_epoll.cpp b/src/datasystem/common/rpc/zmq/zmq_epoll.cpp index c1ef5438597b30237f494a1c9e977964e2e2da55..6c2e36e03d178e54f7eb1cb4c4c917f385eff8ad 100644 --- a/src/datasystem/common/rpc/zmq/zmq_epoll.cpp +++ b/src/datasystem/common/rpc/zmq/zmq_epoll.cpp @@ -180,7 +180,7 @@ Status ZmqEpoll::HandleEvent(int timeout) continue; } rc = pe->inEventFunc_(pe, ev.events); - VLOG_IF(RPC_LOG_LEVEL, rc.IsError()) << FormatString("%s fd %d", rc.ToString(), pe->fd_); + VLOG(RPC_LOG_LEVEL) << FormatString("%s fd %d", rc.ToString(), pe->fd_); } if (ev.events & EPOLLOUT) { // We need to check again the pe and fd again because we can get both @@ -190,7 +190,7 @@ Status ZmqEpoll::HandleEvent(int timeout) continue; } rc = pe->outEventFunc_(pe, ev.events); - VLOG_IF(RPC_LOG_LEVEL, rc.IsError()) << FormatString("%s fd %d", rc.ToString(), pe->fd_); + VLOG(RPC_LOG_LEVEL) << FormatString("%s fd %d", rc.ToString(), pe->fd_); } } HandlePendingClose(); diff --git a/src/datasystem/common/rpc/zmq/zmq_msg_decoder.cpp b/src/datasystem/common/rpc/zmq/zmq_msg_decoder.cpp index fff42074092b35885b1cc705b1affcb6224fee56..3a3ad07497b91e0db8b965146633dfc38b8df493 100644 --- a/src/datasystem/common/rpc/zmq/zmq_msg_decoder.cpp +++ b/src/datasystem/common/rpc/zmq/zmq_msg_decoder.cpp @@ -62,9 +62,13 @@ Status ZmqMsgDecoder::Recv() // It is assumed the underlying file descriptor is non-blocking CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(K_WA_SIZE >= bytesReceived_, K_RUNTIME_ERROR, FormatString("Invalid bytesReceived_ %zu", bytesReceived_)); - ssize_t bytesReceived = recv(fd_, buf + bytesReceived_, K_WA_SIZE - bytesReceived_, 0); + + // This code is doing direct socket system call instead of going through the UnixSockFd interface. + // This may be a concern for exclusive connection logic. Current, this codepath does not seem to be called from the + // client-side, so it is not currently an issue. + ssize_t bytesReceived = recv(pSockFd_->GetFd(), buf + bytesReceived_, K_WA_SIZE - bytesReceived_, 0); if (bytesReceived == -1) { - auto rc = UnixSockFd::ErrnoToStatus(errno, fd_); + auto rc = UnixSockFd::ErrnoToStatus(errno, pSockFd_->GetFd()); return rc; } if (bytesReceived == 0) { @@ -97,8 +101,7 @@ Status ZmqMsgDecoder::DecodeHdrLen(MsgState &state) { CHECK_FAIL_RETURN_STATUS(state == MsgState::HDR_LEN_READY, K_RUNTIME_ERROR, "Wrong state"); if (Empty()) { - UnixSockFd sock(fd_); - RETURN_IF_NOT_OK(sock.RecvProtobuf(hdr_)); + RETURN_IF_NOT_OK(pSockFd_->RecvProtobuf(hdr_)); // Move the state to detect if it is V1 or V2 state = MsgState::MTP_DETECT; return Status::OK(); @@ -146,11 +149,11 @@ Status ZmqMsgDecoder::DetectMTP(MsgState &state) // V2 sends an empty header newFormat_ = hdr_.msg_size_size() == 0; if (newFormat_) { - VLOG(RPC_LOG_LEVEL) << FormatString("V2 format detected for fd %d", fd_); + VLOG(RPC_LOG_LEVEL) << FormatString("V2 format detected for fd %d", pSockFd_->GetFd()); state = MsgState::FLAGS_READY; } else { // V1 format - VLOG(RPC_LOG_LEVEL) << FormatString("V1 format detected for fd %d", fd_); + VLOG(RPC_LOG_LEVEL) << FormatString("V1 format detected for fd %d", pSockFd_->GetFd()); state = MsgState::DOWNLEVEL_CLIENT; } return Status::OK(); @@ -163,7 +166,8 @@ Status ZmqMsgDecoder::V1Client(MsgState &state) // possibly some bytes in the work area. v1Frames_.clear(); const int numMsg = hdr_.msg_size_size(); - VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to receive %d frames from fd %d using V1 format", numMsg, fd_); + VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to receive %d frames from fd %d using V1 format", numMsg, + pSockFd_->GetFd()); for (auto i = 0; i < hdr_.msg_size_size(); ++i) { size_t msgReadSoFar = 0; ZmqMessage msg; @@ -172,10 +176,9 @@ Status ZmqMsgDecoder::V1Client(MsgState &state) RETURN_IF_NOT_OK(TransferFromWA(msg.Data(), sz, msgReadSoFar)); // For the rest we will simply read directly into the ZmqMessage. if (msgReadSoFar < sz) { - UnixSockFd sock(fd_); // We will block ourselves until we get all the data. RETURN_IF_NOT_OK( - sock.Recv(reinterpret_cast(msg.Data()) + msgReadSoFar, sz - msgReadSoFar, true)); + pSockFd_->Recv(reinterpret_cast(msg.Data()) + msgReadSoFar, sz - msgReadSoFar, true)); } VLOG(RPC_LOG_LEVEL) << "Frame (" << i << ") received. Size " << msg.Size() << " ... " << msg; v1Frames_.push_back(std::move(msg)); @@ -313,9 +316,8 @@ Status ZmqMsgDecoder::ReadMessage(MsgState &state, void *dest, size_t sz) // For the rest or large payload, we will simply read directly into the ZmqMessage. if (msgReadSoFar < msgSize_) { // We will block ourselves until we get all the data. - UnixSockFd sock(fd_); - RETURN_IF_NOT_OK( - sock.Recv(reinterpret_cast(inProcess_.Data()) + msgReadSoFar, msgSize_ - msgReadSoFar, true)); + RETURN_IF_NOT_OK(pSockFd_->Recv(reinterpret_cast(inProcess_.Data()) + msgReadSoFar, + msgSize_ - msgReadSoFar, true)); } chgState(); return Status::OK(); @@ -375,18 +377,17 @@ Status ZmqMsgDecoder::GetMessage(ZmqMessage &outMsg, bool &more) return rc; } -Status ZmqMsgDecoder::ReceiveMsgFramesV1(ZmqMsgFrames &frames) const +Status ZmqMsgDecoder::ReceiveMsgFramesV1(ZmqMsgFrames &frames) { - UnixSockFd sock(fd_); MultiMsgHdrPb hdr; - RETURN_IF_NOT_OK(sock.RecvProtobuf(hdr)); + RETURN_IF_NOT_OK(pSockFd_->RecvProtobuf(hdr)); const int numMsg = hdr.msg_size_size(); VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to receive %d frames from fd %d using V1 format", numMsg, - sock.GetFd()); + pSockFd_->GetFd()); for (int i = 0; i < numMsg; ++i) { ZmqMessage msg; RETURN_IF_NOT_OK(msg.AllocMem(hdr.msg_size(i))); - RETURN_IF_NOT_OK(sock.Recv(msg.Data(), msg.Size(), true)); + RETURN_IF_NOT_OK(pSockFd_->Recv(msg.Data(), msg.Size(), true)); VLOG(RPC_LOG_LEVEL) << "Frame (" << i << ") received. Size " << msg.Size() << " ... " << msg; frames.push_back(std::move(msg)); } @@ -465,7 +466,22 @@ Status ZmqMsgDecoder::ReceivePayloadIntoMemory(void *dest, size_t sz) } ZmqMsgDecoder::ZmqMsgDecoder(int fd) - : fd_(fd), + : sockFd_(fd), + curFrame_(0), + msgState_(MsgState::HDR_LEN_READY), + flag_(MTP_PROTOCOL::MTP_NONE), + bytesReceived_(0), + pos_(0), + msgSize_(0), + rpcHdrSz_(0), + newFormat_(true) +{ + wa_ = std::make_unique(K_WA_SIZE); + pSockFd_ = &sockFd_; +} + +ZmqMsgDecoder::ZmqMsgDecoder(UnixSockFd *sockFdRef) + : pSockFd_(sockFdRef), curFrame_(0), msgState_(MsgState::HDR_LEN_READY), flag_(MTP_PROTOCOL::MTP_NONE), @@ -505,20 +521,18 @@ Status ZmqMsgEncoder::SendMessage(const ZmqMessage &msg, bool more) const if (type == ZmqMessage::ZmqMsgType::DECODER) { hdr.flag_ |= MTP_DECODER; } - UnixSockFd sock(fd_); const int SHORT_LENGTH = 2; MemView buf(&hdr, (hdr.flag_ & MTP_LONG) ? K_EIGHT_BYTE + 1 : SHORT_LENGTH); - RETURN_IF_NOT_OK(sock.Send(buf)); + RETURN_IF_NOT_OK(pSockFd_->Send(buf)); if (sz > 0) { buf = MemView(msg.Data(), sz); - RETURN_IF_NOT_OK(sock.Send(buf)); + RETURN_IF_NOT_OK(pSockFd_->Send(buf)); } return Status::OK(); } -Status ZmqMsgEncoder::SendMsgFramesV1(ZmqMsgFrames &que) const +Status ZmqMsgEncoder::SendMsgFramesV1(ZmqMsgFrames &que) { - UnixSockFd sock(fd_); MultiMsgHdrPb hdr; auto it = que.begin(); while (it != que.end()) { @@ -526,28 +540,28 @@ Status ZmqMsgEncoder::SendMsgFramesV1(ZmqMsgFrames &que) const ++it; } VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to send %d frames to fd %d using V1 format", hdr.msg_size_size(), - sock.GetFd()); - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(sock.SendProtobuf(hdr), FormatString("Errno = %d", errno)); + pSockFd_->GetFd()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(pSockFd_->SendProtobuf(hdr), FormatString("Errno = %d", errno)); int i = 0; while (!que.empty()) { auto &msg = que.front(); MemView buf(msg.Data(), msg.Size()); - RETURN_IF_NOT_OK(sock.Send(buf)); + RETURN_IF_NOT_OK(pSockFd_->Send(buf)); VLOG(RPC_LOG_LEVEL) << "Frame (" << i++ << ") sent. Size " << msg.Size() << " ... " << msg; que.pop_front(); } return Status::OK(); } -Status ZmqMsgEncoder::SendMsgFramesV2(ZmqMsgFrames &que) const +Status ZmqMsgEncoder::SendMsgFramesV2(ZmqMsgFrames &que) { - VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to send %d frames to fd %d using V2 format", que.size(), fd_); + VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to send %d frames to fd %d using V2 format", que.size(), + pSockFd_->GetFd()); // We will send an empty MultiMsgHdrPb to be compatible with V1, // and remote peer can distinguish if it is V1 or V2 format. { - UnixSockFd sock(fd_); MultiMsgHdrPb hdr; - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(sock.SendProtobuf(hdr), FormatString("Errno = %d", errno)); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(pSockFd_->SendProtobuf(hdr), FormatString("Errno = %d", errno)); } int i = 0; bool more = true; @@ -566,7 +580,7 @@ Status ZmqMsgEncoder::SendMsgFrames(EventType type, ZmqMsgFrames &frames) if (type == V2MTP) { return SendMsgFramesV2(frames); } else if (type == V1MTP) { - VLOG(RPC_LOG_LEVEL) << FormatString("Fall back to V1 format for fd %d", fd_); + VLOG(RPC_LOG_LEVEL) << FormatString("Fall back to V1 format for fd %d", pSockFd_->GetFd()); return SendMsgFramesV1(frames); } RETURN_STATUS(K_INVALID, FormatString("Unsupported type %d", type)); diff --git a/src/datasystem/common/rpc/zmq/zmq_msg_decoder.h b/src/datasystem/common/rpc/zmq/zmq_msg_decoder.h index ab66144ae8b64ea272ba758ab7bbb6920e1d93e4..8456e44f43a83e01884e6a49786e195943493a01 100644 --- a/src/datasystem/common/rpc/zmq/zmq_msg_decoder.h +++ b/src/datasystem/common/rpc/zmq/zmq_msg_decoder.h @@ -61,7 +61,20 @@ public: EIGHT_BYTE_SIZE_READY, MESSAGE_READY }; + + /** + * @brief This constructor builds a decoder that operates on a file descriptor used by UnixSockFd + * @param[in] fd The fd to use for the internal UnixSockFd + */ explicit ZmqMsgDecoder(int fd); + + /** + * @brief This constructor builds a decoder that operates on a file descriptor referenced by the pointer to the + * existing UnixSockFd. + * @param[in] sockFdRef The pointer for the externally managed UnixSockFd + */ + explicit ZmqMsgDecoder(UnixSockFd *sockFdRef); + ~ZmqMsgDecoder(); /** @@ -70,7 +83,7 @@ public: * @return * @note Doesn't support write into user provided buffers */ - Status ReceiveMsgFramesV1(ZmqMsgFrames &frames) const; + Status ReceiveMsgFramesV1(ZmqMsgFrames &frames); /** * Version 2 protocol of receiving ZmqMessages @@ -94,7 +107,7 @@ public: */ auto GetFd() const { - return fd_; + return sockFd_.GetFd(); } /** @@ -129,7 +142,8 @@ public: private: constexpr static int K_WA_SIZE = 1024; - int fd_; + UnixSockFd sockFd_; + UnixSockFd *pSockFd_; int curFrame_; MsgState msgState_; MTP_PROTOCOL flag_; @@ -166,9 +180,24 @@ private: */ class ZmqMsgEncoder { public: - explicit ZmqMsgEncoder(int fd) : fd_(fd) + /** + * @brief This constructor builds an encoder that operates on a file descriptor used by UnixSockFd + * @param[in] fd The fd to use for the internal UnixSockFd + */ + explicit ZmqMsgEncoder(int fd) : sockFd_(fd) { + pSockFd_ = &sockFd_; } + + /** + * @brief This constructor builds a decoder that operates on a file descriptor referenced by the pointer to the + * existing UnixSockFd. + * @param[in] sockFdRef The pointer for the externally managed UnixSockFd + */ + ZmqMsgEncoder(UnixSockFd *sockFdRef) : pSockFd_(sockFdRef) + { + } + ~ZmqMsgEncoder() = default; /** @@ -177,7 +206,8 @@ public: Status SendMsgFrames(EventType type, ZmqMsgFrames &frames); private: - int fd_; + UnixSockFd sockFd_; + UnixSockFd *pSockFd_; /** * Version 1 protocol of sending ZmqMessages @@ -185,14 +215,14 @@ private: * @return * @note Doesn't support write into user provided buffers */ - Status SendMsgFramesV1(ZmqMsgFrames &que) const; + Status SendMsgFramesV1(ZmqMsgFrames &que); /** * Version 2 protocol of sending ZmqMessages * @param que * @return */ - Status SendMsgFramesV2(ZmqMsgFrames &que) const; + Status SendMsgFramesV2(ZmqMsgFrames &que); Status SendMessage(const ZmqMessage &msg, bool more) const; }; diff --git a/src/datasystem/common/rpc/zmq/zmq_server_impl.cpp b/src/datasystem/common/rpc/zmq/zmq_server_impl.cpp index f4e097a4f10ac94ee29bb8b5bac317449ec562a3..7d78cdbe0c04bb3ccf74707c026677716be8711c 100644 --- a/src/datasystem/common/rpc/zmq/zmq_server_impl.cpp +++ b/src/datasystem/common/rpc/zmq/zmq_server_impl.cpp @@ -163,7 +163,7 @@ Status ParseMsgFrames(ZmqMsgFrames &frames, MetaPb &meta, int fd, const EventTyp PerfPoint::RecordElapsed(PerfKey::ZMQ_NETWORK_TRANSFER, lapTime); // Set up the return address. meta.set_gateway_id(ZmqMessageToString(gatewayId)); - meta.set_routing_fd(std::to_string(fd)); + meta.set_route_fd(fd); meta.set_event_type(type); RecalculateMetaTimeout(meta, lapTime); return Status::OK(); @@ -582,9 +582,7 @@ Status IOService::ServiceToClient(ZmqPollEntry *pe, EventsVal events) PerfPoint::RecordElapsed(PerfKey::ZMQ_FRONTEND_TO_IOSVC, GetLapTime(meta, "ZMQ_FRONTEND_TO_IOSVC")); ZmqMsgFrames &frames = p.second; TraceGuard traceGuard = Trace::Instance().SetTraceNewID(meta.trace_id()); - int fd; - CHECK_FAIL_RETURN_STATUS(StringToInt(meta.routing_fd(), fd), K_RUNTIME_ERROR, - "String convert to int failed, service to client failed"); + int fd = meta.route_fd(); // No need to prepend the gateway if it is direct connection RETURN_IF_NOT_OK(PushFrontProtobufToFrames(meta, frames)); RETURN_IF_NOT_OK(PushFrontStringToFrames(meta.client_id(), frames)); diff --git a/src/datasystem/common/rpc/zmq/zmq_server_stream_base.h b/src/datasystem/common/rpc/zmq/zmq_server_stream_base.h index 03a052a5f635b6f7d032dd4202b1cf6a80bb774b..bd9a6c0fecd729f97db383a81e0cff7c35f0cb0e 100644 --- a/src/datasystem/common/rpc/zmq/zmq_server_stream_base.h +++ b/src/datasystem/common/rpc/zmq/zmq_server_stream_base.h @@ -25,7 +25,8 @@ #include #include - +#include +#include #include "datasystem/common/rpc/rpc_message.h" #include "datasystem/common/rpc/zmq/zmq_service.h" #include "datasystem/common/rpc/zmq/zmq_stream_base.h" @@ -266,11 +267,17 @@ class ServerUnaryWriterReaderImpl : public StreamBase { public: explicit ServerUnaryWriterReaderImpl(std::shared_ptr mQue, const MetaPb &meta, ZmqMsgFrames &&inMsg, bool sendPayload, bool recvPayload) - : StreamBase::StreamBase(sendPayload, recvPayload), mQue_(std::move(mQue)), writeOnce_(false), readOnce_(false) + : StreamBase::StreamBase(sendPayload, recvPayload), + mQue_(std::move(mQue)), + writeOnce_(false), + readOnce_(false), + requestComplete_(true) { meta_ = meta; inMsg_ = std::move(inMsg); + enableMsgQ_ = (mQue_ != nullptr); } + ~ServerUnaryWriterReaderImpl() override = default; virtual Status SendStatus(const Status &rc) @@ -292,6 +299,49 @@ public: } } + Status GetOutMsg(ZmqMsgFrames &outMsg) + { + // Most codepaths do not have async handling, and requestComplete_ will always be true. + // If it does have async, then requestComplete_ will be initialized to false and we as the parent need to wait + // for the child thread to inform us when it is safe to continue processing after the request is done. + if (!requestComplete_) { + VLOG(RPC_LOG_LEVEL) << "Work agent needs to wait for async request to complete before returning results."; + std::unique_lock lock(requestCompleteMtx_); + requestCompleteCond_.wait(lock, [this] { return (requestComplete_ == true); }); + // There is no need to flip the requestComplete to false again because this class is instantiated for every + // request and is not re-used. + VLOG(RPC_LOG_LEVEL) << "Work agent was notified that the request completed."; + } + outMsg = std::move(outMsg_); + return Status::OK(); + } + + virtual bool EnableMsgQ() + { + return enableMsgQ_; + } + + void SetRequestInProgress() + { + if (!enableMsgQ_) { + requestComplete_ = false; + } + // no-op if message queues were used. This call only relevent for exclusive connection mode. + // requestComplete_ remains true in that case. + } + + void SetRequestComplete() + { + if (!enableMsgQ_) { + // Signal that the request is done. + std::unique_lock lock(requestCompleteMtx_); + requestComplete_ = true; + lock.unlock(); + requestCompleteCond_.notify_one(); + } + // no-op if message queues were used. This call only relevent for exclusive connection mode. + } + Status SendAll(ZmqSendFlags flags) override { PerfPoint::RecordElapsed(PerfKey::ZMQ_APP_WORKLOAD, GetLapTime(meta_, "ZMQ_APP_WORKLOAD")); @@ -310,7 +360,29 @@ public: outMsg_.push_back(std::move(rcMsg)); RETURN_IF_NOT_OK(PushBackProtobufToFrames(pb, outMsg_)); RETURN_OK_IF_TRUE(HasRecvPayloadOp()); - return SendAll(ZmqSendFlags::NONE); + if (enableMsgQ_) { + return SendAll(ZmqSendFlags::NONE); + } + return Status::OK(); + } else { + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "ServerUnaryWriterReaderImpl is only supposed to be used once!"); + } + } + + virtual Status ConstructWriteMsg(const W &pb, ZmqMsgFrames &outMsg) + { + CHECK_FAIL_RETURN_STATUS(!enableMsgQ_, StatusCode::K_RUNTIME_ERROR, + "Invoke ConstructWriteMsg() only if enableMsgQ_ flag is off."); + bool expected = false; + if (writeOnce_.compare_exchange_strong(expected, true)) { + VLOG(RPC_LOG_LEVEL) << "Server uses unary socket sending rc " << Status::OK() << " message " + << LogHelper::IgnoreSensitive(pb) << " back to client " << meta_.client_id() + << std::endl; + ZmqMessage rcMsg = StatusToZmqMessage(Status::OK()); + outMsg.push_back(std::move(rcMsg)); + RETURN_IF_NOT_OK(PushBackProtobufToFrames(pb, outMsg)); + RETURN_OK_IF_TRUE(HasRecvPayloadOp()); + return Status::OK(); } else { RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "ServerUnaryWriterReaderImpl is only supposed to be used once!"); } @@ -335,7 +407,10 @@ public: "Server uses unary socket to send %zu payload bytes to Service %s Method" "%d to client %s", bufSz, meta_.svc_name(), meta_.method_index(), meta_.client_id()); - return SendAll(ZmqSendFlags::NONE); + if (enableMsgQ_) { + return SendAll(ZmqSendFlags::NONE); + } + return Status::OK(); } virtual Status SendPayload(std::vector &buffer) @@ -428,6 +503,10 @@ protected: private: std::atomic writeOnce_; std::atomic readOnce_; + bool enableMsgQ_; + std::atomic requestComplete_; + std::mutex requestCompleteMtx_; + std::condition_variable requestCompleteCond_; }; } // namespace datasystem #endif // DATASYSTEM_COMMON_RPC_ZMQ_STREAM_SERVER_H diff --git a/src/datasystem/common/rpc/zmq/zmq_service.cpp b/src/datasystem/common/rpc/zmq/zmq_service.cpp index 72cf00469ba758db1239dc59d7968882b075472b..72f8e31a08455ec32fa8e385cf88e9f9a30186ab 100644 --- a/src/datasystem/common/rpc/zmq/zmq_service.cpp +++ b/src/datasystem/common/rpc/zmq/zmq_service.cpp @@ -36,11 +36,13 @@ DS_DECLARE_string(unix_domain_socket_dir); namespace datasystem { +static const int MAX_EXCLUSIVE_CONNECTIONS_LIMIT = 128; ZmqService::ZmqService() : proxy_(nullptr), outfd_(ZMQ_NO_FILE_FD), infd_(ZMQ_NO_FILE_FD), tcpfd_(ZMQ_NO_FILE_FD), + exclListenFd_(ZMQ_NO_FILE_FD), nextWorker_(0), globalInterrupt_(false), streamSupport_(false), @@ -82,6 +84,15 @@ ZmqService::~ZmqService() if (payloadBank_ != nullptr) { payloadBank_.reset(); } + for (auto &agent : workAgents_) { + try { + agent->Stop(); + agent->CloseSocket(); + } catch (const std::exception &e) { + VLOG(ERROR) << "A work agent got an exception during Stop(): " << e.what(); + } + } + workAgentThreadPool_.reset(); } Status ZmqService::CreateWorkerCBs() @@ -190,6 +201,20 @@ Status ZmqService::BindUnixPath() RETURN_IF_NOT_OK(AddListenFd(listenFd)); sockPath_.emplace_back(sockPath); } + + // create a exclusive connection listener fd. + sockaddr_un addr{}; + auto exclSockPath = FormatString("%s/%s", path, "exclusiveConn"); + unlink(exclSockPath.data()); + UnixSockFd tempsockfd; + RETURN_IF_NOT_OK(tempsockfd.CreateUnixSocket()); + RETURN_IF_NOT_OK(UnixSockFd::SetUpSockPath(exclSockPath, addr)); + RETURN_IF_NOT_OK(tempsockfd.Bind(addr, RPC_SOCK_MODE)); + RETURN_IF_NOT_OK(tempsockfd.SetNonBlocking()); + exclListenFd_ = tempsockfd.GetFd(); + RETURN_IF_NOT_OK(AddListenFd(exclListenFd_)); + sockPath_.emplace_back(exclSockPath); + exclSockPath_ = exclSockPath; return Status::OK(); } @@ -528,7 +553,7 @@ Status ZmqService::WorkerCB::ProcessStreamRpcRq(const MetaPb &meta, ZmqMsgFrames CHECK_FAIL_RETURN_STATUS(methodObj->ClientStreaming() || methodObj->ServerStreaming(), K_RUNTIME_ERROR, "Not streaming method"); m.set_gateway_id(meta.gateway_id()); - m.set_routing_fd(meta.routing_fd()); + m.set_route_fd(meta.route_fd()); m.set_event_type(meta.event_type()); std::string workerId; PerfPoint point(PerfKey::ZMQ_GET_STREAM_WORKER); @@ -554,7 +579,7 @@ Status ZmqService::WorkerCB::ProcessHandshakeRq(const MetaPb &meta, ZmqMsgFrames RETURN_IF_NOT_OK(ParseFromZmqMessage(inMsg.front(), m)); inMsg.pop_front(); m.set_gateway_id(meta.gateway_id()); - m.set_routing_fd(meta.routing_fd()); + m.set_route_fd(meta.route_fd()); m.set_event_type(meta.event_type()); // Create a bank entry HandshakeTokenPb reply; @@ -639,9 +664,7 @@ Status ZmqService::WorkerCB::HandleInternalRq(int fd, const MetaPb &meta, ZmqMsg Status ZmqService::WorkerCB::WorkerEntryImpl(MetaPb &meta, ZmqMsgFrames &inMsg, ZmqMsgFrames &replyMsg) { - int fd; - CHECK_FAIL_RETURN_STATUS(StringToInt(meta.routing_fd(), fd), K_RUNTIME_ERROR, - "String convert to int failed, zmq service worker entry failed"); + int fd = meta.route_fd(); const int idx = meta.method_index(); if (idx >= 0) { auto remainingTime = reqTimeoutDuration.CalcRealRemainingTime(); @@ -695,6 +718,32 @@ Status ZmqService::WorkerCB::WorkerEntry() return Status::OK(); } +Status ZmqService::WorkerCB::WorkerEntryWithoutMsgQ(ZmqMetaMsgFrames &inMsg, ZmqMetaMsgFrames &outMsg) +{ + ReadLock rlock(&inUse_); + MetaPb &meta = inMsg.first; + // Check point. + if (impl_->globalInterrupt_) { + return Status::OK(); + } + VLOG(RPC_LOG_LEVEL) << FormatString("Worker %s started for service '%s' Method %d serving %s", GetWorkerId(), + meta.svc_name(), meta.method_index(), meta.client_id()); + ZmqMsgFrames replyMsg; + Status rc; + const int idx = meta.method_index(); + if (idx >= 0) { + // There is one more protobuf after, but we can't parse it (yet) and leave + // it to the lower level to decode. Also note if the client side is streaming, + // it will be handled by StreamWorkEntry. + rc = impl_->DirectCallMethod(meta, std::move(inMsg.second), 0, replyMsg); + } + VLOG(RPC_LOG_LEVEL) << "Service '" << impl_->ServiceName() << "' Method " << meta.method_index() << " rc " + << rc.ToString(); + outMsg.first = std::move(meta); + outMsg.second = std::move(replyMsg); + return Status::OK(); +} + Status ZmqService::WorkerCB::StreamWorkerEntryImpl() { Status rc; @@ -824,7 +873,7 @@ Status ZmqService::ProcessPayloadGetRq(MetaPb &meta, ZmqMsgFrames &inMsg, ZmqMsg if (meta.event_type() == EventType::ZMQ) { int fd = std::get(entry); meta.set_event_type(EventType::V2MTP); - meta.set_routing_fd(std::to_string(fd)); + meta.set_route_fd(fd); VLOG(RPC_KEY_LOG_LEVEL) << FormatString("Choosing fd %d for V2MTP for service %s client %s", fd, meta.svc_name(), meta.client_id()); } @@ -888,8 +937,7 @@ Status ZmqService::ParkPayloadIfNeeded(ZmqMetaMsgFrames &p, ZmqMsgFrames &payloa CHECK_FAIL_RETURN_STATUS(routeIt != routes_.end(), K_NOT_FOUND, "No other routes"); fd = *(routeIt->second.begin()); } else { - CHECK_FAIL_RETURN_STATUS(StringToInt(meta.routing_fd(), fd), K_RUNTIME_ERROR, - "String convert to int failed, service to client failed"); + fd = meta.route_fd(); } // At this point, we know the client can support V2MTP. We will use FLAGS_payload_nocopy_threshold @@ -939,13 +987,11 @@ Status ZmqService::ServiceToClient(ZmqMetaMsgFrames &p) rpc2.first = std::move(m); rpc2.second = std::move(payload); } - int fd; + int fd = meta.route_fd(); if (meta.event_type() == EventType::ZMQ) { RETURN_IF_NOT_OK(replyQueue_->Put(std::move(p))); eventfd_write(outfd_, 1); } else { - CHECK_FAIL_RETURN_STATUS(StringToInt(meta.routing_fd(), fd), K_RUNTIME_ERROR, - "String convert to int failed, service to client failed"); RETURN_IF_NOT_OK(io_->ServiceToClient(fd, p)); } if (offlineRpc) { @@ -1135,31 +1181,80 @@ Status ZmqService::FrontendToBackend(int fd, const EventType type, ZmqMetaMsgFra Status ZmqService::ProcessAccept(int listenFd) { - static std::map hasBeenLogged; bool isTcp = (listenFd == tcpfd_); - auto fd = accept(listenFd, nullptr, nullptr); - if (fd > 0) { - UnixSockFd sock(fd); - hasBeenLogged[listenFd] = false; + UnixSockFd listenSockFd(listenFd); + UnixSockFd connectedSockFd; + RETURN_IF_NOT_OK(listenSockFd.Accept(connectedSockFd)); + if (listenFd == exclListenFd_) { + VLOG(RPC_LOG_LEVEL) << FormatString("Spawn new work agent for exclusive connection, sock_fd: %s", + connectedSockFd.GetFd()); + if (!workAgentThreadPool_) { + RETURN_IF_EXCEPTION_OCCURS(workAgentThreadPool_ = + std::make_unique(1, MAX_EXCLUSIVE_CONNECTIONS_LIMIT)); + } + auto newAgent = std::make_unique(connectedSockFd, this, !isTcp); + auto ptr = newAgent.get(); + workAgentThreadPool_->Execute([this, ptr] { ptr->Run(); }); + workAgents_.push_back(std::move(newAgent)); + } else { // Make it non-blocking - RETURN_IF_NOT_OK(sock.SetNonBlocking()); + RETURN_IF_NOT_OK(connectedSockFd.SetNonBlocking()); if (isTcp) { - RETURN_IF_NOT_OK(sock.SetNoDelay()); + RETURN_IF_NOT_OK(connectedSockFd.SetNoDelay()); } // Assign it to the next io service - RETURN_IF_NOT_OK(io_->AddFd(this, fd, !isTcp)); - VLOG(RPC_KEY_LOG_LEVEL) << FormatString("Spawn %s connection %d for service %s", isTcp ? "tcp" : "uds", fd, - serviceName_); - } else { - Status rc = UnixSockFd::ErrnoToStatus(errno, listenFd); - auto it = hasBeenLogged.find(listenFd); - if (rc.IsError() && rc.GetCode() != K_TRY_AGAIN && (it == hasBeenLogged.end() || !it->second)) { - hasBeenLogged[listenFd] = true; - LOG(ERROR) << FormatString("Spawn uds connection %d failed for service %s:", listenFd, serviceName_) - << " with status:" << rc.ToString(); - return rc; + RETURN_IF_NOT_OK(io_->AddFd(this, connectedSockFd.GetFd(), !isTcp)); + VLOG(RPC_KEY_LOG_LEVEL) << FormatString("Spawn %s connection %d for service %s", isTcp ? "tcp" : "uds", + connectedSockFd.GetFd(), serviceName_); + } + + return Status::OK(); +} + +Status ZmqService::DirectExecInternalMethod(int fd, EventType type, ZmqMetaMsgFrames &inFrames, + ZmqMetaMsgFrames &outFrames) +{ + if (type == EventType::V2MTP || type == EventType::V1MTP) { + MetaPb meta; + ZmqCurveUserId userId; + // A direct connection from stub needs to be parsed. + Status rc = ParseMsgFrames(inFrames.second, meta, fd, type, userId); + if (rc.IsError()) { + // Log the message for anything else, and move on. + RETURN_STATUS_LOG_ERROR(StatusCode::K_OK, "Incompatible rpc request. Ignore"); } + inFrames.first = std::move(meta); + } + + CHECK_FAIL_RETURN_STATUS(cfg_.numRegularSockets_ != 0, K_RUNTIME_ERROR, + "Get id failed, regular sockets as divisor is 0, route to reg back end failed"); + MetaPb &meta = inFrames.first; + auto id = nextWorker_.fetch_add(1) % cfg_.numRegularSockets_; + auto worker = workerCBs_[id]; + auto workerId = worker->GetWorkerId(); + meta.set_worker_id(workerId); + + auto traceID = meta.trace_id(); + auto timeout = meta.timeout(); + auto dbName = meta.db_name(); + + Timer timer; + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); + if (timeout > 0) { + int64_t elapsed = timer.ElapsedMilliSecond(); + reqTimeoutDuration.Init(timeout - elapsed); + scTimeoutDuration.Init(timeout - elapsed); + } else { + reqTimeoutDuration.Init(); + scTimeoutDuration.Init(); } + g_MetaRocksDbName = dbName; + LOG_IF_ERROR(worker->WorkerEntryWithoutMsgQ(inFrames, outFrames), "worker entry failed"); + g_MetaRocksDbName.clear(); + workerOperationTimeCost.Clear(); + masterOperationTimeCost.Clear(); + VLOG(RPC_LOG_LEVEL) << FormatString("Routing request %s to %s", meta.client_id(), workerId); + return Status::OK(); } @@ -1173,6 +1268,7 @@ Status ZmqService::HandleRqFromProxy() RETURN_IF_NOT_OK(rqQueue_->Take(&p)); PerfPoint::RecordElapsed(PerfKey::ZMQ_ROUTER_TO_SVC, GetLapTime(p.first, "ZMQ_ROUTER_TO_SVC")); RETURN_IF_NOT_OK(FrontendToBackend(ZMQ_NO_FILE_FD, EventType::ZMQ, p, false)); + if (globalInterrupt_) { RETURN_STATUS(StatusCode::K_SHUTTING_DOWN, "Shutting down the socket."); } diff --git a/src/datasystem/common/rpc/zmq/zmq_service.h b/src/datasystem/common/rpc/zmq/zmq_service.h index 86eb3ab4969cecb1126fe005dfa8a4aecb04f0d6..a3d954ac2875d61d5245dd7154466690c12cd31e 100644 --- a/src/datasystem/common/rpc/zmq/zmq_service.h +++ b/src/datasystem/common/rpc/zmq/zmq_service.h @@ -47,12 +47,14 @@ #include "datasystem/common/util/queue/queue.h" #include "datasystem/common/util/thread_pool.h" #include "datasystem/common/util/status_helper.h" +#include "datasystem/common/rpc/zmq/work_agent.h" namespace datasystem { typedef MsgQueRef ZmqServerMsgQueRef; typedef MsgQueMgr ZmqServerMsgMgr; typedef decltype(epoll_event::events) EventsVal; class SockEventService; +class WorkAgent; /** * @brief An abstract class for RPC service. * The ZMQ plugin will generate a subclass, and the user will supply the virtual method implementation. @@ -73,6 +75,7 @@ public: */ virtual Status CallMethod(std::shared_ptr sock, MetaPb meta, ZmqMsgFrames &&inMsg, int64_t seqNo) = 0; + virtual Status DirectCallMethod(MetaPb meta, ZmqMsgFrames &&inMsg, int64_t seqNo, ZmqMsgFrames &outMsg) = 0; Status Init(RpcServiceCfg cfg, void *proxy); @@ -113,6 +116,12 @@ public: return outfd_; } + Status GetExclConnSockPath(std::string &sockPath) + { + sockPath = exclSockPath_; + return Status::OK(); + } + Status ServiceRequest(MetaPb &&meta, ZmqMsgFrames &&msgs); Status ServiceReply(ZmqMetaMsgFrames *msg); @@ -135,6 +144,8 @@ public: return serviceName_; } + Status DirectExecInternalMethod(int fd, EventType type, ZmqMetaMsgFrames &inFrames, ZmqMetaMsgFrames &outFrames); + protected: /** * This map is populated by zmq_plugin generated code. @@ -164,6 +175,7 @@ private: * Entry function for each thread. */ Status WorkerEntry(); + Status WorkerEntryWithoutMsgQ(ZmqMetaMsgFrames &inMsg, ZmqMetaMsgFrames &outMsg); Status StreamWorkerEntry(); void CloseSocket() @@ -258,12 +270,14 @@ private: int outfd_; // For replyQueue_. int infd_; // For rqQueue_. int tcpfd_; // If bypass ZmqServiceImpl + int exclListenFd_; HostPort tcpHostPort_; // If bypass ZmqServiceImpl std::atomic nextWorker_; std::atomic globalInterrupt_; bool streamSupport_; bool multiDestinations_; std::vector sockPath_; + std::string exclSockPath_; bool unlinkSocketPathOnExit_; bool tcpDirect_; std::shared_ptr backendMgr_{ nullptr }; @@ -276,6 +290,9 @@ private: WriterPrefRWLock routeMux_; std::map> routes_; std::map fdToGateway_; + + std::vector> workAgents_; + std::unique_ptr workAgentThreadPool_ { nullptr }; }; } // namespace datasystem #endif // DATASYSTEM_COMMON_RPC_ZMQ_SERVICE_H diff --git a/src/datasystem/common/rpc/zmq/zmq_stub.h b/src/datasystem/common/rpc/zmq/zmq_stub.h index c63e2ac66d925cf3620686c7c257873bc739509b..65d61cdff63ed8e20569b8cbb8bd7a8a47bd777d 100644 --- a/src/datasystem/common/rpc/zmq/zmq_stub.h +++ b/src/datasystem/common/rpc/zmq/zmq_stub.h @@ -77,6 +77,17 @@ public: Status GetInitStatus(); + /** + * @brief Set exclusive connection by assigning required fields + * @param[in] exclusiveId The exclusive id + * @param[in] exclusiveSockPath The socket path for exclusive sock connect + */ + void SetExclusiveConnInfo(const std::optional &exclusiveId, const std::string &sockPath) + { + exclusiveSockPath_ = sockPath; + exclusiveId_ = exclusiveId; + } + protected: /** * @brief Initialization. If requesting uds connection, the connection will be established asynchronously. @@ -86,6 +97,8 @@ protected: std::map> methodMap_; std::string serviceName_; + std::string exclusiveSockPath_; + std::optional exclusiveId_; int channelNo_; std::unique_ptr pimpl_; }; diff --git a/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp b/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp index 926af6221d8b912e90170025a236097c17aa3a8b..0fbe770383951f5a6e929152deeeeb2835559922 100644 --- a/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp +++ b/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp @@ -104,6 +104,7 @@ Status ZmqFrontend::ExchangeJfr() // Send our own jfr UrmaHandshakeReqPb rq; UrmaManager::Instance().GetLocalUrmaInfo().ToProto(rq); + RETURN_IF_NOT_OK(UrmaManager::Instance().GetSegmentInfo(rq)); MetaPb meta = CreateMetaData("", ZMQ_EXCHANGE_JFR_METHOD, ZMQ_INVALID_PAYLOAD_INX, GetStringUuid()); ZmqMsgFrames p; RETURN_IF_NOT_OK(PushFrontProtobufToFrames(meta, p)); diff --git a/src/datasystem/common/rpc/zmq/zmq_unary_client_impl.h b/src/datasystem/common/rpc/zmq/zmq_unary_client_impl.h index 4bf1647cff542e96b009116bb1908752f7c22b3f..ec13224bdb9dba349691e072bfef813668af1dfa 100644 --- a/src/datasystem/common/rpc/zmq/zmq_unary_client_impl.h +++ b/src/datasystem/common/rpc/zmq/zmq_unary_client_impl.h @@ -38,6 +38,7 @@ #include "datasystem/common/util/gflag/common_gflags.h" #include "datasystem/common/util/raii.h" #include "datasystem/common/util/status_helper.h" +#include "datasystem/common/rpc/zmq/exclusive_conn_mgr.h" namespace datasystem { template @@ -56,9 +57,39 @@ public: meta_ = CreateMetaData(svcName, methodIndex, sendPayload ? ZMQ_EMBEDDED_PAYLOAD_INX : ZMQ_INVALID_PAYLOAD_INX, mQue_->GetId()); } + + // Alternate constructor for exclusive connection mode + ClientUnaryWriterReaderImpl(int32_t exclusiveId, const std::string &svcName, int32_t methodIndex, bool sendPayload, + bool recvPayload) + : StreamBase::StreamBase(sendPayload, recvPayload), + mQue_(nullptr), + writeOnce_(false), + readOnce_(false), + v2Server_(false), + payloadId_(-1), + payloadSz_(0), + exclusiveId_(exclusiveId) + { + // In non-exclusive connection mode, the mQue_->GetId() is used for the clientId metadata. This is not the + // actual client id, but the ZmqMsgQueue id. This field is not used in exclusive connection mode, but we can + // populate it with a name for diagnostic purposes. + std::string clientId = gExclusiveConnMgr.GetExclusiveConnMgrName(); + meta_ = CreateMetaData(svcName, methodIndex, sendPayload ? ZMQ_EMBEDDED_PAYLOAD_INX : ZMQ_INVALID_PAYLOAD_INX, + clientId); + } + + Status InitExclusiveConnection(const std::string &exclusiveSockPath, int64_t timeoutMs) + { + RETURN_IF_NOT_OK(gExclusiveConnMgr.CreateExclusiveConnection(exclusiveId_.value(), timeoutMs, + exclusiveSockPath)); + return Status::OK(); + } + ~ClientUnaryWriterReaderImpl() override { - mQue_->Close(); + if (mQue_) { + mQue_->Close(); + } } Status Write(const W &pb) @@ -111,15 +142,14 @@ public: rq.set_id(payloadId_); rq.set_error_code(K_NOT_READY); RETURN_IF_NOT_OK(RequestPayload(rq)); - RETURN_IF_NOT_OK(mQue_->ClientReceiveMsg(reply, ZmqRecvFlags::NONE)); + RETURN_IF_NOT_OK(RecvConnReply(reply, ZmqRecvFlags::NONE, true)); // Just like other reply, the first one is a Status rc auto &frames = reply.second; auto rcMsg = std::move(frames.front()); frames.pop_front(); RETURN_IF_NOT_OK(ZmqMessageToStatus(rcMsg)); } else if (payloadId_ == ZMQ_OFFLINE_PAYLOAD_INX) { - // This is a continuation of the original rpc but sent separately. - RETURN_IF_NOT_OK(mQue_->ClientReceiveMsg(reply, ZmqRecvFlags::NONE)); + RETURN_IF_NOT_OK(RecvConnReply(reply, ZmqRecvFlags::NONE, true)); } auto &frames = reply.second; while (!frames.empty()) { @@ -163,11 +193,7 @@ public: RETURN_IF_NOT_OK(RequestPayload(rq)); // Now we wait for underlying framework to write the payload directly into the memory provided. ZmqMetaMsgFrames reply; - RETURN_IF_NOT_OK(mQue_->ClientReceiveMsg(reply, ZmqRecvFlags::NONE)); - auto &meta = reply.first; - GetLapTime(meta, "ZMQ_PAYLOAD_TRANSFER"); - auto elapsed = GetTotalTime(meta); - PerfPoint::RecordElapsed(PerfKey::ZMQ_PAYLOAD_TRANSFER, elapsed); + RETURN_IF_NOT_OK(RecvConnReply(reply, ZmqRecvFlags::NONE, true)); // Verify the response auto &frames = reply.second; ZmqMessage msg; @@ -184,9 +210,6 @@ public: "Client %s use unary socket to receive %d payload bytes from Service %s" " Method %d.", meta_.client_id(), sz, meta_.svc_name(), meta_.method_index()); - const int NANO_TO_MS = 1'000'000; - VLOG(RPC_KEY_LOG_LEVEL) << FormatString("Time to transfer payload size %d : %6lf milliseconds", sz, - (float)elapsed / (float)NANO_TO_MS); return Status::OK(); } @@ -203,19 +226,20 @@ public: { // Send metadata first. StartTheClock(meta_); - ZmqMetaMsgFrames p(meta_, std::move(outMsg_)); - return mQue_->SendMsg(p, flags); + ZmqMsgFrames frames = std::move(outMsg_); + RETURN_IF_NOT_OK(SendConnMsg(meta_, frames, flags)); + return Status::OK(); } Status ReadAll(ZmqRecvFlags flags) override { inMsg_.clear(); ZmqMetaMsgFrames reply; - RETURN_IF_NOT_OK(mQue_->ClientReceiveMsg(reply, flags)); - PerfPoint::RecordElapsed(PerfKey::ZMQ_STUB_FRONT_TO_BACK, GetLapTime(reply.first, "ZMQ_STUB_FRONT_TO_BACK")); + RETURN_IF_NOT_OK(RecvConnReply(reply, flags, false)); inMsg_ = std::move(reply.second); payloadId_ = reply.first.payload_index(); v2Server_ = payloadId_ >= 0; + return Status::OK(); } @@ -239,6 +263,11 @@ public: return v2Server_; } + bool IsExclusiveConnection() + { + return exclusiveId_.has_value(); + } + protected: std::shared_ptr mQue_; @@ -317,10 +346,96 @@ private: MetaPb meta = meta_; meta.set_method_index(ZMQ_PAYLOAD_GET_METHOD); StartTheClock(meta); - ZmqMetaMsgFrames p; - p.first = std::move(meta); - RETURN_IF_NOT_OK(PushBackProtobufToFrames(rq, p.second)); - RETURN_IF_NOT_OK(mQue_->SendMsg(p)); + + ZmqMsgFrames frames; + RETURN_IF_NOT_OK(PushBackProtobufToFrames(rq, frames)); + RETURN_IF_NOT_OK(SendConnMsg(meta, frames, ZmqSendFlags::NONE)); + return Status::OK(); + } + + Status SendConnMsg(MetaPb &meta, ZmqMsgFrames &frames, ZmqSendFlags flags) + { + if (!IsExclusiveConnection()) { + ZmqMetaMsgFrames p(meta, std::move(frames)); + RETURN_IF_NOT_OK(mQue_->SendMsg(p, flags)); + } else { + ZmqMsgEncoder *encoder; + RETURN_IF_NOT_OK(gExclusiveConnMgr.GetExclusiveConnEncoder(exclusiveId_.value(), encoder)); + + // This is the last perf point before sending data to worker. In non-exclusive mode, this would be + // ZMQ_STUB_TO_BACK_TO_FRONT that gets recorded in the frontend sending codepath. + // Here in exclusive conn mode, it directly sends right here because there is no frontend. + // Thus, record a perf point here so that the next point, the server-side ZMQ_NETWORK_TRANSFER, will more + // accurately show the transfer time. + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(meta.trace_id()); + PerfPoint::RecordElapsed(PerfKey::ZMQ_STUB_TO_EXCL_CONN, GetLapTime(meta, "ZMQ_STUB_TO_EXCL_CONN")); + + // Add the meta to the frames + RETURN_IF_NOT_OK(PushFrontProtobufToFrames(meta, frames)); + + // The following field is part of the protocol and seems to be the first frame sent. + // In exclusive connection mode, I doubt this frame is needed. For now, add it anyway as an empty + // string so that client/server follow agreed protocol. + // This could be a candidate for removal from the protocol later. + std::string gatewayId; + RETURN_IF_NOT_OK(PushFrontStringToFrames(gatewayId, frames)); + RETURN_IF_NOT_OK(encoder->SendMsgFrames(EventType::V1MTP, frames)); + } + return Status::OK(); + } + + Status RecvConnReply(ZmqMetaMsgFrames &reply, ZmqRecvFlags flags, bool isPayload) + { + if (!IsExclusiveConnection()) { + RETURN_IF_NOT_OK(mQue_->ClientReceiveMsg(reply, flags)); + // Regular receive counts the perf point using ZMQ_STUB_FRONT_TO_BACK. Special case for the payload version + // of a receive + if (isPayload) { + PerfPoint::RecordElapsed(PerfKey::ZMQ_PAYLOAD_TRANSFER, + GetLapTime(reply.first, "ZMQ_PAYLOAD_TRANSFER")); + } else { + PerfPoint::RecordElapsed(PerfKey::ZMQ_STUB_FRONT_TO_BACK, + GetLapTime(reply.first, "ZMQ_STUB_FRONT_TO_BACK")); + } + } else { + ZmqMsgDecoder *decoder; + RETURN_IF_NOT_OK(gExclusiveConnMgr.GetExclusiveConnDecoder(exclusiveId_.value(), decoder)); + + ZmqMsgFrames replyFrames; + // v2 decode supports v1 format. Use the v2 version of the decode. + Status rc = decoder->ReceiveMsgFramesV2(replyFrames); + if (rc.IsError()) { + // If there was a timeout in the communication codepaths, then the client-side may have done error exit + // while the service is still working on the request. Later, the service might send a reply back to this + // connection but that is stale/old data. + // The solution is that this connection should not be used anymore. Close it. + // A future request from this thread can make a new connection again. + // Note: Not all errors would require connection reset. For example, some normal error coming back from + // the server doesn't mean the connection needs to be cleaned up. + // Future code here can decide to close only for specific error codes. + LOG_IF_ERROR(gExclusiveConnMgr.CloseExclusiveConn(exclusiveId_.value()), + "Error closing exclusive conn during error path"); + return rc; + } + + const size_t msgFrameMinSize = 2; + CHECK_FAIL_RETURN_STATUS(replyFrames.size() >= msgFrameMinSize, StatusCode::K_INVALID, + "Invalid msg: frames.size() = " + std::to_string(replyFrames.size())); + std::string receiver = ZmqMessageToString(replyFrames.front()); + replyFrames.pop_front(); + + ZmqMessage metaHdr = std::move(replyFrames.front()); + replyFrames.pop_front(); + MetaPb meta; + RETURN_IF_NOT_OK(ParseFromZmqMessage(metaHdr, meta)); + // In exclusive mode, there is no front/backend. There is only a receive. Do not record any stub perf + // points, and just record the network transfer here. + // Note that in non-exclusive mode, the ZMQ_NETWORK_TRANSFER was recorded in the frontend codepath, but that + // code did not run here so this is the appropriate time to capture this stat. + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(meta.trace_id()); + PerfPoint::RecordElapsed(PerfKey::ZMQ_NETWORK_TRANSFER, GetLapTime(meta, "ZMQ_NETWORK_TRANSFER (SOCKET)")); + reply = std::make_pair(meta, std::move(replyFrames)); + } return Status::OK(); } @@ -329,6 +444,7 @@ private: bool v2Server_; int64_t payloadId_; size_t payloadSz_; + std::optional exclusiveId_; }; } // namespace datasystem #endif // DATASYSTEM_COMMON_RPC_ZMQ_STREAM_H diff --git a/src/datasystem/common/shared_memory/shm_unit.cpp b/src/datasystem/common/shared_memory/shm_unit.cpp index c3c4c3860872a90a3b38a1315fafb445d743a735..3728a87489b86ad83fc908552e68308cd1184168 100644 --- a/src/datasystem/common/shared_memory/shm_unit.cpp +++ b/src/datasystem/common/shared_memory/shm_unit.cpp @@ -27,6 +27,7 @@ #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/utils/status.h" @@ -36,7 +37,7 @@ ShmUnit::ShmUnit(int fd, uint64_t mmapSz) : ShmUnitInfo(fd, mmapSz) { } -ShmUnit::ShmUnit(std::string id, ShmView shmView, void *pointer) : ShmUnitInfo(std::move(id), shmView, pointer) +ShmUnit::ShmUnit(ShmKey id, ShmView shmView, void *pointer) : ShmUnitInfo(std::move(id), shmView, pointer) { } diff --git a/src/datasystem/common/shared_memory/shm_unit.h b/src/datasystem/common/shared_memory/shm_unit.h index dd08db48bcda1739040139bd196170d89b19db89..d7ed1b0ed2aeb1403983cef78d2d56e83b5f47d8 100644 --- a/src/datasystem/common/shared_memory/shm_unit.h +++ b/src/datasystem/common/shared_memory/shm_unit.h @@ -55,7 +55,7 @@ public: * @param[in] pointer The pointer to allocated data for the ShmUnit (This ShmUnit shall be responsible to free it * during it's destructor. */ - ShmUnit(std::string id, ShmView shmView, void *pointer); + ShmUnit(ShmKey id, ShmView shmView, void *pointer); /** * @brief Destructor. ShmUnit own memory and will clean themself up and free the memory that they own. diff --git a/src/datasystem/common/shared_memory/shm_unit_info.cpp b/src/datasystem/common/shared_memory/shm_unit_info.cpp index 0d001f627cd2006caf5e0ddadee1b08d024a17ce..0c3281b2c90cf0a438f4d8ae6507078b0b2cf1c8 100644 --- a/src/datasystem/common/shared_memory/shm_unit_info.cpp +++ b/src/datasystem/common/shared_memory/shm_unit_info.cpp @@ -20,6 +20,7 @@ */ #include "datasystem/common/shared_memory/shm_unit_info.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/memory.h" namespace datasystem { @@ -28,14 +29,14 @@ ShmUnitInfo::ShmUnitInfo(int fd, uint64_t mmapSz) : fd(fd), mmapSize(mmapSz) { } -ShmUnitInfo::ShmUnitInfo(ImmutableString id, ShmView shmView, void *pointer) +ShmUnitInfo::ShmUnitInfo(ShmKey id, ShmView shmView, void *pointer) : fd(shmView.fd), mmapSize(shmView.mmapSz), pointer(pointer), offset(shmView.off), size(shmView.sz), refCount(0), - id(id) + id(std::move(id)) { } diff --git a/src/datasystem/common/shared_memory/shm_unit_info.h b/src/datasystem/common/shared_memory/shm_unit_info.h index 6f28e40f6903f916c839f297bb28fcdc7d62b8c3..7c2b610c10f774e360031a591e2b81e00c806771 100644 --- a/src/datasystem/common/shared_memory/shm_unit_info.h +++ b/src/datasystem/common/shared_memory/shm_unit_info.h @@ -29,6 +29,7 @@ #include #include "datasystem/common/log/log.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/thread_pool.h" #include "datasystem/common/immutable_string/immutable_string.h" @@ -101,7 +102,7 @@ public: * @param[in] pointer The pointer to allocated data for the ShmUnitInfo (This ShmUnitInfo will not free it during * it's destructor. */ - ShmUnitInfo(ImmutableString id, ShmView shmView, void *pointer); + ShmUnitInfo(ShmKey id, ShmView shmView, void *pointer); /** * @brief Destructor. Client-side ShmUnitInfo's do not own memory and will not free memory. @@ -112,7 +113,7 @@ public: * @brief Returns shm id. * @return id of shared memory. */ - std::string GetId() const + ShmKey GetId() const { return id; } @@ -237,7 +238,7 @@ public: std::atomic refCount = { 0 }; // uuid - ImmutableString id; + ShmKey id; }; } // namespace datasystem diff --git a/src/datasystem/common/stream_cache/stream_data_page.cpp b/src/datasystem/common/stream_cache/stream_data_page.cpp index c8c6f05ae13e78541f46f1e5bafbc268203480d8..beb2633b4f901731ceeda06a0d91b34c78f8d450 100644 --- a/src/datasystem/common/stream_cache/stream_data_page.cpp +++ b/src/datasystem/common/stream_cache/stream_data_page.cpp @@ -22,6 +22,7 @@ #include #include +#include "datasystem/common/string_intern/string_ref.h" #include "securec.h" #include "datasystem/common/util/raii.h" @@ -262,9 +263,10 @@ Status DataVerificationHeader::ExtractHeader(DataElement &element, ElementHeader return Status::OK(); } -std::string StreamPageBase::CreatePageId(const std::shared_ptr &shmInfo) +ShmKey StreamPageBase::CreatePageId(const std::shared_ptr &shmInfo) { - return FormatString("F:%zu-M:%zu-O:%zu-S:%zu", shmInfo->fd, shmInfo->mmapSize, shmInfo->offset, shmInfo->size); + return ShmKey::Intern( + FormatString("F:%zu-M:%zu-O:%zu-S:%zu", shmInfo->fd, shmInfo->mmapSize, shmInfo->offset, shmInfo->size)); } StreamPageBase::StreamPageBase(std::shared_ptr shmInfo) diff --git a/src/datasystem/common/stream_cache/stream_data_page.h b/src/datasystem/common/stream_cache/stream_data_page.h index f926c15ad05d202c5e3fe23ce48b9cada79ff2fb..ca50e599753f14d6ce11eb02a46d9023e9bb80fd 100644 --- a/src/datasystem/common/stream_cache/stream_data_page.h +++ b/src/datasystem/common/stream_cache/stream_data_page.h @@ -24,6 +24,7 @@ #include "datasystem/common/shared_memory/shm_unit.h" #include "datasystem/common/stream_cache/cursor.h" #include "datasystem/common/stream_cache/stream_meta_shm.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/bitmask_enum.h" #include "datasystem/common/util/raii.h" #include "datasystem/stream/element.h" @@ -199,13 +200,13 @@ public: return pageUnit_->GetSize(); } - static std::string CreatePageId(const std::shared_ptr &pageUnit); + static ShmKey CreatePageId(const std::shared_ptr &pageUnit); /** * @brief Return the page id * @return */ - std::string GetPageId() const + ShmKey GetPageId() const { return pageUnit_->GetId(); } diff --git a/src/datasystem/common/string_intern/string_entity.cpp b/src/datasystem/common/string_intern/string_entity.cpp index 9d93e1cbc7d823d1a6c3b8c8c82edb298def659f..8ce7a42fa9e14d40de92216bd1dbc24b8757e451 100644 --- a/src/datasystem/common/string_intern/string_entity.cpp +++ b/src/datasystem/common/string_intern/string_entity.cpp @@ -66,7 +66,7 @@ int32_t StringEntity::IncRef() const bool StringEntity::DecRef() const { - return (--countRef_ == 0); + return --countRef_ == 0; } void StringEntity::IncDelRef() const diff --git a/src/datasystem/common/string_intern/string_pool.h b/src/datasystem/common/string_intern/string_pool.h index 87cf0f42575247803d5752f452c774b6a44d126a..f32de19756ed367ef3819bc31a974dfb5992e5f4 100644 --- a/src/datasystem/common/string_intern/string_pool.h +++ b/src/datasystem/common/string_intern/string_pool.h @@ -95,7 +95,7 @@ public: * @brief Return the size ofStringPool * @return The size ofStringPool */ - size_t Size() + size_t Size() const { return pool_.size(); } diff --git a/src/datasystem/common/string_intern/string_ref.h b/src/datasystem/common/string_intern/string_ref.h index 822c410890a9a0e1f100929bbc2399bd100fd01c..fa2d416c123469e595ac7240c0a5432f46d5e0aa 100644 --- a/src/datasystem/common/string_intern/string_ref.h +++ b/src/datasystem/common/string_intern/string_ref.h @@ -46,7 +46,7 @@ public: StringRef &operator=(const StringRef &other) noexcept { if (this != &other) { - handle_ = other.ptr_; + handle_ = other.handle_; handle_.IncRef(); } return *this; @@ -142,6 +142,11 @@ public: return ToString().size(); } + bool Empty() const + { + return Size() == 0; + } + private: StringPtr handle_; }; diff --git a/src/datasystem/common/util/gflag/common_gflags.cpp b/src/datasystem/common/util/gflag/common_gflags.cpp index ead11baba8a00ccfa9511ccfcddf538a4a0b2005..b08b60e3023b7537aaa21a284b5bdc3385c003e2 100644 --- a/src/datasystem/common/util/gflag/common_gflags.cpp +++ b/src/datasystem/common/util/gflag/common_gflags.cpp @@ -76,6 +76,21 @@ bool ValidateUrmaMode(const char *flagName, const std::string &value) return true; #endif } + +bool ValidateEnableRdma(const char *flagName, bool value) +{ + (void)flagName; +#ifdef USE_RDMA + (void)value; + return true; +#else + if (value) { + LOG(ERROR) << FormatString("Worker not build with UCX RDMA framework, but %s set to true", flagName); + return false; + } + return true; +#endif +} } // namespace DS_DEFINE_bool(enable_urma, false, "Option to turn on urma for OC worker to worker data transfer, default false."); @@ -88,3 +103,8 @@ DS_DEFINE_bool(urma_register_whole_arena, true, "Register the whole arena as segment during init, otherwise, register each object as a segment."); DS_DEFINE_bool(urma_event_mode, false, "Uses interrupt mode to poll completion events."); DS_DEFINE_bool(enable_worker_worker_batch_get, false, "Enable worker->worker OC batch get, default false."); + +DS_DEFINE_bool(enable_rdma, false, "Option to turn on rdma for OC worker to worker data transfer, default false."); +DS_DEFINE_validator(enable_rdma, &ValidateEnableRdma); +DS_DEFINE_bool(rdma_register_whole_arena, true, + "Register the whole arena as segment during init, otherwise, register each object as a segment."); diff --git a/src/datasystem/common/util/gflag/common_gflags.h b/src/datasystem/common/util/gflag/common_gflags.h index 4268299beb8ed30e2cb34d5f12a3d8d9248bf5c9..5bb440cf5ba74c81e6a7561026db50600d7b064c 100644 --- a/src/datasystem/common/util/gflag/common_gflags.h +++ b/src/datasystem/common/util/gflag/common_gflags.h @@ -29,4 +29,5 @@ DS_DECLARE_bool(enable_multi_stubs); DS_DECLARE_bool(enable_tcp_direct_for_multi_stubs); DS_DECLARE_bool(log_monitor); DS_DECLARE_bool(enable_worker_worker_batch_get); +DS_DECLARE_bool(urma_register_whole_arena); #endif // DATASYSTEM_COMMON_UTIL_COMMON_GFLAGS_H diff --git a/src/datasystem/common/util/queue/shm_circular_queue.cpp b/src/datasystem/common/util/queue/shm_circular_queue.cpp index 3ae6581dcd513e3facbc194820de497afa09e59c..55ab2c108f92ccbd0e83baa44352928e787dd949 100644 --- a/src/datasystem/common/util/queue/shm_circular_queue.cpp +++ b/src/datasystem/common/util/queue/shm_circular_queue.cpp @@ -19,6 +19,7 @@ */ #include "datasystem/common/util/queue/shm_circular_queue.h" +#include "datasystem/common/string_intern/string_ref.h" namespace datasystem { @@ -291,7 +292,7 @@ void ShmCircularQueue::WriteUnlock() return queueLock_->UnWLatch(); } -Status ShmCircularQueue::GetQueueShmUnit(int &fd, uint64_t &mmapSize, ptrdiff_t &offset, std::string &id) +Status ShmCircularQueue::GetQueueShmUnit(int &fd, uint64_t &mmapSize, ptrdiff_t &offset, ShmKey &id) { RETURN_RUNTIME_ERROR_IF_NULL(circularQueueUnit_); fd = circularQueueUnit_->GetFd(); diff --git a/src/datasystem/common/util/queue/shm_circular_queue.h b/src/datasystem/common/util/queue/shm_circular_queue.h index aa89db9bd5a47221641914823bc27dfd736ac724..f96b6fd9dd804ded73f92557c5faea3ace71872c 100644 --- a/src/datasystem/common/util/queue/shm_circular_queue.h +++ b/src/datasystem/common/util/queue/shm_circular_queue.h @@ -260,7 +260,7 @@ public: * @param[out] id The id of this shmUnit. * @return Status of the call. */ - Status GetQueueShmUnit(int &fd, uint64_t &mmapSize, ptrdiff_t &offset, std::string &id); + Status GetQueueShmUnit(int &fd, uint64_t &mmapSize, ptrdiff_t &offset, ShmKey &id); /** * @brief Check futex result. diff --git a/src/datasystem/common/util/request_table.h b/src/datasystem/common/util/request_table.h index 21788a354551983950282f12473c93ea5feb801f..004bc7ff0dc10c14e42b66f7f7bbaff22f8acce9 100644 --- a/src/datasystem/common/util/request_table.h +++ b/src/datasystem/common/util/request_table.h @@ -34,17 +34,20 @@ namespace datasystem { template class RequestTable { public: + using TbbRequestTable = tbb::concurrent_hash_map>>; + /** * @brief Add request to Worker/MasterRequestManager. * @param[in] objectKey The object key. * @param[in] request The request that is waiting on the object key. * @return Status of the call. */ - Status AddRequest(const std::string &objectKey, std::shared_ptr &request) + Status AddRequest(const std::string &objectKey, const std::shared_ptr &request) { RETURN_RUNTIME_ERROR_IF_NULL(request); - std::lock_guard lck(mutex_); - requestTable_[objectKey].push_back(request); + typename TbbRequestTable::accessor acc; + requestTable_.insert(acc, objectKey); + acc->second.emplace_back(request); return Status::OK(); } @@ -55,23 +58,21 @@ public: */ bool ObjectInRequest(const std::string &objectKey) { - std::shared_lock lck(mutex_); - return requestTable_.find(objectKey) != requestTable_.end(); + return requestTable_.count(objectKey) != 0; } /** * @brief Remove the request from the waiting requests table. * @param[in] request The request need to remove. */ - void RemoveRequest(std::shared_ptr &request) + void RemoveRequest(const std::shared_ptr &request) { - std::lock_guard locker(mutex_); - for (auto &objectKey : request->deduplicatedObjectKeys_) { - auto iter = requestTable_.find(objectKey); - if (iter == requestTable_.end()) { + for (const auto &objectKey : request->GetUniqueObjectkeys()) { + typename TbbRequestTable::accessor acc; + if (!requestTable_.find(acc, objectKey)) { continue; } - auto &requestsOnObject = iter->second; + auto &requestsOnObject = acc->second; // Erase request from the list. auto it = std::find(requestsOnObject.begin(), requestsOnObject.end(), request); if (it == requestsOnObject.end()) { @@ -80,7 +81,7 @@ public: requestsOnObject.erase(it); // If the vector is empty, remove the object key from the map. if (requestsOnObject.empty()) { - requestTable_.erase(iter); + requestTable_.erase(acc); } } } @@ -91,8 +92,7 @@ public: */ void EraseSub(const std::string &key) { - std::lock_guard locker(mutex_); - (void)requestTable_.erase(key); + requestTable_.erase(key); } /** @@ -111,19 +111,19 @@ public: const std::string &objectKey, std::shared_ptr entryParam, Status lastRc, std::function)> doneRequestCallBack, const std::shared_ptr &specRequset = nullptr, bool isUpdateSubRecvEventRequest = false, - std::function)> checkOffsetMatch = nullptr) + std::function &req)> checkOffsetMatch = nullptr) { std::vector> requests; { - std::shared_lock lck(mutex_); - auto it = requestTable_.find(objectKey); - RETURN_OK_IF_TRUE(it == requestTable_.end()); + typename TbbRequestTable::const_accessor acc; + RETURN_OK_IF_TRUE(!requestTable_.find(acc, objectKey)); LOG(INFO) << FormatString("Update request for objectKey: %s, status:%s", objectKey, lastRc.ToString()); // Avoid acquiring locks for both WorkerRequestManager/MasterDevReqManager and xxRequest at the same time. - requests = it->second; + requests = acc->second; } std::vector> completedRequests; + completedRequests.reserve(requests.size()); for (auto &req : requests) { std::lock_guard locker(req->mutex_); if (specRequset != nullptr && specRequset != req) { @@ -168,19 +168,43 @@ public: */ std::vector> GetRequestsByObject(const std::string &objKey) { - std::shared_lock lck(mutex_); - auto it = requestTable_.find(objKey); - if (it != requestTable_.end()) { - return it->second; + typename TbbRequestTable::const_accessor acc; + if (requestTable_.find(acc, objKey)) { + return acc->second; } return {}; } -private: - std::shared_timed_mutex mutex_; + template + Status NotifyPendingGetRequest(const std::string &objectKey, std::unique_ptr params) + { + std::vector> requests; + { + typename TbbRequestTable::const_accessor accessor; + if (!requestTable_.find(accessor, objectKey)) { + return Status::OK(); + } + requests = accessor->second; + } + LOG(INFO) << FormatString("Update request for objectKey: %s", objectKey); + size_t requestCount = requests.size(); + // happy path + if (requestCount == 1) { + return requests[0]->MarkSuccessForNotify(objectKey, std::move(params)); + } + Status lastRc; + for (auto &req : requests) { + Status rc = req->MarkSuccessForNotify(objectKey, params->Clone()); + if (rc.IsError()) { + lastRc = rc; + } + } + return lastRc; + } +private: // A hash table that maps object key to a vector of requests, which are waiting for objects to be ready. - std::unordered_map>> requestTable_; + TbbRequestTable requestTable_; }; template @@ -269,6 +293,11 @@ public: } } + const std::unordered_set &GetUniqueObjectkeys() const + { + return deduplicatedObjectKeys_; + } + // The rpc request info Req requestInfo_; diff --git a/src/datasystem/common/util/status.cpp b/src/datasystem/common/util/status.cpp index 8403d8930132a9950ce4f4d492d236cd1bd774b1..11c589786db58138567a0e9a41f84bac8123f5e0 100644 --- a/src/datasystem/common/util/status.cpp +++ b/src/datasystem/common/util/status.cpp @@ -27,52 +27,54 @@ #include "datasystem/common/log/trace.h" namespace datasystem { -Status::Status() noexcept : code_(StatusCode::K_OK) +Status::Status() noexcept : state_(nullptr) { } -Status &Status::operator=(const Status &other) +Status::Status(const Status &other) noexcept { - if (this == &other) { - return *this; - } - code_ = other.code_; - errMsg_ = other.errMsg_; + Assign(other); +} + +Status &Status::operator=(const Status &other) noexcept +{ + Assign(other); return *this; } Status::Status(Status &&other) noexcept { - code_ = other.code_; - other.code_ = StatusCode::K_OK; - errMsg_ = std::move(other.errMsg_); + std::swap(state_, other.state_); } Status &Status::operator=(Status &&other) noexcept { - if (this == &other) { - return *this; - } - code_ = other.code_; - other.code_ = StatusCode::K_OK; - errMsg_ = std::move(other.errMsg_); + std::swap(state_, other.state_); return *this; } -Status::Status(StatusCode code, std::string msg) : code_(code), errMsg_(std::move(msg)) +Status::Status(StatusCode code, std::string msg) noexcept { + if (code == StatusCode::K_OK) { + return; + } auto traceId = Trace::Instance().GetTraceID(); - if (code_ != K_OK && !traceId.empty() && errMsg_.find(traceId) == std::string::npos) { - if (!errMsg_.empty() && errMsg_.back() == '.') { - errMsg_.pop_back(); + if (code != K_OK && !traceId.empty()) { + if (!msg.empty() && msg.back() == '.') { + msg.pop_back(); } - errMsg_ += ", traceId: " + traceId; + msg += ", traceId: " + traceId; } + state_ = std::make_unique(); + state_->code = code; + state_->errMsg = std::move(msg); } Status::Status(StatusCode code, int lineOfCode, const std::string &fileName, const std::string &extra) { - this->code_ = code; + if (code == StatusCode::K_OK) { + return; + } std::ostringstream ss; ss << "Thread ID " << std::this_thread::get_id() << " " << StatusCodeName(code) << ". "; if (!extra.empty()) { @@ -85,10 +87,12 @@ Status::Status(StatusCode code, int lineOfCode, const std::string &fileName, con ss << "File : " << fileName.substr(position, fileName.length() - position) << std::endl; } auto traceId = Trace::Instance().GetTraceID(); - if (code_ != K_OK && !traceId.empty() && errMsg_.find(traceId) == std::string::npos) { + if (code != K_OK && !traceId.empty()) { ss << "traceId : " << traceId; } - errMsg_ = ss.str(); + state_ = std::make_unique(); + state_->code = code; + state_->errMsg = ss.str(); } std::ostream &operator<<(std::ostream &os, const Status &s) @@ -99,24 +103,41 @@ std::ostream &operator<<(std::ostream &os, const Status &s) std::string Status::ToString() const { - return "code: [" + StatusCodeName(code_) + "], msg: [" + errMsg_ + "]"; + return "code: [" + StatusCodeName(GetCode()) + "], msg: [" + GetMsg() + "]"; } StatusCode Status::GetCode() const { - return code_; + return state_ == nullptr ? K_OK : state_->code; } std::string Status::GetMsg() const { - return errMsg_; + return state_ == nullptr ? "" : state_->errMsg; } void Status::AppendMsg(const std::string &appendMsg) { - errMsg_ += (!errMsg_.empty() && errMsg_.back() != '.' ? ". " : " ") + appendMsg; + if (IsOk()) { + return; + } + auto &errMsg = state_->errMsg; + errMsg += (!errMsg.empty() && errMsg.back() != '.' ? ". " : " ") + appendMsg; } +void Status::Assign(const Status &other) noexcept +{ + if (other.IsOk()) { + state_ = nullptr; + return; + } + if (state_ == nullptr) { + state_ = std::make_unique(); + } + *state_ = *other.state_; +} + + // clang-format off. #define STATUS_CODE_DEF(code, msg) \ case StatusCode::code: \ @@ -138,8 +159,7 @@ std::string Status::StatusCodeName(StatusCode code) #undef STATUS_CODE_DEF #define STATUS_CODE_DEF(code, msg) \ - else if (name == #code) \ - { \ + else if (name == #code) { \ statusCode = StatusCode::code; \ } StatusCode GetStatusCodeByName(const std::string &name) diff --git a/src/datasystem/common/util/thread_pool.cpp b/src/datasystem/common/util/thread_pool.cpp index 406a980f14379180fba62dc287bb36e98eca57e7..5cf6780255e355c09027a2ce1570273374c23a84 100644 --- a/src/datasystem/common/util/thread_pool.cpp +++ b/src/datasystem/common/util/thread_pool.cpp @@ -91,9 +91,9 @@ void ThreadPool::DoThreadWork() void ThreadPool::AddThread() { - std::lock_guard workerLock(workersMtx_); auto thread = Thread([this] { this->DoThreadWork(); }); thread.set_name(name_); + std::lock_guard workerLock(workersMtx_); workers_[thread.get_id()] = std::move(thread); } diff --git a/src/datasystem/master/CMakeLists.txt b/src/datasystem/master/CMakeLists.txt index 659cd598b267944e1ee2878aaaf24915c4b22b19..f5a9c3f0fef031773d0878391ee8623bf64a3f84 100644 --- a/src/datasystem/master/CMakeLists.txt +++ b/src/datasystem/master/CMakeLists.txt @@ -20,6 +20,7 @@ set(MASTER_DEPEND_LIBS common_rocksdb common_rpc_zmq common_util + string_ref ds_server master_heartbeat_protos master_object_cache diff --git a/src/datasystem/master/meta_addr_info.h b/src/datasystem/master/meta_addr_info.h index fdbdcf352393038052df0d4306f572aec613e433..4148db4bfaf84540905d76114e346c6d5905c83f 100644 --- a/src/datasystem/master/meta_addr_info.h +++ b/src/datasystem/master/meta_addr_info.h @@ -100,6 +100,11 @@ public: isFromOtherAz_ = false; } + bool Empty() const + { + return addr_.Empty(); + } + private: HostPort addr_; std::string dbName_; diff --git a/src/datasystem/master/metadata_redirect_helper.h b/src/datasystem/master/metadata_redirect_helper.h index c4716172428617445410e640b494b09b47b673ac..e2d138e5a8dcc8e435fb3873fd067beb317e30f1 100644 --- a/src/datasystem/master/metadata_redirect_helper.h +++ b/src/datasystem/master/metadata_redirect_helper.h @@ -40,6 +40,8 @@ #include "datasystem/protos/utils.pb.h" #include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" +DS_DECLARE_bool(enable_redirect); + namespace datasystem { namespace master { using TbbMigratingTable = tbb::concurrent_hash_map; @@ -146,7 +148,7 @@ protected: void FillRedirectResponseInfo(Rsp &response, const std::string &id, bool &redirect) { std::string newMetaAddr; - if (!redirect) { + if (!redirect || !FLAGS_enable_redirect) { VLOG(1) << "receive redirect object: " << id; return; } @@ -188,7 +190,7 @@ protected: template void FillRedirectResponseInfos(Rsp &rsp, std::vector &ids, bool redirect) { - if (!redirect) { + if (!redirect || !FLAGS_enable_redirect) { return; } std::string localAddr; diff --git a/src/datasystem/master/object_cache/delete_object_mediator.cpp b/src/datasystem/master/object_cache/delete_object_mediator.cpp index 4e39e3e5dc1d01adf1856b2d9474269dd574f29c..f0fe4779ca41bf5bb21b507980d20aa40fb77b28 100644 --- a/src/datasystem/master/object_cache/delete_object_mediator.cpp +++ b/src/datasystem/master/object_cache/delete_object_mediator.cpp @@ -140,12 +140,12 @@ const std::vector &DeleteObjectMediator::GetToBeNotifiedNestedRefs( return toBeNotifiedNestedRefs_; } -void DeleteObjectMediator::SetObjKey2Version(std::unordered_map &&objKey2Version) +void DeleteObjectMediator::SetObjKey2Version(std::unordered_map &&objKey2Version) { objKey2Version_ = std::move(objKey2Version); } -bool DeleteObjectMediator::CheckIfExpired(const std::string &objKey, int64_t currVersion) +bool DeleteObjectMediator::CheckIfExpired(const std::string &objKey, uint64_t currVersion) { auto iter = objKey2Version_.find(objKey); return iter != objKey2Version_.end() && currVersion > iter->second; diff --git a/src/datasystem/master/object_cache/delete_object_mediator.h b/src/datasystem/master/object_cache/delete_object_mediator.h index 9a3ef7ec0a7e1555c8e6c9440954567f8047fcb0..e8c6619e48f338c0f51b6adfb8f4f9253ab80a42 100644 --- a/src/datasystem/master/object_cache/delete_object_mediator.h +++ b/src/datasystem/master/object_cache/delete_object_mediator.h @@ -165,7 +165,7 @@ public: * @brief If the object in the delete request carries a version, store the version here. * @param[in] objKey2Version The object2Version massage in delete request. */ - void SetObjKey2Version(std::unordered_map &&objKey2Version); + void SetObjKey2Version(std::unordered_map &&objKey2Version); /** * @brief Check whether the version of the object in the delete request is expired. @@ -173,7 +173,7 @@ public: * @param[in] currVersion The version of this object in this node. * @return T/F */ - bool CheckIfExpired(const std::string &objKey, int64_t currVersion); + bool CheckIfExpired(const std::string &objKey, uint64_t currVersion); /** * @brief Get the version of the object in the delete request. @@ -200,7 +200,7 @@ private: std::unordered_set failedIds_; std::unordered_set successDelIds_; std::unordered_set hashObjsWithoutMeta_; - std::unordered_map objKey2Version_; + std::unordered_map objKey2Version_; std::vector outdatedObjs_; /** * The key of idsNeedToNotifyWorker_ is objectKeys, value is set of worker addresses which master need to notify. diff --git a/src/datasystem/master/object_cache/device/master_dev_npu_events.cpp b/src/datasystem/master/object_cache/device/master_dev_npu_events.cpp index 637546c40c43a2e33624cbb0b7f3196984416f04..0e88c68d423e9c43c89c8cc8be3efa912fddccfb 100644 --- a/src/datasystem/master/object_cache/device/master_dev_npu_events.cpp +++ b/src/datasystem/master/object_cache/device/master_dev_npu_events.cpp @@ -96,9 +96,10 @@ void NpuEventsTable::PutNpuEvent(const std::string &srcNpuId, std::shared_ptrobjectKey, srcNpuId, static_cast(npuEvent->eventType)); - (void)npuPendingGetReqsTable_.insert(npuPendingGetReqsAccess, - std::make_pair( - srcNpuId, std::make_shared>>())); + std::string srcId = srcNpuId; + (void)npuPendingGetReqsTable_.insert( + npuPendingGetReqsAccess, std::make_pair( + std::move(srcId), std::make_shared>>())); npuPendingGetReqsAccess->second->Push(npuEvent); } diff --git a/src/datasystem/master/object_cache/expired_object_manager.cpp b/src/datasystem/master/object_cache/expired_object_manager.cpp index bf18f0fd2be432200bf53158c509d22d8e333efc..dc6506d9d9e96160b5d8699f0dc3719ee065ec43 100644 --- a/src/datasystem/master/object_cache/expired_object_manager.cpp +++ b/src/datasystem/master/object_cache/expired_object_manager.cpp @@ -140,6 +140,11 @@ void ExpiredObjectManager::ReloadExpireObjects(const std::vector lock(mutex_); masterOperationTimeCost.Append("InsertObject", timer.ElapsedMilliSecond()); @@ -147,12 +152,6 @@ Status ExpiredObjectManager::InsertObject(const std::string &objectKey, const ui // updated data will be deleted soon. RETURN_IF_NOT_OK(CheckObjectInAsyncDelete(objectKey, K_RUNTIME_ERROR)); RemoveObjectIfExistUnlock(objectKey); - // If objectKey in timeObj_ and we insert the same objectKey again, it means the object is being updated. - // If ttl is not zero, we should remove the object key and old expire time, and insert with new expire time again. - // If ttl is zero, we should remove the object key to keep the new data won't be deleted. - if (!acceptZero && ttlSecond == 0) { - return Status::OK(); - } return InsertObjectUnlock(objectKey, version, ttlSecond); } @@ -163,7 +162,7 @@ Status ExpiredObjectManager::InsertObjectUnlock(const std::string &objectKey, co auto iter = timedObj_.insert({ expiredTime, objectKey }); obj2Timed_[objectKey] = iter; VLOG(1) << FormatString("Insert the object %s with version %llu, ttl second %u, expireTime %llu, remain time %llu", - objectKey, version, ttlSecond, expiredTime, expiredTime - GetSteadyClockTimeStampUs()); + objectKey, version, ttlSecond, expiredTime, expiredTime - GetSystemClockTimeStampUs()); statisticsInfo_.IncreaseObj(); return Status::OK(); } @@ -202,7 +201,7 @@ void ExpiredObjectManager::AddSucceedObject(const std::unordered_map &objectKe uint64_t newTtlSecond = (UINT64_MAX - 1) / failedObjects_[objectKey] < RETRY_WAIT_TIME ? UINT64_MAX : static_cast(RETRY_WAIT_TIME) * failedObjects_[objectKey] + 1; - uint64_t expiredTime = CalcExpireTime(GetSteadyClockTimeStampUs(), newTtlSecond); + uint64_t expiredTime = CalcExpireTime(GetSystemClockTimeStampUs(), newTtlSecond); auto iter = timedObj_.insert({ expiredTime, objectKey }); obj2Timed_[objectKey] = iter; LOG(INFO) << FormatString( @@ -238,7 +237,7 @@ void ExpiredObjectManager::AddFailedObject(const std::set &objectKe std::unordered_map ExpiredObjectManager::GetExpiredObject() { std::unordered_map expiredObject; - uint64_t currentTime = static_cast(GetSteadyClockTimeStampUs()); + uint64_t currentTime = static_cast(GetSystemClockTimeStampUs()); std::lock_guard lock(mutex_); for (auto iter = timedObj_.begin(); iter != timedObj_.end() && iter->first <= currentTime && expiredObject.size() < MAX_DEL_BATCH_NUM;) { @@ -246,7 +245,7 @@ std::unordered_map ExpiredObjectManager::GetExpiredObject currentTime); expiredObject[iter->second] = iter->first; uint64_t delayTimeSecond = - (GetSteadyClockTimeStampUs() - iter->first) / TIME_UNIT_CONVERSION / TIME_UNIT_CONVERSION; + (GetSystemClockTimeStampUs() - iter->first) / TIME_UNIT_CONVERSION / TIME_UNIT_CONVERSION; statisticsInfo_.IncreaseDelayGetObj(delayTimeSecond); (void)readyExpiredObjects_.emplace(iter->second); (void)obj2Timed_.erase(iter->second); @@ -261,7 +260,7 @@ Status ExpiredObjectManager::AsyncDelete(std::unordered_map requestObjectKeyMap; std::transform(expiredObjMap.begin(), expiredObjMap.end(), @@ -274,6 +273,7 @@ Status ExpiredObjectManager::AsyncDelete(std::unordered_mapFindNeedDeleteIds(mediator); std::unordered_set hashObjsWithoutMeta = mediator.GetHashObjsWithoutMeta(); @@ -297,7 +297,7 @@ Status ExpiredObjectManager::AsyncDelete(std::unordered_mapfirst > currentUs ? obj2Timed_[objectKey]->first - currentUs : 0; remainTimeSecond = remainUs / TIME_UNIT_CONVERSION / TIME_UNIT_CONVERSION; RemoveObjectIfExistUnlock(objectKey); diff --git a/src/datasystem/master/object_cache/oc_metadata_manager.cpp b/src/datasystem/master/object_cache/oc_metadata_manager.cpp index d0853a31f022c77761214e39f10474ebe4ac1ebf..69364706b13eb9d1f21969631b0763a4c60bf641 100644 --- a/src/datasystem/master/object_cache/oc_metadata_manager.cpp +++ b/src/datasystem/master/object_cache/oc_metadata_manager.cpp @@ -75,6 +75,8 @@ DS_DEFINE_string(rocksdb_store_dir, "~/datasystem/rocksdb", "specify in rocksdb scenario. The rocksdb database is used to persistently store the metadata " "in the master, so that the metadata before the restart can be re-obtained when the master restarts."); DS_DEFINE_validator(rocksdb_store_dir, &Validator::ValidatePathString); +DS_DEFINE_bool(enable_redirect, "true", + "enable query meta redirect when scale up or voluntary scale down, default is false"); DS_DECLARE_string(etcd_address); DS_DECLARE_bool(async_delete); @@ -105,10 +107,11 @@ OCMetadataManager::OCMetadataManager(std::shared_ptr akSkManager, R persistApi_(persistApi), newNode_(newNode) { + bool isEnabled = FLAGS_rocksdb_write_mode != "none" || FLAGS_oc_io_from_l2cache_need_metadata; if (FLAGS_enable_meta_replica && !etcdCM_->IsCentralized()) { - objectStore_ = std::make_shared(rocksStore, nullptr, true); + objectStore_ = std::make_shared(rocksStore, nullptr, isEnabled); } else { - objectStore_ = std::make_shared(rocksStore, etcdStore, true); + objectStore_ = std::make_shared(rocksStore, etcdStore, isEnabled); } if (etcdCM_ != nullptr && !etcdCM_->IsCentralized()) { dbName_ = dbName; @@ -444,11 +447,10 @@ void OCMetadataManager::SetMetaInfo(const ObjectMetaPb &newMeta, const std::stri } Status OCMetadataManager::NotifyOtherAzNodeRemoveMeta(const std::string &objectKey, int64_t version, - const ObjectMetaPb &newMeta) + ObjectMetaStore::WriteType type) { - if (!FLAGS_other_cluster_names.empty()) { - LOG(INFO) << "Notify nodes in other clusters to remove meta for object: " << objectKey; - } + RETURN_OK_IF_TRUE(FLAGS_other_cluster_names.empty()); + LOG(INFO) << "Notify nodes in other clusters to remove meta for object: " << objectKey; std::unordered_map metaAddrInfos; RETURN_IF_NOT_OK(etcdCM_->GetAllNodesInOtherAzsByHash(objectKey, metaAddrInfos, true)); for (const auto &item : metaAddrInfos) { @@ -459,8 +461,7 @@ Status OCMetadataManager::NotifyOtherAzNodeRemoveMeta(const std::string &objectK LOG(WARNING) << "Fail to notify other az's node to remove meta: " << rc.ToString(); // DFX LOG_IF_ERROR(notifyWorkerManager_->InsertAsyncWorkerOp( - "", objectKey, { NotifyWorkerOpType::REMOVE_META, version, { item.first } }, true, - WriteMode2MetaType(newMeta.config().write_mode())), + "", objectKey, { NotifyWorkerOpType::REMOVE_META, version, { item.first } }, true, type), "Insert remote meta notification to AsyncWorkerOpTable failed, obj: " + objectKey); } } @@ -472,6 +473,7 @@ Status OCMetadataManager::CreateMetaFirstTime(const ObjectMetaPb &newMeta, const TbbMetaTable::accessor &accessor) { const std::string &objectKey = newMeta.object_key(); + ObjectMetaStore::WriteType type = WriteMode2MetaType(newMeta.config().write_mode()); ObjectMeta metaCache; SetMetaInfo(newMeta, address, version, metaCache); accessor->second = metaCache; @@ -479,12 +481,11 @@ Status OCMetadataManager::CreateMetaFirstTime(const ObjectMetaPb &newMeta, const std::string serializedStr; RETURN_IF_NOT_OK(objectStore_->CreateSerializedStringForMeta(objectKey, accessor->second.meta, serializedStr)); // Create meta info in rocksDB. - RETURN_IF_NOT_OK(objectStore_->CreateOrUpdateMeta(objectKey, serializedStr, - WriteMode2MetaType(metaCache.meta.config().write_mode()))); + RETURN_IF_NOT_OK(objectStore_->CreateOrUpdateMeta(objectKey, serializedStr, type)); accessor.release(); if (!HasWorkerId(objectKey)) { - RETURN_IF_NOT_OK(NotifyOtherAzNodeRemoveMeta(objectKey, version, newMeta)); + RETURN_IF_NOT_OK(NotifyOtherAzNodeRemoveMeta(objectKey, version, type)); } // Update subscribeCache. if multiset_state == pending, create not finish, don't update subscribe. @@ -595,7 +596,7 @@ Status OCMetadataManager::CreatePendingMeta(const ObjectMetaPb &newMeta, const s // If the timestamp of the object does not exceed multiSetTimestamp, return K_TRY_AGAIN. // Except for the same address, we can refresh meta. if (address != accessor->second.meta.primary_address() - && GetSteadyClockTimeStampUs() < accessor->second.multiSetTimestamp) { + && GetSystemClockTimeStampUs() < accessor->second.multiSetTimestamp) { return rc; } LOG(INFO) << FormatString("[ObjectKey %s] PreCommit changed from %s to %s", objectKey, @@ -655,7 +656,6 @@ Status OCMetadataManager::CreateMetaForBinaryFormat(const ObjectMetaPb &newMeta, // it is not allowed to double Set. RETURN_IF_NOT_OK(CheckExistenceOpt(accessor->second, objectKey, newMeta.existence(), firstOne)); version = static_cast(GetSystemClockTimeStampUs()); - uint64_t versionForTTL = static_cast(GetSteadyClockTimeStampUs()); RaiiPlus raiiP; if (!firstOne && !HasWorkerId(objectKey)) { @@ -692,12 +692,12 @@ Status OCMetadataManager::CreateMetaForBinaryFormat(const ObjectMetaPb &newMeta, if (!nestedObjectKeys.empty() && nestedRefManager_->IsNestedKeysDiff(objectKey, nestedObjectKeys)) { RETURN_IF_NOT_OK(nestedRefManager_->IncreaseNestedRefCnt(objectKey, nestedObjectKeys)); } - RETURN_IF_NOT_OK(expiredObjectManager_->InsertObject(objectKey, versionForTTL, newMeta.ttl_second())); + RETURN_IF_NOT_OK(expiredObjectManager_->InsertObject(objectKey, version, newMeta.ttl_second())); return s; } // Case 2: first time creating meta. RETURN_IF_NOT_OK(CreateMetaFirstTime(newMeta, address, version, nestedObjectKeys, accessor)); - RETURN_IF_NOT_OK(expiredObjectManager_->InsertObject(objectKey, versionForTTL, newMeta.ttl_second())); + RETURN_IF_NOT_OK(expiredObjectManager_->InsertObject(objectKey, version, newMeta.ttl_second())); VLOG(1) << FormatString("[ObjectKey %s] CreateMeta finished: objectKey: %s, worker address: %s", objectKey, objectKey, address); return Status::OK(); @@ -719,44 +719,153 @@ Status OCMetadataManager::CreateMultiMeta(const CreateMultiMetaReqPb &req, Creat return CreateMultiMetaNtx(req, rsp); } +Status OCMetadataManager::UpdateMeta(ObjectMeta &meta, const ObjectMetaPb &newMeta, const std::string &address, + int64_t &version) +{ + const std::string &objectKey = newMeta.object_key(); + // In the NX Set scenario, when the worker restarts, if there is data in the L2 cache, + // it is not allowed to double Set. + bool firstOne = false; + RETURN_IF_NOT_OK(CheckExistenceOpt(meta, objectKey, newMeta.existence(), firstOne)); + RaiiPlus raiiP; + if (!HasWorkerId(objectKey)) { + MarkUpdatingAndUpdateRemoveMetaNotification(objectKey, version, raiiP); + } + CHECK_FAIL_RETURN_STATUS( + meta.multiSetState != PENDING, K_TRY_AGAIN, + FormatString("update meta failed, multi meta objectKey(%s) is creating, wait and try again", objectKey)); + BinaryFormatParamsStruct newMateDate = { .writeMode = newMeta.config().write_mode(), + .dataFormat = newMeta.config().data_format(), + .consistencyType = newMeta.config().consistency_type(), + .cacheType = newMeta.config().cache_type(), + .isReplica = newMeta.config().is_replica() }; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckBinaryFormatParamMatch(objectKey, meta, newMateDate), "Check format failed"); + + // Cache Invalidation Logic. + Status s = DoBinaryCacheInvalidationUnlocked(objectKey, meta, + { .newAddress = address, + .newVersion = version, + .newDataSz = newMeta.data_size(), + .newLifeState = newMeta.life_state(), + .newBlobSizes = newMeta.device_info().blob_sizes() }); + if (s.IsError()) { + // If the cache invalid processing fails, delete the address from the meta. + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->RemoveObjectLocation(objectKey, address), + "Remove location failed from rocksdb."); + (void)meta.locations.erase(address); + } + + RETURN_IF_NOT_OK(expiredObjectManager_->InsertObject(objectKey, version, newMeta.ttl_second())); + return s; +} + +Status OCMetadataManager::CreateMeta(const std::string &objectKey, ObjectMeta &newMeta, const std::string &address, + int64_t &version, bool &firstOne) +{ + INJECT_POINT("master.create_meta_failure"); + auto &metaPb = newMeta.meta; + const auto ttl = metaPb.ttl_second(); + ObjectMetaStore::WriteType type = WriteMode2MetaType(metaPb.config().write_mode()); + std::shared_lock lck(metaTableMutex_); + TbbMetaTable::accessor accessor; + firstOne = metaTable_.insert(accessor, objectKey); + if (!firstOne) { + return UpdateMeta(accessor->second, metaPb, address, version); + } + ObjectMeta metaCache; + if (objectStore_->IsPersistenceEnabled()) { + std::string serializedStr; + RETURN_IF_NOT_OK(objectStore_->CreateSerializedStringForMeta(objectKey, metaPb, serializedStr)); + RETURN_IF_NOT_OK(objectStore_->CreateOrUpdateMeta(objectKey, serializedStr, type)); + } + accessor->second = std::move(newMeta); + accessor.release(); + if (!FLAGS_other_cluster_names.empty() && !HasWorkerId(objectKey)) { + RETURN_IF_NOT_OK(NotifyOtherAzNodeRemoveMeta(objectKey, version, type)); + } + return expiredObjectManager_->InsertObject(objectKey, version, ttl); +} + +void OCMetadataManager::ConstructMetaInfo(const CreateMultiMetaReqPb &req, const ObjectBaseInfoPb &info, + int64_t version, ObjectMetaPb &meta) +{ + meta.set_object_key(info.object_key()); + meta.set_data_size(info.data_size()); + meta.set_version(version); + meta.set_life_state(req.life_state()); + *meta.mutable_config() = req.config(); + meta.set_primary_address(req.address()); + meta.set_ttl_second(req.ttl_second()); + meta.set_existence(req.existence()); + if (info.has_device_info()) { + *meta.mutable_device_info() = info.device_info(); + } +} Status OCMetadataManager::CreateMultiMetaNtx(const CreateMultiMetaReqPb &req, CreateMultiMetaRspPb &rsp) { std::vector rollBackIds; Status lastRc; - int64_t version = 0; - int64_t failVersion = 0; - for (const auto &metaInfo : req.metas()) { - if (metaInfo.object_key().empty() || req.address().empty()) { - rsp.add_failed_object_keys(metaInfo.object_key()); - rsp.add_version(failVersion); - lastRc = Status(K_INVALID, "CreateMeta: Cannot CreateMeta with empty objectKey or server address."); + if (req.address().empty()) { + return Status(K_INVALID, "CreateMeta: Cannot CreateMeta with server address."); + } + std::vector objsFirst; + objsFirst.reserve(req.metas_size()); + int64_t version = static_cast(GetSystemClockTimeStampUs()); + std::vector newMetas; + newMetas.reserve((req.metas_size())); + for (int i = 0; i < req.metas_size(); i++) { + const ObjectBaseInfoPb &info = req.metas(i); + ObjectMeta &meta = newMetas.emplace_back(); + meta.locations.emplace(req.address()); + ConstructMetaInfo(req, info, version, meta.meta); + } + for (int i = 0; i < req.metas_size(); i++) { + const auto &objectKey = req.metas(i).object_key(); + if (objectKey.empty()) { + rsp.add_failed_object_keys(objectKey); + lastRc = Status(K_INVALID, "CreateMeta: Cannot CreateMeta with server address."); + continue; } bool firstOne = false; - auto status = CreateMeta(metaInfo, req.address(), {}, version, firstOne); + auto status = CreateMeta(objectKey, newMetas[i], req.address(), version, firstOne); + if (firstOne) { + objsFirst.emplace_back(objectKey); + } if (status.IsError()) { // meta maybe already insert to metatable. if not first one, no need delete old meta. if (firstOne) { - rollBackIds.emplace_back(metaInfo.object_key()); + rollBackIds.emplace_back(objectKey); } - rsp.add_failed_object_keys(metaInfo.object_key()); - rsp.add_version(failVersion); + rsp.add_failed_object_keys(objectKey); lastRc = status; - } else { - rsp.add_version(version); } } + ExecuteAsyncTask([this, objsFirst]() { + for (const auto &objKey : objsFirst) { + std::shared_lock lck(metaTableMutex_); + TbbMetaTable::const_accessor accessor; + if (!metaTable_.find(accessor, objKey)) { + LOG(WARNING) << "Object " << objKey << " can't found in metaTable, notify subscribe failed"; + continue; + } + ObjectMeta metaCache = accessor->second; + accessor.release(); + UpdateSubscribeCache(objKey, metaCache); + } + }); RollBackMultiMetaWhenCreateFailed(rollBackIds, req.address()); rsp.mutable_last_rc()->set_error_msg(lastRc.GetMsg()); rsp.mutable_last_rc()->set_error_code(lastRc.GetCode()); + rsp.set_version(version); return Status::OK(); } Status OCMetadataManager::CreateMultiMetaTx(const CreateMultiMetaReqPb &req, CreateMultiMetaRspPb &rsp) { std::vector successIds; - int64_t pendingTtl = GetSteadyClockTimeStampUs() + MSET_PENDING_TTL_US; + int64_t pendingTtl = GetSystemClockTimeStampUs() + MSET_PENDING_TTL_US; INJECT_POINT("master.CreateMultiMetaTx.pendingTtl", [&pendingTtl](int ttlUs) { - pendingTtl = GetSteadyClockTimeStampUs() + ttlUs; + pendingTtl = GetSystemClockTimeStampUs() + ttlUs; return Status::OK(); }); for (const auto &metaInfo : req.metas()) { @@ -764,8 +873,10 @@ Status OCMetadataManager::CreateMultiMetaTx(const CreateMultiMetaReqPb &req, Cre RollBackMultiMetaWhenCreateFailed(successIds, req.address()); RETURN_STATUS(K_INVALID, "CreateMeta: Cannot CreateMeta with empty objectKey or server address."); } + ObjectMetaPb meta; + ConstructMetaInfo(req, metaInfo, 0, meta); bool firstOne = false; - auto status = CreatePendingMeta(metaInfo, req.address(), pendingTtl, firstOne); + auto status = CreatePendingMeta(meta, req.address(), pendingTtl, firstOne); if (status.IsError()) { // meta maybe already insert to metatable. if not first one, no need delete old meta. if (firstOne) { @@ -781,7 +892,7 @@ Status OCMetadataManager::CreateMultiMetaTx(const CreateMultiMetaReqPb &req, Cre if (req.is_pre_commit()) { return Status::OK(); } - auto type = WriteMode2MetaType(req.metas().begin()->config().write_mode()); + auto type = WriteMode2MetaType(req.config().write_mode()); uint64_t version = static_cast(GetSystemClockTimeStampUs()); auto status = PublishMultiMeta(successIds, req.address(), type, version, rsp); if (status.IsError()) { @@ -830,7 +941,6 @@ Status OCMetadataManager::PublishMultiMeta(const std::vector &objec ObjectMetaStore::WriteType type, uint64_t version, CreateMultiMetaRspPb &rsp) { std::unordered_map metaInfos; - uint64_t versionForTTL = static_cast(GetSteadyClockTimeStampUs()); for (const auto &objKey : objectKeys) { std::shared_lock lck(metaTableMutex_); TbbMetaTable::accessor accessor; @@ -846,12 +956,12 @@ Status OCMetadataManager::PublishMultiMeta(const std::vector &objec if (objectMeta.config().data_format() != (uint64_t)DataFormat::HASH_MAP) { UpdateSubscribeCache(objKey, accessor->second); } - RETURN_IF_NOT_OK(expiredObjectManager_->InsertObject(objKey, versionForTTL, objectMeta.ttl_second())); + RETURN_IF_NOT_OK(expiredObjectManager_->InsertObject(objKey, version, objectMeta.ttl_second())); std::string serializedStr; RETURN_IF_NOT_OK(objectStore_->CreateSerializedStringForMeta(objKey, objectMeta, serializedStr)); metaInfos.emplace(objKey, serializedStr); - rsp.add_version(version); } + rsp.set_version(version); return objectStore_->CreateOrUpdateBatchMeta(metaInfos, type); } @@ -1638,7 +1748,7 @@ void OCMetadataManager::DeleteAllCopyMetaImpl( { const std::string &sourceWorker = request.address(); std::vector objectKeys = { request.object_keys().begin(), request.object_keys().end() }; - std::unordered_map objKey2Version; + std::unordered_map objKey2Version; for (const auto &objWithVersion : request.ids_with_version()) { objectKeys.emplace_back(objWithVersion.id()); objKey2Version.emplace(objWithVersion.id(), objWithVersion.version()); @@ -2076,8 +2186,7 @@ Status OCMetadataManager::UpdateMetaByState(const UpdateMetaReqPb &request, Obje RETURN_IF_NOT_OK(nestedRefManager_->IncreaseNestedRefCnt(objectKey, nestedObjectKeys)); } LOG(INFO) << "UpdateMeta finished"; - uint64_t versionForTTL = static_cast(GetSteadyClockTimeStampUs()); - return expiredObjectManager_->InsertObject(objectKey, versionForTTL, request.ttl_second()); + return expiredObjectManager_->InsertObject(objectKey, version, request.ttl_second()); } Status OCMetadataManager::UpdateMeta(const UpdateMetaReqPb &request, UpdateMetaRspPb &response) @@ -2187,8 +2296,8 @@ void OCMetadataManager::InsertExpireObjects(ObjectMetaPb &metaPb, { if (metaPb.ttl_second() > 0) { INJECT_POINT("master.LoadMeta.steadyClockIsDifferent", [&metaPb]() { metaPb.set_version(0); }); - long curSteadyClock = GetSteadyClockTimeStampUs(); - expireObjects.emplace_back(metaPb.object_key(), curSteadyClock, metaPb.ttl_second()); + long curSystemClock = GetSystemClockTimeStampUs(); + expireObjects.emplace_back(metaPb.object_key(), curSystemClock, metaPb.ttl_second()); } } @@ -2810,7 +2919,7 @@ Status OCMetadataManager::RecoverMasterAppRef(std::functionNeedRedirect(objectKey, masterAddr); if (!needRedirect) { @@ -3584,8 +3696,8 @@ void OCMetadataManager::AsyncDeleteByExpired(DeleteObjectMediator &mediator) // For those objs that do not have metadata on this node, we will notify other az masters in the asynchronous queue // to delete the metadata. for (auto &objectKey : mediator.GetObjKeys()) { - uint64_t versionForTTL = static_cast(GetSteadyClockTimeStampUs()); - Status rc = expiredObjectManager_->InsertObject(objectKey, versionForTTL, MIN_TTL_SECOND, true); + uint64_t version = static_cast(GetSystemClockTimeStampUs()); + Status rc = expiredObjectManager_->InsertObject(objectKey, version, MIN_TTL_SECOND, true); // if object is being delete, don't need to insert again. if (rc.IsOk() || rc.GetCode() == K_TRY_AGAIN) { mediator.AddSuccessDelId(objectKey); @@ -3660,7 +3772,7 @@ bool OCMetadataManager::SaveOneMeta(const MetaForMigrationPb &objMeta, Status &s return false; } - uint64_t currentTime = static_cast(GetSteadyClockTimeStampUs()); + uint64_t currentTime = static_cast(GetSystemClockTimeStampUs()); // Theoretically, inserts don't fail. (void)expiredObjectManager_->InsertObject(objectKey, currentTime, objMeta.remain_ttl_second(), objMeta.enable_ttl()); @@ -3892,7 +4004,7 @@ void OCMetadataManager::HandleMetaDataMigrationFailed( const MetaForMigrationPb &objMeta, const std::unordered_map>> &asyncMap) { - expiredObjectManager_->InsertObject(objMeta.object_key(), GetSteadyClockTimeStampUs(), objMeta.remain_ttl_second(), + expiredObjectManager_->InsertObject(objMeta.object_key(), GetSystemClockTimeStampUs(), objMeta.remain_ttl_second(), objMeta.enable_ttl()); for (auto &async_op : objMeta.async_ops()) { // FillMetadataForMigration does not delete data from etcd. Therefore, type is set to ROCKS_ONLY. @@ -4357,8 +4469,8 @@ Status OCMetadataManager::Expire(const ExpireReqPb &req, ExpireRspPb &rsp) } accessor->second.meta.set_ttl_second(req.ttl_second()); } - uint64_t versionForTTL = static_cast(GetSteadyClockTimeStampUs()); - auto rc = expiredObjectManager_->InsertObject(objectKey, versionForTTL, req.ttl_second()); + uint64_t version = static_cast(GetSystemClockTimeStampUs()); + auto rc = expiredObjectManager_->InsertObject(objectKey, version, req.ttl_second()); if (rc.IsError()) { LOG(WARNING) << "Faied to insert object[" << objectKey << "] with new ttl second."; rsp.add_failed_object_keys(objectKey); diff --git a/src/datasystem/master/object_cache/oc_metadata_manager.h b/src/datasystem/master/object_cache/oc_metadata_manager.h index 4d444557023522241517ccfe130b39d592074fcd..134e92882d48d88f0092f52bb24fabce5f114375 100644 --- a/src/datasystem/master/object_cache/oc_metadata_manager.h +++ b/src/datasystem/master/object_cache/oc_metadata_manager.h @@ -293,6 +293,16 @@ public: */ void SetMetaInfo(const ObjectMetaPb &newMeta, const std::string &address, int64_t version, ObjectMeta &metaCache); + /** + * @brief Construct meta info for meta + * @param[in] req request info + * @param[in] info object base info + * @param[in] version object version + * @param[out] meta out meta info. + */ + void ConstructMetaInfo(const CreateMultiMetaReqPb &req, const ObjectBaseInfoPb &info, int64_t version, + ObjectMetaPb &meta); + /** * @brief Insert to etcd memory table. * @param[in] objectKey The key of object. @@ -556,7 +566,7 @@ public: } }); response.set_ref_is_moving(false); - if (!redirect) { + if (!redirect || !FLAGS_enable_redirect) { VLOG(1) << "receive redirect req"; return; } @@ -1097,7 +1107,7 @@ public: std::shared_ptr GetDeviceOcManager(); #ifdef WITH_TESTS - OCNestedManager *CheckIsNoneNestedRefById() + OCNestedManager *GetNestedRefManager() { return nestedRefManager_.get(); } @@ -1222,6 +1232,27 @@ private: Status CreatePendingMeta(const ObjectMetaPb &newMeta, const std::string &address, int64_t pendingTtl, bool &firstOne); + /** + * @brief Update meta info in cache and rocksdb. + * @param[in] meta The meta info + * @param[in] newMeta The new meta info + * @param[in] address The request address + * @param[in] version The object version + * @param[out] firstOne Create first time or not + */ + Status UpdateMeta(ObjectMeta &meta, const ObjectMetaPb &newMeta, const std::string &address, int64_t &version); + + /** + * @brief Create meta info in cache and rocksdb. + * @param[in] objectKey The object key. + * @param[in] newMeta The new meta info + * @param[in] address The request address + * @param[in] version The object version + * @param[out] firstOne Create first time or not + */ + Status CreateMeta(const std::string &objectKey, ObjectMeta &newMeta, const std::string &address, int64_t &version, + bool &firstOne); + /** * @brief Recovery object locations * @param[in] objLocMap The map record object and locations. @@ -1662,10 +1693,10 @@ private: * Note: Hash type keys will care about this. * @param[in] objectKey The object key. * @param[in] version The new meta's version. - * @param[in] newMeta Metadata of object. + * @param[in] type write type. * @return Status of the call. */ - Status NotifyOtherAzNodeRemoveMeta(const std::string &objectKey, int64_t version, const ObjectMetaPb &newMeta); + Status NotifyOtherAzNodeRemoveMeta(const std::string &objectKey, int64_t version, ObjectMetaStore::WriteType type); /** * @brief Process remove meta notification from other az. diff --git a/src/datasystem/master/object_cache/oc_nested_manager.h b/src/datasystem/master/object_cache/oc_nested_manager.h index f398479a3689465faab285dabeabb1712974d60f..442e84a9f7135e7dd8f28d8ef2573362e41e4775 100644 --- a/src/datasystem/master/object_cache/oc_nested_manager.h +++ b/src/datasystem/master/object_cache/oc_nested_manager.h @@ -61,7 +61,8 @@ namespace master { class OCNestedManager { public: OCNestedManager(std::shared_ptr objectRockStore, EtcdClusterManager *cm) - : objectStore_(std::move(objectRockStore)), nestedRef_(std::make_unique(false)) + : objectStore_(std::move(objectRockStore)), + nestedRef_(std::make_unique>(false)) { etcdCM_ = cm; } @@ -177,7 +178,7 @@ public: private: std::unordered_map> dependencyTable_; std::shared_ptr objectStore_; - std::unique_ptr nestedRef_; + std::unique_ptr> nestedRef_; // std::string -> ObjectKey std::shared_timed_mutex mutex_; EtcdClusterManager *etcdCM_ = nullptr; }; diff --git a/src/datasystem/master/object_cache/store/object_meta_store.h b/src/datasystem/master/object_cache/store/object_meta_store.h index c8887f5db10ebd9007c8d042acdca27fe649c3c2..6619cf92f0d99dfd64b15794fc385666dd806bec 100644 --- a/src/datasystem/master/object_cache/store/object_meta_store.h +++ b/src/datasystem/master/object_cache/store/object_meta_store.h @@ -96,6 +96,15 @@ public: */ Status Init(); + /** + * @brief Is persistence enabled or not. + * @return true If persistence enable. + */ + bool IsPersistenceEnabled() const + { + return isPersistenceEnabled_; + } + /** * @brief Create the serialized string of object meta. * @param[in] objectKey id of the object meta. diff --git a/src/datasystem/protos/master_object.proto b/src/datasystem/protos/master_object.proto index 4dd2c0fd93373be81eaa408022f2754fd2da96cc..30569daa655cfa33df9b31aa603f1c4034349499 100644 --- a/src/datasystem/protos/master_object.proto +++ b/src/datasystem/protos/master_object.proto @@ -60,12 +60,18 @@ message CreateMetaRspPb { } message CreateMultiMetaReqPb { - repeated ObjectMetaPb metas = 1; + repeated ObjectBaseInfoPb metas = 1; string address = 2; int64 timeout = 3; bool isTx = 4; bool is_pre_commit = 5; bool redirect = 6; + ConfigPb config = 7; + uint64 version = 8; // version control + // For the meaning of the value, see 'ObjectLifeState' enum class. + uint32 life_state = 9; + uint32 ttl_second = 10; + ExistenceOptPb existence = 11; // put to the end, the previous data is used to generate AK and SK signatures. uint64 timestamp = 100; @@ -74,7 +80,7 @@ message CreateMultiMetaReqPb { } message CreateMultiMetaRspPb { - repeated uint64 version = 1; + uint64 version = 1; repeated string failed_object_keys = 2; ErrorInfoPb last_rc = 3; repeated RedirectMetaInfo info = 4; diff --git a/src/datasystem/protos/meta_zmq.proto b/src/datasystem/protos/meta_zmq.proto index 160659abfa2c1b64dc9a31e2a53e6e92e8fce123..984927acf67d190f1f11f9df5d202374303a9fe5 100644 --- a/src/datasystem/protos/meta_zmq.proto +++ b/src/datasystem/protos/meta_zmq.proto @@ -34,12 +34,13 @@ message MetaPb { string client_id = 4; string worker_id = 5; string gateway_id = 6; - string routing_fd = 7; - int32 event_type = 8; + string routing_fd = 7; // Deprecated + int32 event_type = 8; repeated TickPb ticks = 9; string trace_id = 10; int64 timeout = 11; string db_name = 12; + int32 route_fd = 13; // New version of routing_fd // ak/sk required. string tenant_id = 98; @@ -104,8 +105,14 @@ message UrmaHandshakeReqPb { repeated uint32 jfr_ids = 3; HostPortPb address = 4; repeated JfrBondInfo bond_infos = 5; + repeated UrmaImportSegmentPb seg_infos = 6; } -// Exchange of jfr. Remote's jfr message UrmaHandshakeRspPb { } + +message RdmaHandshakeReqPb { +} + +message RdmaHandshakeRspPb { +} diff --git a/src/datasystem/protos/object_posix.proto b/src/datasystem/protos/object_posix.proto index 7f776d3acae83e6f7be26f0d5795c8c412645736..c36dd0c4b4ecfce087d21b03aca246eb610760c8 100644 --- a/src/datasystem/protos/object_posix.proto +++ b/src/datasystem/protos/object_posix.proto @@ -154,7 +154,8 @@ message GetReqPb { string tenant_id = 5; repeated uint64 read_offset_list = 6; repeated uint64 read_size_list = 7; - bool no_query_l2cache = 8 ; + bool no_query_l2cache = 8 ; + bool return_object_index = 9; // put to the end, the previous data is used to generate AK and SK signatures. uint64 timestamp = 100; @@ -177,6 +178,7 @@ message GetRspPb { uint32 consistency_type = 11; string shm_id = 12; uint32 cache_type = 13; + uint32 object_index = 14; } message PayloadInfoPb { string object_key = 1; @@ -188,6 +190,7 @@ message GetRspPb { uint32 consistency_type = 7; repeated uint32 part_index = 8; uint32 cache_type = 9; + uint32 object_index = 10; } repeated ObjectInfoPb objects = 1; repeated PayloadInfoPb payload_info = 2; diff --git a/src/datasystem/protos/share_memory.proto b/src/datasystem/protos/share_memory.proto index 7f341c6dec772a0de6534b310621542f777a5f72..baa8472d046c3ae961b5fe7a9da425d057910dbf 100644 --- a/src/datasystem/protos/share_memory.proto +++ b/src/datasystem/protos/share_memory.proto @@ -45,6 +45,8 @@ message RegisterClientReqPb { string tenant_id = 8; bool enable_cross_node = 9; string pod_name = 10; + bool enable_exclusive_connection = 11; + // DFX required, need to send something to worker when reconnect. repeated google.protobuf.Any extend = 50; @@ -76,6 +78,7 @@ message RegisterClientRspPb { repeated string available_workers = 18; bool enable_p2p_transfer = 19; uint64 client_reconnect_wait_s = 20; + string exclusive_conn_sockpath = 21; // put to the end, the previous data is used to generate AK and SK signatures. uint64 timestamp = 100; diff --git a/src/datasystem/protos/utils.proto b/src/datasystem/protos/utils.proto index dae691a2c21ce37b980b3bf6572c24eb523ebe37..3eb5b84ab9ef06fc86814011fb6d5251c2bf7335 100644 --- a/src/datasystem/protos/utils.proto +++ b/src/datasystem/protos/utils.proto @@ -51,9 +51,15 @@ message UrmaBondSegInfoPb{ } message UrmaImportSegmentPb { - HostPortPb request_address = 1; - /* segment */ - UrmaSegPb seg = 2; - uint64 seg_data_offset = 3; - UrmaBondSegInfoPb bond_info = 4; + UrmaSegPb seg = 1; + UrmaBondSegInfoPb bond_info = 2; +} + +message UrmaRemoteAddrPb { + uint64 seg_va = 1; + uint64 seg_data_offset = 2; + HostPortPb request_address = 3; +} + +message RdmaImportSegmentPb { } diff --git a/src/datasystem/protos/worker_object.proto b/src/datasystem/protos/worker_object.proto index 371db4499cf1a96143e8891b9a234fcc94e14a6c..5174e5d5a60dbc94576314800019c954d8547f7c 100644 --- a/src/datasystem/protos/worker_object.proto +++ b/src/datasystem/protos/worker_object.proto @@ -31,10 +31,10 @@ message GetObjectRemoteReqPb { string request_id = 2; bool try_lock = 3; uint64 version = 4; - uint64 read_offset = 5; // new - uint64 read_size = 6; // new + uint64 read_offset = 5; + uint64 read_size = 6; uint64 data_size = 7; - UrmaImportSegmentPb urma_info = 8; + UrmaRemoteAddrPb urma_info = 8; // put to the end, the previous data is used to generate AK and SK signatures. uint64 timestamp = 100; @@ -51,16 +51,7 @@ message GetObjectRemoteRspPb { } message BatchGetObjectRemoteReqPb { - message GetObjectRemoteBaseReqPb { - string object_key = 1; - bool try_lock = 2; - uint64 version = 3; - uint64 read_offset = 4; - uint64 read_size = 5; - uint64 data_size = 6; - UrmaImportSegmentPb urma_info = 7; - } - repeated GetObjectRemoteBaseReqPb requests = 1; + repeated GetObjectRemoteReqPb requests = 1; // put to the end, the previous data is used to generate AK and SK signatures. uint64 timestamp = 100; @@ -247,25 +238,31 @@ message ChangePrimaryCopyRspPb { repeated string success_ids = 1; } -message ObjectMetaPb { - message ConfigPb { - // The config of object +message ConfigPb { + // The config of object - // For the meaning of the value, see 'WriteMode' enum class. - uint32 write_mode = 1; + // For the meaning of the value, see 'WriteMode' enum class. + uint32 write_mode = 1; - // For the meaning of the value, see 'DataFormat' enum class. - uint32 data_format = 2; + // For the meaning of the value, see 'DataFormat' enum class. + uint32 data_format = 2; - // For the meaning of the value, see 'ConsistencyType' enum class. - uint32 consistency_type = 3; + // For the meaning of the value, see 'ConsistencyType' enum class. + uint32 consistency_type = 3; - uint32 cache_type = 4; + uint32 cache_type = 4; - // If true, this won't invalidate existing copies. - bool is_replica = 5; - } + // If true, this won't invalidate existing copies. + bool is_replica = 5; +} +message ObjectBaseInfoPb { + string object_key = 1; + uint64 data_size = 2; + DeviceMetaInfoPb device_info = 3; +} + +message ObjectMetaPb { string object_key = 1; uint64 data_size = 2; uint64 version = 3; // version control diff --git a/src/datasystem/pybind_api/pybind_register_object.cpp b/src/datasystem/pybind_api/pybind_register_object.cpp index 55232c5dc6dbbb586e6b7fbe7c5081753369d0ce..b5139ac12d4e21d5259161ddf0017315bde95246 100644 --- a/src/datasystem/pybind_api/pybind_register_object.cpp +++ b/src/datasystem/pybind_api/pybind_register_object.cpp @@ -126,7 +126,7 @@ PybindDefineRegisterer g_pybind_define_f_Client("ObjectClient", PRIORITY_LOW, [] .def(py::init([](const std::string &host, int32_t port, int32_t connectTimeoutMs, const std::string &clientPublicKey, const std::string &clientPrivateKey, const std::string &serverPublicKey, const std::string &accessKey, const std::string &secretKey, - const std::string &tenantId) { + const std::string &tenantId, const bool enableExclusiveConnection) { ConnectOptions connectOpts{ .host = host, .port = port, .connectTimeoutMs = connectTimeoutMs, @@ -135,8 +135,8 @@ PybindDefineRegisterer g_pybind_define_f_Client("ObjectClient", PRIORITY_LOW, [] .serverPublicKey = serverPublicKey, .accessKey = accessKey, .secretKey = secretKey, - .tenantId = tenantId, - .enableCrossNodeConnection = false }; + .tenantId = tenantId}; + connectOpts.enableExclusiveConnection = enableExclusiveConnection; return std::make_unique(connectOpts); })) diff --git a/src/datasystem/server/common_server.h b/src/datasystem/server/common_server.h index 7e4fe91d44d0e6445dc336a14c6ec52ae6886cc6..e7c4ec2c257024400a933327189983487ee0e64f 100644 --- a/src/datasystem/server/common_server.h +++ b/src/datasystem/server/common_server.h @@ -30,6 +30,7 @@ #include "datasystem/common/rpc/rpc_helper.h" #include "datasystem/common/rpc/rpc_server.h" #include "datasystem/common/rpc/rpc_channel.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/status_helper.h" #ifdef WITH_TESTS @@ -113,8 +114,9 @@ public: * @param[out] id The id of this shmUnit. * @return Status of the call. */ - virtual Status GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, - std::string &id) = 0; + virtual Status GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, ShmKey &id) = 0; + + virtual Status GetExclConnSockPath(std::string &sockPath) = 0; protected: /** diff --git a/src/datasystem/worker/CMakeLists.txt b/src/datasystem/worker/CMakeLists.txt index 2057b59a35a875be64ed073ac2132088b26e9a08..ed2ad971fd90b0227ff21112948ea540819bea01 100644 --- a/src/datasystem/worker/CMakeLists.txt +++ b/src/datasystem/worker/CMakeLists.txt @@ -29,6 +29,7 @@ set(WORKER_DEPEND_LIBS common_util common_metrics common_signal + string_ref ds_master ds_server httpclient diff --git a/src/datasystem/worker/authenticate.cpp b/src/datasystem/worker/authenticate.cpp index f37e19b38ac1820d5aabb06211edea572b8b4d3d..d04d88d296a4e3aa7972698d614414340f427196 100644 --- a/src/datasystem/worker/authenticate.cpp +++ b/src/datasystem/worker/authenticate.cpp @@ -20,6 +20,8 @@ #include "datasystem/worker/authenticate.h" +DS_DEFINE_bool(skip_authenticate, false, "hack to skip authenticate"); + namespace datasystem { namespace worker { @@ -63,4 +65,4 @@ Status AuthenticateMessageInternal(std::shared_ptr akSkManager, con } } -} \ No newline at end of file +} diff --git a/src/datasystem/worker/authenticate.h b/src/datasystem/worker/authenticate.h index b8b327550a22e595d07b5d8395a2c3d272a91446..5f5dc6f6ac099021084df825211c49b7ad5d497a 100644 --- a/src/datasystem/worker/authenticate.h +++ b/src/datasystem/worker/authenticate.h @@ -31,19 +31,32 @@ #include "datasystem/common/ak_sk/ak_sk_manager.h" #include "datasystem/common/iam/tenant_auth_manager.h" +#include "datasystem/common/util/gflag/common_gflags.h" #include "datasystem/common/util/thread_local.h" #include "datasystem/worker/client_manager/client_manager.h" +DS_DECLARE_bool(skip_authenticate); + namespace datasystem { namespace worker { Status AuthenticateMessageInternal(std::shared_ptr akSkManager, const std::string &reqTenantId, const std::string &token, std::string &tenantId); +inline Status CheckTenantId(const std::string &reqTenantId) +{ + CHECK_FAIL_RETURN_STATUS(reqTenantId.empty(), K_INVALID, + "Don't request worker with tenantId, when enable skip_authenticate "); + return Status::OK(); +} + template Status AuthenticateRequest(std::shared_ptr akSkManager, const ReqType &req, const std::string &reqTenantId, std::string &tenantId) { + if (FLAGS_skip_authenticate) { + return CheckTenantId(reqTenantId); + } if (!g_ReqAk.empty() && !g_ReqSignature.empty() && !g_SerializedMessage.Empty()) { return worker::AuthenticateMessageInternal(akSkManager, reqTenantId, req.token(), tenantId); } @@ -75,9 +88,12 @@ Status AuthenticateRequest(std::shared_ptr akSkManager, const ReqTy } template -Status Authenticate(std::shared_ptr akSkManager, ReqType req, std::string &tenantId, +Status Authenticate(std::shared_ptr akSkManager, const ReqType &req, std::string &tenantId, const std::string &clientId) { + if (FLAGS_skip_authenticate) { + return CheckTenantId(req.tenant_id()); + } Timer timer; std::string authTenantId = req.tenant_id(); auto clientInfo = worker::ClientManager::Instance().GetClientInfo(clientId); @@ -92,9 +108,12 @@ Status Authenticate(std::shared_ptr akSkManager, ReqType req, std:: } template -Status AuthenticateMessage(std::shared_ptr akSkManager, ReqType req, const std::string &clientId, +Status AuthenticateMessage(std::shared_ptr akSkManager, const ReqType &req, const std::string &clientId, std::string &tenantId) { + if (FLAGS_skip_authenticate) { + return CheckTenantId(req.tenant_id()); + } Timer timer; std::string authTenantId = req.tenant_id(); auto clientInfo = worker::ClientManager::Instance().GetClientInfo(clientId); @@ -109,8 +128,11 @@ Status AuthenticateMessage(std::shared_ptr akSkManager, ReqType req } template -Status Authenticate(std::shared_ptr akSkManager, ReqType req, std::string &tenantId) +Status Authenticate(std::shared_ptr akSkManager, const ReqType &req, std::string &tenantId) { + if (FLAGS_skip_authenticate) { + return CheckTenantId(req.tenant_id()); + } return !g_ReqAk.empty() && !g_ReqSignature.empty() && !g_SerializedMessage.Empty() ? AuthenticateMessage(akSkManager, req, req.client_id(), tenantId) : Authenticate(akSkManager, req, tenantId, req.client_id()); diff --git a/src/datasystem/worker/client_manager/client_info.cpp b/src/datasystem/worker/client_manager/client_info.cpp index f8019c483ca8972bbaf68f4d2072ab533daaed2a..ccb1e175777de88c7ef27de2a08713af87ba024e 100644 --- a/src/datasystem/worker/client_manager/client_info.cpp +++ b/src/datasystem/worker/client_manager/client_info.cpp @@ -24,6 +24,7 @@ #include "datasystem/common/flags/flags.h" #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/uuid_generator.h" DS_DEFINE_uint64(client_dead_timeout_s, 120, @@ -149,13 +150,13 @@ bool ClientInfo::RemoveShmUnit(const std::shared_ptr &shmUnit) } } else { LOG(WARNING) << "RemoveShmUnit: The value of refCount is 0 and cannot be decreased. id:" - << BytesUuidToString(shmUnit->id); + << BytesUuidToString(shmUnit->id.ToString()); } return true; } auto itr = shmUnitIds_.find(shmUnit->id); if (itr == shmUnitIds_.end()) { - LOG(WARNING) << "RemoveShmUnit: The ID does not exist. id:" << BytesUuidToString(shmUnit->id); + LOG(WARNING) << "RemoveShmUnit: The ID does not exist. id:" << BytesUuidToString(shmUnit->id.ToString()); return false; } if (shmUnit->refCount > 0) { @@ -166,7 +167,7 @@ bool ClientInfo::RemoveShmUnit(const std::shared_ptr &shmUnit) } } else { LOG(WARNING) << "RemoveShmUnit: The value of refCount is 0 and cannot be decreased. id:" - << BytesUuidToString(shmUnit->id); + << BytesUuidToString(shmUnit->id.ToString()); } itr->second -= 1; if (itr->second == 0) { @@ -178,15 +179,17 @@ bool ClientInfo::RemoveShmUnit(const std::shared_ptr &shmUnit) return true; } -bool ClientInfo::Contains(const std::string &uuid) const +bool ClientInfo::Contains(const ShmKey &shmId) const { - return shmUnitIds_.find(uuid) != shmUnitIds_.end(); + return shmUnitIds_.find(shmId) != shmUnitIds_.end(); } -void ClientInfo::GetShmUnitIds(std::unordered_map &shmUnitIds) +#ifdef WITH_TESTS +void ClientInfo::GetShmUnitIds(std::unordered_map &shmUnitIds) { shmUnitIds = shmUnitIds_; } +#endif void ClientInfo::GetReaderSessionIds(std::unordered_set &sessionIds) const { diff --git a/src/datasystem/worker/client_manager/client_info.h b/src/datasystem/worker/client_manager/client_info.h index 9c819ee90ba0253957d9377a9192f93e09c0a4a3..7ab9d3bec41fafa8fce484ab87c3e759ffd956fd 100644 --- a/src/datasystem/worker/client_manager/client_info.h +++ b/src/datasystem/worker/client_manager/client_info.h @@ -28,6 +28,7 @@ #include #include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/timer.h" @@ -98,16 +99,18 @@ public: /** * @brief Check one object whether be referred by client. - * @param[in] uuid Shared memory unit id of object. + * @param[in] shmId Shared memory unit id of object. * @return True if contains uuid. */ - bool Contains(const std::string &uuid) const; + bool Contains(const ShmKey &shmId) const; +#ifdef WITH_TESTS /** * @brief Get all share memory unit ids referred by client. * @param[out] shmUnitIds Shared memory unit id of objects. */ - void GetShmUnitIds(std::unordered_map &shmUnitIds); + void GetShmUnitIds(std::unordered_map &shmUnitIds); +#endif /** * @brief Add the reader session id used by the client. @@ -249,7 +252,7 @@ private: Timer lastHeartbeat_; // Time received the last heartbeat. std::atomic heartbeatType_{ HeartbeatType::NO_HEARTBEAT }; std::unordered_set fds_; - std::unordered_map shmUnitIds_; + std::unordered_map shmUnitIds_; std::unordered_set readerSessionTable_; // session id of readers std::unordered_set writerSessionTable_; // session id of writers std::string tenantId_; diff --git a/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h b/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h index 6ad7289e416e2eba44e82d007c65b8a3456c0528..68ea9fac030374ced872a4507f0e197f0bd98e28 100644 --- a/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h +++ b/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h @@ -22,6 +22,7 @@ #define DATASYSTEM_WORKER_OBJECT_CACHE_CLUSTER_MANGER_H #include +#include #include #include #include @@ -289,33 +290,23 @@ public: * @param[in] objectKeys Container(Vector or list) of objectkeys * @param[out] objKeysGrpByMaster map with master as key and objectkeys belong to the master as value * @param[out] objKeysUndecidedMaster IDs without known master in hashring - * @return Status */ template - Status GroupObjKeysByMasterHostPort( + void GroupObjKeysByMasterHostPort( const container &objectKeys, std::unordered_map> &objKeysGrpByMaster, std::unordered_map> &objKeysUndecidedMaster) { - MetaAddrInfo emptyInfo; objKeysGrpByMaster = GroupObjKeysByMasterHostPort(objectKeys); - auto it = objKeysGrpByMaster.find(emptyInfo); - if (it != objKeysGrpByMaster.end()) { - LOG(INFO) << "Some objectKeys can't find address."; - auto objKeys = std::move(it->second); - it = objKeysGrpByMaster.erase(it); - for (const auto &objKey : objKeys) { - std::string workerId; - (void)TrySplitWorkerIdFromObjecId(objKey, workerId); - auto iter = objKeysUndecidedMaster.find(workerId); - if (iter == std::end(objKeysUndecidedMaster)) { - std::unordered_set objectKeyList({ objKey }); - objKeysUndecidedMaster.insert(std::make_pair(workerId, std::move(objectKeyList))); - } else { - iter->second.emplace(objKey); - } - } + auto emptyIt = objKeysGrpByMaster.find(MetaAddrInfo()); + if (emptyIt == objKeysGrpByMaster.end()) { + return; } - return Status::OK(); + for (auto &objKey : emptyIt->second) { + auto workerId = SplitWorkerIdFromObjecId(objKey); + auto &con = objKeysUndecidedMaster.try_emplace(std::move(workerId)).first->second; + (void)con.emplace(std::move(objKey)); + } + (void)objKeysGrpByMaster.erase(emptyIt); } /** @@ -328,9 +319,9 @@ public: { // go through objectKeys and group them by master and db name. std::unordered_map> objKeysGrpByMaster; - std::unordered_map errInfos; Timer timer; - GroupObjKeysByMasterHostPortWithStatus(objectKeys, objKeysGrpByMaster, errInfos); + std::optional> emptyOption; + GroupObjKeysByMasterHostPortWithStatus(objectKeys, objKeysGrpByMaster, emptyOption); auto elapsedMs = static_cast(std::round(timer.ElapsedMilliSecond())); workerOperationTimeCost.Append("GroupObjKeys", elapsedMs); return objKeysGrpByMaster; @@ -345,23 +336,59 @@ public: template void GroupObjKeysByMasterHostPortWithStatus( const container &objectKeys, std::unordered_map> &objKeysGrpByMaster, - std::unordered_map &errInfos) + std::optional> &errInfos) { // go through objectKeys and group them by master and db name. - for (const std::string &objectKey : objectKeys) { + for (auto &objectKey : objectKeys) { MetaAddrInfo metaAddrInfo; - auto rc = GetMetaAddress(objectKey, metaAddrInfo); - if (rc.IsError()) { - (void)errInfos.emplace(objectKey, rc); - VLOG(1) << FormatString("objKey[%s] can not find master, status: %s", objectKey, rc.ToString()); + auto rc = GetMetaAddressNotCheckConnection(objectKey, metaAddrInfo); + auto &con = objKeysGrpByMaster.try_emplace(std::move(metaAddrInfo)).first->second; + con.emplace_back(std::move(objectKey)); + if (rc.IsOk()) { + continue; } - auto iter = objKeysGrpByMaster.find(metaAddrInfo); - if (iter == std::end(objKeysGrpByMaster)) { - std::vector objectKeyList({ objectKey }); - objKeysGrpByMaster.insert(std::make_pair(metaAddrInfo, std::move(objectKeyList))); + if (errInfos) { + (void)errInfos->emplace(objectKey, rc); + } + VLOG(1) << FormatString("objKey[%s] can not find master, status: %s", objectKey, rc.ToString()); + } + static const auto checkConnectionFunc = [](EtcdClusterManager *ptr, + const MetaAddrInfo &metaAddrInfo) -> Status { + const auto &masterAddr = metaAddrInfo.GetAddress(); + if (metaAddrInfo.IsFromOtherAz()) { + CHECK_FAIL_RETURN_STATUS(ptr->CheckIfOtherAzNodeConnected(masterAddr), K_RPC_UNAVAILABLE, + FormatString("The other az node %s disconnected.", masterAddr.ToString())); } else { - iter->second.push_back(objectKey); + return ptr->CheckConnection(masterAddr); + } + return Status::OK(); + }; + auto emptyIt = objKeysGrpByMaster.end(); // Iterator for the key of the target node not found. + for (auto it = objKeysGrpByMaster.begin(); it != objKeysGrpByMaster.end();) { + const auto &metaAddrInfo = it->first; + const auto &objectKeys = it->second; + if (metaAddrInfo.Empty()) { + emptyIt = it; + ++it; + continue; + } + Status rc = checkConnectionFunc(this, metaAddrInfo); + if (rc.IsOk()) { + ++it; + continue; + } + if (errInfos) { + for (const auto &objectKey : objectKeys) { + (void)errInfos->emplace(objectKey, rc); + } } + if (emptyIt == objKeysGrpByMaster.end()) { + emptyIt = objKeysGrpByMaster.try_emplace(MetaAddrInfo()).first; + } + auto &con = emptyIt->second; + con.insert(con.end(), std::make_move_iterator(it->second.begin()), + std::make_move_iterator(it->second.end())); + it = objKeysGrpByMaster.erase(it); } } diff --git a/src/datasystem/worker/object_cache/CMakeLists.txt b/src/datasystem/worker/object_cache/CMakeLists.txt index 0254af33e2cb394ca5cf29666271b2f4c5675801..08a9a3843470b679093bbf56e80785178c78f6fe 100644 --- a/src/datasystem/worker/object_cache/CMakeLists.txt +++ b/src/datasystem/worker/object_cache/CMakeLists.txt @@ -42,6 +42,7 @@ set(WORKER_OBJECT_CACHE_DEPEND_LIBS common_shared_memory common_immutable_string common_persistence_api + common_parallel master_heartbeat_protos posix_protos worker_object_protos diff --git a/src/datasystem/worker/object_cache/async_rollback_manager.h b/src/datasystem/worker/object_cache/async_rollback_manager.h index 8a9d45c7a1efaa72916bcceb327864317b7adf01..86431e6158860e1b4efecabe667e9e202cc84092 100644 --- a/src/datasystem/worker/object_cache/async_rollback_manager.h +++ b/src/datasystem/worker/object_cache/async_rollback_manager.h @@ -26,6 +26,7 @@ #include "datasystem/common/util/thread.h" #include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" #include "datasystem/worker/object_cache/worker_master_oc_api.h" +#include "datasystem/worker/object_cache/worker_request_manager.h" namespace datasystem { namespace object_cache { @@ -83,6 +84,17 @@ public: return false; } + void UpdateIsRollback(std::unordered_map &objectKeys) + { + std::shared_lock lock(mutex_); + for (auto &[objectKey, objectInfo] : objectKeys) { + if (pendingObject_.find(objectKey) != pendingObject_.end() + || processingObject_.find(objectKey) != processingObject_.end()) { + objectInfo.isRollBack = true; + } + } + } + private: /** * @brief Rollback thread responsible for rollback metadata of pendingObject_. diff --git a/src/datasystem/worker/object_cache/device/worker_device_oc_manager.cpp b/src/datasystem/worker/object_cache/device/worker_device_oc_manager.cpp index e8634548107bc53537c34baa9dd7eb2e9e150de2..9710e86f9168bb6574fb7c33202808f2de59c41a 100644 --- a/src/datasystem/worker/object_cache/device/worker_device_oc_manager.cpp +++ b/src/datasystem/worker/object_cache/device/worker_device_oc_manager.cpp @@ -23,6 +23,7 @@ #include "datasystem/common/device/device_helper.h" #include "datasystem/common/object_cache/object_bitmap.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/rpc_util.h" #include "datasystem/protos/master_object.pb.h" #include "datasystem/protos/worker_object.pb.h" @@ -58,7 +59,8 @@ Status WorkerDeviceOcManager::PublishDeviceObject(const std::string &devObjectKe devObj->SetOffset(req.offset()); devObj->stateInfo.SetDataFormat(DataFormat::HETERO); entry->SetRealObject(std::move(devObj)); - workerOcImpl_->publishProc_->AttachShmUnitToObject(req.client_id(), req.dev_object_key(), req.shm_id(), + workerOcImpl_->publishProc_->AttachShmUnitToObject(WorkerOcServiceCreateImpl::ClientShmEnabled(req.client_id()), + req.dev_object_key(), ShmKey::Intern(req.shm_id()), req.data_size(), *entry); } else { CHECK_FAIL_RETURN_STATUS(!(*entry)->IsHetero(), K_INVALID, @@ -94,7 +96,7 @@ Status WorkerDeviceOcManager::CreateDeviceMetaToMaster(const ObjectKV &objectKV) metadata->set_object_key(objectKey); metadata->set_data_size(safeObj->GetDataSize()); - ObjectMetaPb::ConfigPb *configPb = metadata->mutable_config(); + ConfigPb *configPb = metadata->mutable_config(); configPb->set_data_format(static_cast(safeObj->stateInfo.GetDataFormat())); metaReq.set_address(workerOcImpl_->localAddress_.ToString()); @@ -220,14 +222,14 @@ Status WorkerDeviceOcManager::TryGetDeviceObjectFromRemote(const int64_t subTime std::shared_ptr &request, std::vector &objectsNeedGetRemote) { - std::vector needGetIds; + std::set needGetIds; for (const auto &id : objectsNeedGetRemote) { - needGetIds.emplace_back(ReadKey(id)); + (void)needGetIds.emplace(ReadKey(id)); } if (!objectsNeedGetRemote.empty()) { PerfPoint pointRemote(PerfKey::WORKER_PROCESS_GET_OBJECT_REMOTE); std::unordered_set failedIds; - std::vector needRetryIds; + std::set needRetryIds; Status status; do { needRetryIds.clear(); @@ -235,18 +237,19 @@ Status WorkerDeviceOcManager::TryGetDeviceObjectFromRemote(const int64_t subTime workerOcImpl_->getProc_->ProcessObjectsNotExistInLocal(needGetIds, subTimeout, failedIds, needRetryIds); if (status.IsOk()) { break; - } else if (status.GetCode() == K_OUT_OF_MEMORY || reqTimeoutDuration.CalcRealRemainingTime() <= 0) { + } + if (status.GetCode() == K_OUT_OF_MEMORY || reqTimeoutDuration.CalcRealRemainingTime() <= 0) { std::for_each(needRetryIds.begin(), needRetryIds.end(), - [&](ReadKey &key) { failedIds.emplace(key.objectKey); }); + [&failedIds](const ReadKey &key) { failedIds.emplace(key.objectKey); }); break; - } else { - needGetIds.swap(needRetryIds); } + needGetIds.swap(needRetryIds); } while (!needGetIds.empty()); pointRemote.Record(); if (status.GetCode() == K_OUT_OF_MEMORY) { return deviceReqManager_.ReturnFromGetDeviceObjectRequest(request, status); - } else if (status.IsError()) { + } + if (status.IsError()) { // If the error is RPC error, return them directly, other error would be covered up as RUNTIME_ERROR. Status lastRc = IsRpcTimeoutOrTryAgain(status) ? status : Status(K_RUNTIME_ERROR, status.GetMsg()); for (const auto &id : failedIds) { @@ -310,4 +313,4 @@ Status WorkerDeviceOcManager::ProcessGetDataInfoRequest( return workerMasterApi->GetDataInfo(req, serverApi, subTimeout, workerOcImpl_->asyncRpcManager_); } } // namespace object_cache -} // namespace datasystem \ No newline at end of file +} // namespace datasystem diff --git a/src/datasystem/worker/object_cache/eviction_list.cpp b/src/datasystem/worker/object_cache/eviction_list.cpp index a7001d0fe6985c323596c0fd300b815e057ba87b..09fd8a5ac77abde59dce73600608060876a0c9d6 100644 --- a/src/datasystem/worker/object_cache/eviction_list.cpp +++ b/src/datasystem/worker/object_cache/eviction_list.cpp @@ -15,6 +15,7 @@ */ #include "datasystem/worker/object_cache/eviction_list.h" + #include "datasystem/common/log/log.h" #include "datasystem/common/perf/perf_manager.h" #include "datasystem/common/util/status_helper.h" @@ -28,20 +29,17 @@ EvictionList::EvictionList() : oldest_(list_.end()) void EvictionList::Add(const std::string &objectKey, uint8_t counter) { PerfPoint point(PerfKey::WORKER_EVICT_LIST_ADD); - std::lock_guard lck(listMutex_); - auto iter = indexTable_.find(objectKey); - if (iter == indexTable_.end()) { - Node node(objectKey, counter); - auto newest = list_.insert(oldest_, node); + TBBIndexMap::accessor accessor; + if (indexTable_.insert(accessor, objectKey)) { + tbb::spin_rw_mutex::scoped_lock wlock(listMutex_, true); + auto newest = list_.emplace(oldest_, objectKey, counter); if (list_.size() == 1) { oldest_ = newest; } - (void)indexTable_.emplace(objectKey, newest); - point.Record(); - return; + accessor->second = newest; } - // Object exist, refresh it - auto &nodePtr = iter->second; + + auto &nodePtr = accessor->second; if (nodePtr->curCounter < nodePtr->maxCounter) { nodePtr->curCounter++; } @@ -51,12 +49,13 @@ void EvictionList::Add(const std::string &objectKey, uint8_t counter) Status EvictionList::Erase(const std::string &objectKey) { PerfPoint point(PerfKey::WORKER_EVICT_LIST_ERASE); - std::lock_guard lck(listMutex_); - auto iter = indexTable_.find(objectKey); - if (iter == indexTable_.end()) { + TBBIndexMap::accessor accessor; + if (!indexTable_.find(accessor, objectKey)) { VLOG(1) << "Object " + objectKey + " does not exist in EvictionList"; RETURN_STATUS(StatusCode::K_NOT_FOUND, "Object " + objectKey + " does not exist in EvictionList."); } + + tbb::spin_rw_mutex::scoped_lock wlock(listMutex_, true); bool reassign = false; if (oldest_->objectKey == objectKey) { ++oldest_; @@ -64,8 +63,8 @@ Status EvictionList::Erase(const std::string &objectKey) reassign = true; } } - list_.erase(iter->second); - indexTable_.erase(objectKey); + list_.erase(accessor->second); + indexTable_.erase(accessor); if (reassign) { oldest_ = list_.begin(); } @@ -75,22 +74,21 @@ Status EvictionList::Erase(const std::string &objectKey) size_t EvictionList::Size() { - std::shared_lock lck(listMutex_); + tbb::spin_rw_mutex::scoped_lock rlock(listMutex_, false); return list_.size(); } Status EvictionList::FindEvictCandidate(std::string &candidateObjKey) { PerfPoint point(PerfKey::WORKER_EVICT_LIST_FIND); - std::lock_guard lck(listMutex_); + tbb::spin_rw_mutex::scoped_lock wlock(listMutex_, true); CHECK_FAIL_RETURN_STATUS(!list_.empty(), StatusCode::K_RUNTIME_ERROR, "EvictionList is empty."); while (true) { if (oldest_->curCounter == 0) { candidateObjKey = oldest_->objectKey; break; - } else { - oldest_->curCounter--; } + oldest_->curCounter--; if (++oldest_ == list_.end()) { oldest_ = list_.begin(); } @@ -101,17 +99,18 @@ Status EvictionList::FindEvictCandidate(std::string &candidateObjKey) Status EvictionList::GetObjectInfo(const std::string &objectKey, Node &node) { - std::lock_guard lck(listMutex_); - if (indexTable_.find(objectKey) == indexTable_.end()) { + TBBIndexMap::const_accessor readAccessor; + if (!indexTable_.find(readAccessor, objectKey)) { RETURN_STATUS_LOG_ERROR(StatusCode::K_NOT_FOUND, "Object " + objectKey + " does not exist"); } - node = *(indexTable_[objectKey]); + tbb::spin_rw_mutex::scoped_lock rlock(listMutex_, false); + node = *(readAccessor->second); return Status::OK(); } Status EvictionList::GetOldestObjectInfo(Node &node) { - std::lock_guard lck(listMutex_); + tbb::spin_rw_mutex::scoped_lock rlock(listMutex_, false); CHECK_FAIL_RETURN_STATUS(!list_.empty(), StatusCode::K_RUNTIME_ERROR, "EvictionList is empty."); node.objectKey = oldest_->objectKey; node.curCounter = oldest_->curCounter; @@ -121,7 +120,7 @@ Status EvictionList::GetOldestObjectInfo(Node &node) Status EvictionList::GetAllObjectsInfo(std::vector &res, EvictionList::Node &oldest) { - std::lock_guard lck(listMutex_); + tbb::spin_rw_mutex::scoped_lock rlock(listMutex_, false); if (list_.empty()) { return Status::OK(); } @@ -145,8 +144,7 @@ Status EvictionList::GetAllObjectsInfo(std::vector &res, Evi bool EvictionList::Exist(const std::string &objectKey) { - std::shared_lock lck(listMutex_); return indexTable_.count(objectKey) > 0; } } // namespace object_cache -} // namespace datasystem \ No newline at end of file +} // namespace datasystem diff --git a/src/datasystem/worker/object_cache/eviction_list.h b/src/datasystem/worker/object_cache/eviction_list.h index e42d8b658c449fe997368afa03d625ebc1165f82..dc84ebf155f70fb7f9ed5f65bef9d1745f371b95 100644 --- a/src/datasystem/worker/object_cache/eviction_list.h +++ b/src/datasystem/worker/object_cache/eviction_list.h @@ -48,6 +48,7 @@ public: uint8_t curCounter; uint8_t maxCounter; }; + using TBBIndexMap = tbb::concurrent_hash_map::iterator>; /** * @brief Construct EvictionList. @@ -114,16 +115,12 @@ public: bool Exist(const std::string &objectKey); private: - std::shared_timed_mutex listMutex_; - + mutable tbb::spin_rw_mutex listMutex_; std::list list_; - - // unordered_map - std::unordered_map::iterator> indexTable_; - std::list::iterator oldest_; + TBBIndexMap indexTable_; }; } // namespace object_cache } // namespace datasystem -#endif \ No newline at end of file +#endif diff --git a/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp b/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp index b1092f8299786fb3620e2319aa9ccfc86284265b..2876102b9bb21b0e7f794b5e2147fcef8a009a80 100644 --- a/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp +++ b/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp @@ -22,6 +22,7 @@ #include "datasystem/common/constants.h" #include "datasystem/common/iam/tenant_auth_manager.h" #include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/strings_util.h" #include "datasystem/common/util/uuid_generator.h" @@ -38,6 +39,8 @@ LOG(INFO) << "try get shm for payload " << currCnt_ << " times"; \ } while (0) +DS_DECLARE_uint64(oc_worker_aggregate_merge_size); + namespace datasystem { namespace object_cache { @@ -234,12 +237,10 @@ Status AggregateAllocate( std::vector &shmIndexMapping) { // Pre-allocate aggregated chunks of shared memory as ShmOwner, to reduce the number of allocation calls. - // 1. Only for URMA case, non-URMA case allocates the memory when response is handled. - // 2. Aggregate only if all the objects are small objects (< 1MB size), and batch up to 1024 keys and 2MB size. - // 3. Multi-tenancy not yet supported. + // Aggregate only for small objects (< 1MB size), and batch up to 1024 keys and 2MB size. const uint64_t batchLimitKeys = 1024; const uint64_t batchLimitSingleSize = 1024 * 1024; - const uint64_t batchLimitTotalSize = 2 * 1024 * 1024; + const uint64_t batchLimitTotalSize = FLAGS_oc_worker_aggregate_merge_size; bool needAggregate = false; std::vector aggreatedSizes; @@ -247,7 +248,7 @@ Status AggregateAllocate( uint64_t currentKeyCount = 0; std::function aggregateCollector = - [&](uint64_t dataSz, uint64_t shmSize, uint32_t objectId) { + [&](uint64_t dataSz, uint64_t shmSize, uint32_t objectIndex) { // Skip any object that has size beyond 1MB. if (dataSz >= batchLimitSingleSize) { return; @@ -262,7 +263,7 @@ Status AggregateAllocate( // Record the size and num, and also map from object key to ShmOwners index. currentBatchSize += shmSize; currentKeyCount++; - shmIndexMapping[objectId] = aggreatedSizes.size(); + shmIndexMapping[objectIndex] = aggreatedSizes.size(); }; traversalHelper(aggregateCollector, needAggregate); @@ -292,7 +293,7 @@ Status AllocateNewShmUnit(const std::string &objectKey, uint64_t dataSize, uint6 shmUnit = std::make_shared(); RETURN_IF_NOT_OK( AllocateMemoryForObject(objectKey, dataSize, metadataSize, populate, evictionManager, *shmUnit, cacheType)); - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); return Status::OK(); } @@ -322,7 +323,7 @@ Status LoadSpilledObjectToMemory(ReadObjectKV &objectKV, std::shared_ptrstateInfo.SetIncompleted(true); return Status(K_OUT_OF_MEMORY, "out of memory"); }); - newShmUnit->id = GetStringUuid(); + newShmUnit->id = ShmKey::Intern(GetStringUuid()); objShmUnit->SetShmUnit(newShmUnit); } bool isOffsetRead = objectKV.IsOffsetRead(); @@ -380,7 +381,7 @@ Status SaveBinaryObjectToMemory(ObjectKV &objectKV, const std::vector(); RETURN_IF_NOT_OK(AllocateMemoryForObject(objectKey, payloadSz, metaSz, false, evictionManager, *shmUnit, entry->modeInfo.GetCacheType())); - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); entry->SetShmUnit(shmUnit); } // There is no need to latch buffer because client can't access the buffer this moment. diff --git a/src/datasystem/worker/object_cache/object_kv.cpp b/src/datasystem/worker/object_cache/object_kv.cpp index 1e986c0f5b587ae7e9bd23540ddf5035b0944188..00cef7cedf84cfaba7ec98fbf26a74ae079be1d4 100644 --- a/src/datasystem/worker/object_cache/object_kv.cpp +++ b/src/datasystem/worker/object_cache/object_kv.cpp @@ -79,7 +79,7 @@ void SetDeviceObjEntry(const ObjectMetaPb &meta, uint64_t metaDataSize, SafeObjT void SetObjectEntryAccordingToMeta(const ObjectMetaPb &meta, uint64_t metaDataSize, SafeObjType &entry) { const std::string &objectKey = meta.object_key(); - const ObjectMetaPb::ConfigPb &configPb = meta.config(); + const ConfigPb &configPb = meta.config(); auto dataFormat = static_cast(configPb.data_format()); if (dataFormat == DataFormat::HETERO) { SetDeviceObjEntry(meta, metaDataSize, entry); diff --git a/src/datasystem/worker/object_cache/object_kv.h b/src/datasystem/worker/object_cache/object_kv.h index 037044dd6e95d7cad96241dea294322ce592e027..2b13c30a3980120e03488d1e56f15c88f0b73fcd 100644 --- a/src/datasystem/worker/object_cache/object_kv.h +++ b/src/datasystem/worker/object_cache/object_kv.h @@ -45,8 +45,9 @@ public: /** * @brief Construct ObjectKV. */ - ObjectKV(const std::string &objectKey, SafeObjType &entry) - : objectKey_(objectKey), entry_(entry) {} + ObjectKV(const std::string &objectKey, SafeObjType &entry) : objectKey_(objectKey), entry_(entry) + { + } ObjectKV(const std::string &objectKey, std::nullptr_t) = delete; // Disable all copy and move constructors. @@ -84,8 +85,12 @@ private: }; struct ReadKey : public OffsetInfo { - explicit ReadKey(std::string objectKey, uint64_t offset = 0, uint64_t size = 0) - : OffsetInfo(offset, size), objectKey(std::move(objectKey)) + explicit ReadKey(const std::string &objectKey, uint64_t offset = 0, uint64_t size = 0) + : OffsetInfo(offset, size), objectKey(objectKey) + { + } + + ReadKey(const std::string &objectKey, OffsetInfo offsetInfo) : OffsetInfo(offsetInfo), objectKey(objectKey) { } @@ -97,7 +102,18 @@ struct ReadKey : public OffsetInfo { } return out; } - std::string objectKey; + + bool operator<(const ReadKey &other) const + { + return objectKey < other.objectKey; + } + + OffsetInfo GetOffsetInfo() const + { + return OffsetInfo(readOffset, readSize); + } + + const std::string &objectKey; }; class ReadObjectKV : public ObjectKV, protected OffsetInfo { diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp index e52b3d60c3293586072a3def47d3f9221049376e..8da58d592f654f215529aaddb114254b980e5abe 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp @@ -63,7 +63,7 @@ Status WorkerOcServiceGetImpl::BatchGetRetrieveRemotePayload(uint64_t completeDa auto shmUnit = std::make_shared(); RETURN_IF_NOT_OK(AllocateMemoryForObject(objectKey, completeDataSize, metaSz, false, evictionManager_, *shmUnit, entry->modeInfo.GetCacheType())); - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); entry->SetShmUnit(shmUnit); } PerfPoint copyPoint(PerfKey::WORKER_MEMORY_COPY); @@ -100,12 +100,14 @@ void WorkerOcServiceGetImpl::HandleGetFailureHelper(const std::string &objectKey } } -Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( - std::vector &queryMetas, const std::map &readKeys, - const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds) +Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched(std::vector &queryMetas, + const std::shared_ptr &request, + std::vector &payloads, + std::map &lockedEntries, + std::unordered_set &failedIds, + std::set &needRetryIds) { + RETURN_RUNTIME_ERROR_IF_NULL(workerBatchRemoteGetThreadPool_); Status lastRc = Status::OK(); std::vector successIds; successIds.reserve(queryMetas.size()); @@ -115,6 +117,7 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( std::vector payloadIndexMetas; for (auto &queryMeta : queryMetas) { const auto &meta = queryMeta.meta(); + const auto &objectKey = meta.object_key(); const auto dataFormat = static_cast(meta.config().data_format()); if (dataFormat != DataFormat::BINARY && dataFormat != DataFormat::HETERO) { lastRc = Status(K_INVALID, "object data format not match."); @@ -122,29 +125,25 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( LOG(ERROR) << lastRc; continue; } - auto iter = lockedEntries.find(meta.object_key()); + auto iter = lockedEntries.find(ReadKey(objectKey)); if (iter == lockedEntries.end()) { LOG(ERROR) << FormatString("[ObjectKey %s] QueryMeta exist but lock entry absent, should not happen", - meta.object_key()); + objectKey); lastRc = Status(K_UNKNOWN_ERROR, "QueryMeta exist but lock entry absent, should not happen"); continue; } - if (readKeys.find(meta.object_key()) == readKeys.end()) { - LOG(ERROR) << FormatString("[ObjectKey %s] cant find offset and size to get", meta.object_key()); - lastRc = Status(K_UNKNOWN_ERROR, "Can not find offset or size to get object"); - continue; - } + auto &safeObj = iter->second.safeObj; if (queryMeta.payload_indexs_size() != 0) { payloadIndexMetas.emplace_back(queryMeta); } else { GroupQueryMeta(queryMeta, groupedQueryMetas); - SetObjectEntryAccordingToMeta(meta, GetMetadataSize(), *(lockedEntries.at(meta.object_key()).first)); + SetObjectEntryAccordingToMeta(meta, GetMetadataSize(), *safeObj); } } // For the ones that already got their payload from queried meta, fallback to existing logic. - lastRc = GetObjectsFromAnywhereSerially(payloadIndexMetas, readKeys, request, payloads, lockedEntries, failedIds, - needRetryIds); + lastRc = + GetObjectsFromAnywhereSerially(payloadIndexMetas, request, payloads, lockedEntries, failedIds, needRetryIds); // And then deal with the requests that can be batched. std::vector> futures; @@ -153,25 +152,28 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( std::vector> tempSuccessIds(groupedQueryMetas.size()); std::vector> tempNeedRetryIds(groupedQueryMetas.size()); std::vector> tempFailedIds(groupedQueryMetas.size()); - int index = 0; - auto workerBatchThreadPool_ = std::make_shared(1, WORKER_BATCH_THREAD_NUM, "OcWorkerBatch"); + size_t index = 0; auto traceId = Trace::Instance().GetTraceID(); - for (auto queryMeta = groupedQueryMetas.begin(); queryMeta != groupedQueryMetas.end(); ++queryMeta, ++index) { auto &address = queryMeta->first; auto &metaList = queryMeta->second; - futures.emplace_back(workerBatchThreadPool_->Submit([this, &lastRc, address, &metaList, readKeys, &request, - &lockedEntries, &tempSuccessIds, &tempNeedRetryIds, - &tempFailedIds, &tempFailedMetas, index, traceId] { + + auto func = [this, &lastRc, address, &metaList, &request, &lockedEntries, &tempSuccessIds, + &tempNeedRetryIds, &tempFailedIds, &tempFailedMetas, index, traceId] { for (auto &metaPair : metaList) { auto &metas = metaPair.first; TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); - lastRc = BatchGetObjectFromRemoteOnLock(address, metas, readKeys, request, lockedEntries, + lastRc = BatchGetObjectFromRemoteOnLock(address, metas, request, lockedEntries, tempSuccessIds[index], tempNeedRetryIds[index], tempFailedIds[index], tempFailedMetas[index]); } return lastRc; - })); + }; + if (index + 1 == groupedQueryMetas.size()) { + LOG_IF_ERROR(func(), "BatchGetObjectFromRemoteOnLock failed"); + } else { + futures.emplace_back(workerBatchRemoteGetThreadPool_->Submit(std::move(func))); + } } for (auto &fut : futures) { if (!fut.get().IsOk()) { @@ -185,7 +187,7 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( std::make_move_iterator(tempSuccessIds[i].end())); } if (!tempNeedRetryIds[i].empty()) { - needRetryIds.insert(needRetryIds.end(), std::make_move_iterator(tempNeedRetryIds[i].begin()), + needRetryIds.insert(std::make_move_iterator(tempNeedRetryIds[i].begin()), std::make_move_iterator(tempNeedRetryIds[i].end())); } if (!tempFailedIds[i].empty()) { @@ -196,11 +198,11 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( } } auto metaIter = failedMetas.begin(); - for (int i = 0; metaIter != failedMetas.end(); i++) { + while (metaIter != failedMetas.end()) { auto &objectKey = (*metaIter)->object_key(); - auto &pair = lockedEntries.at(objectKey); - auto &entry = pair.first; - bool isInsert = pair.second; + auto &pair = lockedEntries.at(ReadKey(objectKey)); + auto &entry = pair.safeObj; + bool isInsert = pair.insert; HandleGetFailureHelper(objectKey, (*metaIter)->version(), entry, isInsert); metaIter++; } @@ -235,25 +237,25 @@ void WorkerOcServiceGetImpl::GroupQueryMeta( splitList.back().second += meta.data_size(); } -void WorkerOcServiceGetImpl::BatchGetObjectHandleIndividualStatus(Status &status, const std::string &objectKey, - ReadKey readKey, std::vector &successIds, +void WorkerOcServiceGetImpl::BatchGetObjectHandleIndividualStatus(Status &status, const ReadKey &readKey, + std::vector &successIds, std::vector &needRetryIds, std::unordered_set &failedIds) { if (status.IsOk()) { - successIds.emplace_back(objectKey); + successIds.emplace_back(readKey.objectKey); } else if (status.GetCode() == K_WORKER_PULL_OBJECT_NOT_FOUND) { - LOG(INFO) << FormatString("[ObjectKey %s] Object not found in remote worker.", objectKey); + LOG(INFO) << FormatString("[ObjectKey %s] Object not found in remote worker.", readKey.objectKey); status = Status::OK(); needRetryIds.emplace_back(readKey); } else if (status.GetCode() == K_OC_REMOTE_GET_NOT_ENOUGH) { // Note that it gets retried at BatchGetObjectFromRemoteWorker, so do not need to add to needRetryIds. - LOG(INFO) << FormatString("[ObjectKey %s] Object size changed, needs retry.", objectKey); + LOG(INFO) << FormatString("[ObjectKey %s] Object size changed, needs retry.", readKey.objectKey); } else if (status.GetCode() == K_OUT_OF_MEMORY) { - LOG(INFO) << FormatString("[ObjectKey %s] Out of memory, get remote abort.", objectKey); + LOG(INFO) << FormatString("[ObjectKey %s] Out of memory, get remote abort.", readKey.objectKey); } else { - LOG(ERROR) << FormatString("[ObjectKey %s] Get from remote failed: %s.", objectKey, status.ToString()); - failedIds.emplace(objectKey); + LOG(ERROR) << FormatString("[ObjectKey %s] Get from remote failed: %s.", readKey.objectKey, status.ToString()); + failedIds.emplace(readKey.objectKey); } } @@ -321,8 +323,7 @@ void WorkerOcServiceGetImpl::HandleBatchSubResponsePart2(Status &subRc, const st Status WorkerOcServiceGetImpl::ProcessBatchResponse( const std::string &address, Status &checkConnectStatus, std::list &metas, - const std::map &readKeys, const std::shared_ptr &request, - std::map, bool>> &lockedEntries, const Status &status, + const std::shared_ptr &request, std::map &lockedEntries, const Status &status, BatchGetObjectRemoteRspPb &rspPb, std::vector &payloads, std::vector &successIds, std::vector &needRetryIds, std::unordered_set &failedIds, std::list &failedMetas, bool &dataSizeChange) @@ -332,9 +333,12 @@ Status WorkerOcServiceGetImpl::ProcessBatchResponse( auto metaIter = metas.begin(); for (int i = 0; metaIter != metas.end(); i++) { auto &objectKey = (*metaIter)->object_key(); - auto &pair = lockedEntries.at(objectKey); - auto &entry = pair.first; - auto &readKey = readKeys.at(objectKey); + auto iter = lockedEntries.find(ReadKey(objectKey)); + if (iter == lockedEntries.cend()) { + continue; + } + auto &entry = iter->second.safeObj; + const auto &readKey = iter->first; ReadObjectKV objectKV(readKey, *entry); Status subRc = status; bool tryGetFromElsewhere = true; @@ -371,7 +375,7 @@ Status WorkerOcServiceGetImpl::ProcessBatchResponse( objectKey)); } if (subRc.IsOk()) { - subRc = UpdateRequestForSuccessNotReturnForClient(objectKV, request); + subRc = UpdateRequestForSuccess(objectKV, request); } if (!dataSizeChanged) { if (subRc.IsError()) { @@ -382,18 +386,17 @@ Status WorkerOcServiceGetImpl::ProcessBatchResponse( dataSizeChange = true; metaIter++; } - BatchGetObjectHandleIndividualStatus(subRc, objectKey, readKey, successIds, needRetryIds, failedIds); + BatchGetObjectHandleIndividualStatus(subRc, readKey, successIds, needRetryIds, failedIds); lastRc = subRc; } return lastRc; } Status WorkerOcServiceGetImpl::BatchGetObjectFromRemoteWorker( - const std::string &address, std::list &metas, const std::map &readKeys, - const std::shared_ptr &request, - std::map, bool>> &lockedEntries, - std::vector &successIds, std::vector &needRetryIds, - std::unordered_set &failedIds, std::list &failedMetas) + const std::string &address, std::list &metas, const std::shared_ptr &request, + std::map &lockedEntries, std::vector &successIds, + std::vector &needRetryIds, std::unordered_set &failedIds, + std::list &failedMetas) { bool dataSizeChange; Status lastRc; @@ -418,8 +421,8 @@ Status WorkerOcServiceGetImpl::BatchGetObjectFromRemoteWorker( CHECK_FAIL_RETURN_STATUS(checkConnectStatus.IsOk(), K_RUNTIME_ERROR, FormatString("Fail to get objects from remote worker, no object copy exists.")); INJECT_POINT("worker.before_GetObjectFromRemoteWorkerAndDump"); - RETURN_IF_NOT_OK(ConstructBatchGetRequest(address, metas, readKeys, lockedEntries, successIds, needRetryIds, - failedIds, reqPb)); + RETURN_IF_NOT_OK( + ConstructBatchGetRequest(address, metas, lockedEntries, successIds, needRetryIds, failedIds, reqPb)); INJECT_POINT("worker.remote_get_failed"); std::shared_ptr workerStub; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CreateRemoteWorkerApi(address, akSkManager_, workerStub), @@ -448,18 +451,17 @@ Status WorkerOcServiceGetImpl::BatchGetObjectFromRemoteWorker( PerfPoint point(PerfKey::WORKER_CONSTRUCT_AND_SEND); Status rc = constructAndSend(); point.Record(); - lastRc = ProcessBatchResponse(address, checkConnectStatus, metas, readKeys, request, lockedEntries, rc, rspPb, - payloads, successIds, needRetryIds, failedIds, failedMetas, dataSizeChange); + lastRc = ProcessBatchResponse(address, checkConnectStatus, metas, request, lockedEntries, rc, rspPb, payloads, + successIds, needRetryIds, failedIds, failedMetas, dataSizeChange); } while (dataSizeChange); return lastRc; } Status WorkerOcServiceGetImpl::BatchGetObjectFromRemoteOnLock( - const std::string &address, std::list &metas, const std::map &readKeys, - const std::shared_ptr &request, - std::map, bool>> &lockedEntries, - std::vector &successIds, std::vector &needRetryIds, - std::unordered_set &failedIds, std::list &failedMetas) + const std::string &address, std::list &metas, const std::shared_ptr &request, + std::map &lockedEntries, std::vector &successIds, + std::vector &needRetryIds, std::unordered_set &failedIds, + std::list &failedMetas) { PerfPoint point(PerfKey::WORKER_PULL_REMOTE_DATA); // Unlock entries at exit. @@ -467,34 +469,33 @@ Status WorkerOcServiceGetImpl::BatchGetObjectFromRemoteOnLock( for (auto &meta : metas) { const auto &objectKey = meta->object_key(); RemoveInRemoteGetObject(objectKey); - lockedEntries.at(objectKey).first->WUnlock(); + lockedEntries.at(ReadKey(objectKey)).safeObj->WUnlock(); } }); // Construct and send request for batch remote get. Timer endToEndTimer; - Status rc = BatchGetObjectFromRemoteWorker(address, metas, readKeys, request, lockedEntries, successIds, - needRetryIds, failedIds, failedMetas); + Status rc = BatchGetObjectFromRemoteWorker(address, metas, request, lockedEntries, successIds, needRetryIds, + failedIds, failedMetas); LOG(INFO) << FormatString("object get from remote finish, use %f millisecond.", endToEndTimer.ElapsedMilliSecond()); return rc; } -Status WorkerOcServiceGetImpl::AggregateAllocateHelper( - const std::list &metas, - std::map, bool>> &lockedEntries, - std::vector> &shmOwners, std::vector &shmIndexMapping) +Status WorkerOcServiceGetImpl::AggregateAllocateHelper(const std::list &metas, + std::map &lockedEntries, + std::vector> &shmOwners, + std::vector &shmIndexMapping) { std::function, bool &)> traversalHelper = - [&metas, &lockedEntries](std::function collector, - bool &needAggregate) { + [&metas, &lockedEntries](std::function collector, bool &needAggregate) { needAggregate = metas.size() > 1; - uint32_t objectId = 0; - for (auto metaIter = metas.begin(); metaIter != metas.end() && needAggregate; metaIter++, objectId++) { + uint32_t objectIndex = 0; + for (auto metaIter = metas.begin(); metaIter != metas.end() && needAggregate; metaIter++, objectIndex++) { auto &meta = *metaIter; auto dataSz = meta->data_size(); const auto &objectKey = meta->object_key(); - auto &pair = lockedEntries.at(objectKey); - auto &entry = *(pair.first); + auto &lockedEntity = lockedEntries.at(ReadKey(objectKey)); + auto &entry = *(lockedEntity.safeObj); auto metaSz = entry->GetMetadataSize(); uint64_t shmSize = dataSz + metaSz; @@ -504,12 +505,11 @@ Status WorkerOcServiceGetImpl::AggregateAllocateHelper( if (!szChanged) { continue; } - collector(dataSz, shmSize, objectId); + collector(dataSz, shmSize, objectIndex); } }; auto firstObjectKey = metas.front()->object_key(); - RETURN_IF_NOT_OK( - AggregateAllocate(firstObjectKey, traversalHelper, evictionManager_, shmOwners, shmIndexMapping)); + RETURN_IF_NOT_OK(AggregateAllocate(firstObjectKey, traversalHelper, evictionManager_, shmOwners, shmIndexMapping)); return Status::OK(); } } // namespace object_cache diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp index 1ad6d41a225c93cb9cedbdf90ac9130f91e3db33..fdc6b02764240edf35c949c7a4345ba8a26f8c73 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp @@ -21,8 +21,10 @@ #include "datasystem/common/log/log.h" #include "datasystem/common/iam/tenant_auth_manager.h" +#include "datasystem/common/parallel/parallel_for.h" #include "datasystem/common/perf/perf_manager.h" #include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/uuid_generator.h" @@ -94,7 +96,7 @@ Status WorkerOcServiceCreateImpl::CreateImpl(const std::string &tenantId, const std::string shmUnitId; IndexUuidGenerator(shmIdCounter.fetch_add(1), shmUnitId); - shmUnit->id = shmUnitId; + shmUnit->id = ShmKey::Intern(std::move(shmUnitId)); memoryRefTable_->AddShmUnit(clientId, shmUnit); // Construct CreateRespPb. @@ -107,26 +109,100 @@ Status WorkerOcServiceCreateImpl::CreateImpl(const std::string &tenantId, const return Status::OK(); } -Status WorkerOcServiceCreateImpl::MultiCreate(const MultiCreateReqPb &req, MultiCreateRspPb &resp) +Status WorkerOcServiceCreateImpl::AggregateAllocateHelper(const MultiCreateReqPb &req, + std::vector> &shmOwners, + std::vector &shmIndexMapping) +{ + const size_t metaSz = GetMetadataSize(); + std::function, bool &)> traversalHelper = + [&req, &metaSz](const std::function &collector, bool &needAggregate) { + needAggregate = req.object_key_size() > 1; + for (int i = 0; i < req.object_key_size(); i++) { + collector(req.data_size(i), req.data_size(i) + metaSz, i); + } + }; + const auto &firstObjectKey = *req.object_key().begin(); + return AggregateAllocate(firstObjectKey, traversalHelper, evictionManager_, shmOwners, shmIndexMapping); +} + +Status WorkerOcServiceCreateImpl::MultiCreateImpl(const MultiCreateReqPb &req, const std::string &tenantId, + MultiCreateRspPb &resp) +{ + int objectSize = req.object_key_size(); + std::vector shmIndexMapping(req.object_key_size(), std::numeric_limits::max()); + std::vector> shmOwners; + AggregateAllocateHelper(req, shmOwners, shmIndexMapping); + std::vector subRsp(objectSize); + std::vector results(objectSize); + + auto createMeta = [&] (int start, int end) { + std::vector> shmUnits(end - start + 1); + for (int i = start, j = 0; i < end; i++, j++) { + if (!req.skip_check_existence() && resp.exists(i)) { + continue; + } + const auto &objectKey = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.object_key(i)); + + std::shared_ptr shmOwner = nullptr; + if (shmIndexMapping.size() > static_cast(i) && shmOwners.size() > shmIndexMapping[i]) { + shmOwner = shmOwners[shmIndexMapping[i]]; + } + // Given size, construct shmUnit, generate shm uuid and add client's reference on shmUnit. + auto shmUnit = std::make_shared(); + auto metadataSize = GetMetadataSize(); + auto dataSize = req.data_size(i); + if (shmOwner) { + results[i] = DistributeMemoryForObject(objectKey, dataSize, metadataSize, true, shmOwner, *shmUnit); + } else { + results[i] = + AllocateMemoryForObject(objectKey, dataSize, metadataSize, true, evictionManager_, *shmUnit); + } + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(results[i], "worker allocate memory failed"); + + std::string shmUnitId; + IndexUuidGenerator(shmIdCounter.fetch_add(1), shmUnitId); + shmUnit->id = ShmKey::Intern(shmUnitId); + shmUnits[j] = shmUnit; + + // Construct CreateRespPb. + CreateRspPb subResp; + subRsp[i].set_store_fd(shmUnit->GetFd()); + subRsp[i].set_mmap_size(shmUnit->GetMmapSize()); + subRsp[i].set_offset(shmUnit->GetOffset()); + subRsp[i].set_shm_id(shmUnit->GetId()); + subRsp[i].set_metadata_size(metadataSize); + } + memoryRefTable_->AddShmUnits(req.client_id(), shmUnits); + return Status::OK(); + }; + + static const int parallelThreshold = 128; + static const int parallism = 4; + if (objectSize > parallelThreshold) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Parallel::ParallelFor(0, objectSize, createMeta, 0, parallism), + "ParallelFor failed"); + } else { + createMeta(0, objectSize); + } + + resp.mutable_results()->Reserve(objectSize); + for (int i = 0; i < objectSize; i++) { + RETURN_IF_NOT_OK(results[i]); + resp.mutable_results()->Add(std::move(subRsp[i])); + } + return Status::OK(); +} + +void WorkerOcServiceCreateImpl::CheckExistence(const MultiCreateReqPb &req, const std::string &tenantId, + MultiCreateRspPb &resp) { - CHECK_FAIL_RETURN_STATUS(etcdCM_ != nullptr, StatusCode::K_NOT_READY, "ETCD cluster manager is not provided."); - std::string tenantId; - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker::Authenticate(akSkManager_, req, tenantId), "Authenticate failed."); - CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(req.object_key_size()), - StatusCode::K_INVALID, "invalid object size"); - CHECK_FAIL_RETURN_STATUS(req.object_key_size() == req.data_size_size(), K_INVALID, - FormatString("object key count %zu not match with data size count %zu", - req.object_key_size(), req.data_size_size())); - auto lastRc = Status::OK(); - auto totalSize = 0u; for (int i = 0; i < req.object_key().size(); i++) { - auto objectKey = req.object_key(i); - auto dataSize = req.data_size(i); + const auto &objectKey = req.object_key(i); // Check whether the object is in local. { auto key = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, objectKey); std::shared_ptr entry; - if (!req.skip_check_existence() && objectTable_->Get(key, entry).IsOk() && entry->RLock(false).IsOk()) { + if (objectTable_->Get(key, entry).IsOk() && entry->RLock(false).IsOk()) { Raii unlock([&entry]() { entry->RUnlock(); }); if ((*entry)->IsBinary() && !(*entry)->IsInvalid()) { resp.add_exists(true); @@ -135,39 +211,32 @@ Status WorkerOcServiceCreateImpl::MultiCreate(const MultiCreateReqPb &req, Multi } } resp.add_exists(false); - totalSize += dataSize; } - for (int i = 0; i < req.object_key().size(); i++) { - if (resp.exists(i) || totalSize < FLAGS_oc_shm_transfer_threshold_kb * KB) { - resp.add_results(); - continue; - } - const auto &objectKey = req.object_key(i); - auto dataSize = req.data_size(i); - // If some buffer create failed, need to rollback, remove shm-unit. - CreateRspPb subResp; - Status rc = CreateImpl(tenantId, req.client_id(), objectKey, dataSize, subResp); - INJECT_POINT("WorkerOCServiceImpl.MultiCreate.Allocate", [&i, &rc](int failedIndex) { - if (failedIndex == i) { - rc = Status(StatusCode::K_RUNTIME_ERROR, "Set runtime error"); - } - return Status::OK(); - }); - resp.mutable_results()->Add(std::move(subResp)); - if (rc.IsError()) { - lastRc = rc; - break; - } +} + +Status WorkerOcServiceCreateImpl::MultiCreate(const MultiCreateReqPb &req, MultiCreateRspPb &resp) +{ + CHECK_FAIL_RETURN_STATUS(etcdCM_ != nullptr, StatusCode::K_NOT_READY, "ETCD cluster manager is not provided."); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker::Authenticate(akSkManager_, req, tenantId), "Authenticate failed."); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(req.object_key_size()), StatusCode::K_INVALID, + "invalid object size"); + CHECK_FAIL_RETURN_STATUS(req.object_key_size() == req.data_size_size(), K_INVALID, + FormatString("object key count %zu not match with data size count %zu", + req.object_key_size(), req.data_size_size())); + if (!req.skip_check_existence()) { + CheckExistence(req, tenantId, resp); } - // Rollback all memory if failed. - if (lastRc.IsError()) { + Status rc = MultiCreateImpl(req, tenantId, resp); + if (rc.IsError()) { + // Rollback all memory if failed. const auto &clientId = req.client_id(); for (auto &subResp : resp.results()) { - memoryRefTable_->RemoveShmUnit(clientId, subResp.shm_id()); + memoryRefTable_->RemoveShmUnit(clientId, ShmKey::Intern(subResp.shm_id())); } resp.Clear(); } - return lastRc; + return rc; } } // namespace object_cache } // namespace datasystem diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.h b/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.h index e228f0a870124e62ff4255db02bd06033c37ba6e..a4ddd3bac309998e0227c6b093b617b32b3bbad4 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.h +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.h @@ -61,6 +61,33 @@ private: Status CreateImpl(const std::string &tenantId, const std::string &clientId, const std::string &rawObjectKey, size_t dataSize, CreateRspPb &resp, CacheType cacheType = CacheType::MEMORY); + /** + * @brief Helper function to allocate aggregated memory for objects creation. + * @param[in] req The multi create request. + * @param[out] shmOwners The allocated shared memory chunks. + * @param[out] shmIndexMapping The object key to shmOwners index mapping. + * @return Status of the call. + */ + Status AggregateAllocateHelper(const MultiCreateReqPb &req, std::vector> &shmOwners, + std::vector &shmIndexMapping); + + /** + * @brief The implementation of MultiCreate. + * @param[in] req The MultiCreate request. + * @param[in] tenantId The tenant id. + * @param[out] resp The MultiCreate response. + * @return Status of the call. + */ + Status MultiCreateImpl(const MultiCreateReqPb &req, const std::string &tenantId, MultiCreateRspPb &resp); + + /** + * @brief Check existence of object. + * @param[in] req The MultiCreate request. + * @param[in] tenantId The tenant id. + * @param[in] resp The MultiCreate response. + */ + void CheckExistence(const MultiCreateReqPb &req, const std::string &tenantId, MultiCreateRspPb &resp); + EtcdClusterManager *etcdCM_{ nullptr }; // back pointer to the cluster manager std::atomic shmIdCounter{0}; diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.cpp index adac8341803f4dc672b149bc0a54c4c2ec813555..17c820264100bb95933b21734862d1a757b16880 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.cpp @@ -22,6 +22,7 @@ #include +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/thread_local.h" #include "datasystem/common/log/log.h" @@ -135,29 +136,15 @@ Status WorkerOcServiceCrudCommonApi::SaveBinaryObjectToPersistence(ObjectKV &obj return Status::OK(); } -Status WorkerOcServiceCrudCommonApi::UpdateRequestForSuccess(ReadObjectKV &objectKV) +Status WorkerOcServiceCrudCommonApi::UpdateRequestForSuccess(ReadObjectKV &objectKV, + const std::shared_ptr &request) { const auto dataFormat = objectKV.GetObjEntry()->stateInfo.GetDataFormat(); if (dataFormat == DataFormat::BINARY) { - return workerRequestManager_.UpdateRequestForSuccess(objectKV, memoryRefTable_, false); - } - if (dataFormat == DataFormat::HETERO) { - return workerDevOcManager_->UpdateRequestForSuccess(objectKV); - } - RETURN_STATUS(K_INVALID, "The dataformat is neither BINARY nor HETERO"); -} - -Status WorkerOcServiceCrudCommonApi::UpdateRequestForSuccessNotReturnForClient( - ReadObjectKV &objectKV, const std::shared_ptr &request) -{ - const auto dataFormat = objectKV.GetObjEntry()->stateInfo.GetDataFormat(); - if (dataFormat == DataFormat::BINARY) { - if (objectKV.GetObjEntry()->stateInfo.IsIncomplete()) { - // for not complete obj, only return for this request. - return workerRequestManager_.UpdateRequestForSuccess(objectKV, memoryRefTable_, true, request); - } else { - return workerRequestManager_.UpdateRequestForSuccess(objectKV, memoryRefTable_, true); + if (request != nullptr) { + return request->MarkSuccess(objectKV.GetObjKey(), objectKV.GetObjEntry()); } + return workerRequestManager_.NotifyPendingGetRequest(objectKV); } if (dataFormat == DataFormat::HETERO) { return workerDevOcManager_->UpdateRequestForSuccess(objectKV); @@ -165,20 +152,6 @@ Status WorkerOcServiceCrudCommonApi::UpdateRequestForSuccessNotReturnForClient( RETURN_STATUS(K_INVALID, "The dataformat is neither BINARY nor HETERO"); } -void WorkerOcServiceCrudCommonApi::ReturnToClientByRequest(const std::shared_ptr &request) -{ - if (request != nullptr) { - LOG_IF_ERROR(workerRequestManager_.ReturnFromGetRequest(request, memoryRefTable_), "return to client failed"); - // Avoid timeCostPoint destruct after traceGuard. - request->accessRecorderPoint_.reset(); - } -} - -void WorkerOcServiceCrudCommonApi::ReturnToClientByObjectKey(const std::string &objectKey) -{ - workerRequestManager_.CheckAndReturnToClient(objectKey, memoryRefTable_); -} - Status WorkerOcServiceCrudCommonApi::DeleteObjectFromDisk(ObjectKV &objectKV) { const auto &objectKey = objectKV.GetObjKey(); @@ -222,13 +195,13 @@ size_t WorkerOcServiceCrudCommonApi::GetMetadataSize() const return metadataSize_; } -Status WorkerOcServiceCrudCommonApi::AttachShmUnitToObject(const std::string &clientId, const std::string &objectKey, - const std::string &shmUnitId, uint64_t dataSize, +Status WorkerOcServiceCrudCommonApi::AttachShmUnitToObject(const bool &shmEnabled, const std::string &objectKey, + const ShmKey &shmUnitId, uint64_t dataSize, SafeObjType &entry) { INJECT_POINT("AttachShmUnitToObject.error"); std::shared_ptr shmUnit; - if (ClientShmEnabled(clientId) && ShmEnable() && !shmUnitId.empty()) { + if (shmEnabled && ShmEnable() && !shmUnitId.Empty()) { RETURN_IF_NOT_OK(memoryRefTable_->GetShmUnit(shmUnitId, shmUnit)); } else { // non-shm case, create first @@ -242,12 +215,13 @@ Status WorkerOcServiceCrudCommonApi::AttachShmUnitToObject(const std::string &cl } Status WorkerOcServiceCrudCommonApi::CheckShmUnitByTenantId(const std::string &tenantId, const std::string &clientId, - std::vector &shmUnitIds, + std::vector &shmUnitIds, std::shared_ptr memoryRefTable) { + RETURN_OK_IF_TRUE(!ClientShmEnabled(clientId)); for (const auto &shmUnitId : shmUnitIds) { std::shared_ptr shmUnit; - if (ClientShmEnabled(clientId) && !shmUnitId.empty()) { + if (!shmUnitId.Empty()) { RETURN_IF_NOT_OK(memoryRefTable->GetShmUnit(shmUnitId, shmUnit)); if (tenantId != shmUnit->GetTenantId()) { LOG(ERROR) << FormatString("req tenantId: %s is not equal shmUnit tenantId: %s", tenantId, diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.h b/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.h index 6b552999a1d8f1301a96f39fffa71f0d291ab38f..c0087581891b6104a4ca1ee5169382b950c2838f 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.h +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.h @@ -20,6 +20,7 @@ #ifndef DATASYSTEM_OBJECT_CACHE_WORKER_SERVICE_CRUD_COMMON_API_H #define DATASYSTEM_OBJECT_CACHE_WORKER_SERVICE_CRUD_COMMON_API_H +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/utils/status.h" #include "datasystem/common/l2cache/persistence_api.h" @@ -170,43 +171,22 @@ public: /** * @brief Attach shmUnit to object entry - * @param[in] clientId The client id. + * @param[in] shmEnabled Enable shm or not. * @param[in] objectKey The object key * @param[in] shmUnitId The shm unit id. * @param[in] dataSize The size of data * @param[out] entry The object entry * @return OK if attach success. */ - Status AttachShmUnitToObject(const std::string &clientId, const std::string &objectKey, - const std::string &shmUnitId, uint64_t dataSize, SafeObjType &entry); + Status AttachShmUnitToObject(const bool &shmEnabled, const std::string &objectKey, const ShmKey &shmUnitId, + uint64_t dataSize, SafeObjType &entry); /** * @brief Update the request if object is getting success. * @param[in] objectKV The key-value of the object. * @return OK if update success. */ - Status UpdateRequestForSuccess(ReadObjectKV &objectKV); - - /** - * @brief only update request for success not return for client. - * @param objectKV The key-value of the object. - * @param request request from client - * @return Status - */ - Status UpdateRequestForSuccessNotReturnForClient(ReadObjectKV &objectKV, - const std::shared_ptr &request = nullptr); - - /** - * @brief Return request to client - * @param[in] request request from client - */ - void ReturnToClientByRequest(const std::shared_ptr &request); - - /** - * @brief Return request to client - * @param[in] object object key to check return. - */ - void ReturnToClientByObjectKey(const std::string &objectKey); + Status UpdateRequestForSuccess(ReadObjectKV &objectKV, const std::shared_ptr &request = nullptr); /** * @brief CheckShmUnitByTenantId @@ -216,7 +196,7 @@ public: * @return Status */ static Status CheckShmUnitByTenantId(const std::string &tenantId, const std::string &clientId, - std::vector &shmUnitIds, + std::vector &shmUnitIds, std::shared_ptr memoryRefTable); /** diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.cpp index 1f3d1ec41649cf5a59e8dd4fcedd3b4c293650a3..dcf48f3a92af55f165dd1038257d17d1d205175b 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.cpp @@ -52,7 +52,7 @@ Status WorkerOcServiceDeleteImpl::DeleteAllCopy(const DeleteAllCopyReqPb &req, D uint64_t deletedSize = 0; Status rc = DeleteAllCopyImpl(req, resp, deletedSize); RequestParam reqParam; - reqParam.objectKey = objectKeysToAbbrStr(req.object_keys()); + reqParam.objectKey = ObjectKeysToAbbrStr(req.object_keys()); posixPoint.Record(rc.GetCode(), std::to_string(deletedSize), reqParam, rc.GetMsg()); workerOperationTimeCost.Append("Total DeleteAllCopy", timer.ElapsedMilliSecond()); LOG(INFO) << FormatString("The operations of DeleteAllCopy %s", workerOperationTimeCost.GetInfo()); @@ -280,14 +280,15 @@ Status WorkerOcServiceDeleteImpl::DeleteAllCopyMetaFromMaster(const std::vector< Status lastRc; // Group ObjectKeys by masterId std::unordered_map> objKeysGrpByMasterId; - std::unordered_map errInfos; + std::optional> errInfos; + errInfos.emplace(); etcdCM_->GroupObjKeysByMasterHostPortWithStatus(needDeleteObjectKey, objKeysGrpByMasterId, errInfos); std::unordered_map> crossAzOfflineWorkerIdKeys; // map - ExtractCrossAzOfflineWorkerIdKeyWithEmptyAddress(objKeysGrpByMasterId, errInfos, crossAzOfflineWorkerIdKeys); + ExtractCrossAzOfflineWorkerIdKeyWithEmptyAddress(objKeysGrpByMasterId, *errInfos, crossAzOfflineWorkerIdKeys); DeleteCrossAzKeyWhenMasterFailed(crossAzOfflineWorkerIdKeys); - for (const auto &kv : errInfos) { + for (const auto &kv : *errInfos) { // If objectKey don't belong to any master, just ignore it. if (kv.second.GetCode() != K_NOT_FOUND) { failedObjectKeys.emplace(kv.first); diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp index 648e425d983eeb8e813c1760f8d9c2bb90160c37..e809eff2208f454ccbddbab22ed926fddfea7b92 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp @@ -69,7 +69,7 @@ Status WorkerOcServiceExpireImpl::Expire(const ExpireReqPb &req, ExpireRspPb &rs std::unordered_map> objKeysGrpByMaster; std::unordered_map> objKeysUndecidedMaster; - RETURN_IF_NOT_OK(etcdCM_->GroupObjKeysByMasterHostPort(objectKeys, objKeysGrpByMaster, objKeysUndecidedMaster)); + etcdCM_->GroupObjKeysByMasterHostPort(objectKeys, objKeysGrpByMaster, objKeysUndecidedMaster); std::unordered_set objKeysExpireFailed; std::vector absentObjectKeys; diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp index 95871f02fa8c7b5c85130f3ed05cba18fae18809..7d914f87391157adf604f5b5437d339142d574df 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp @@ -30,6 +30,7 @@ #include "datasystem/common/log/access_recorder.h" #include "datasystem/common/log/log.h" #include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/master/object_cache/master_worker_oc_api.h" #include "datasystem/object/object_enum.h" #include "datasystem/common/rdma/urma_manager_wrapper.h" @@ -49,6 +50,7 @@ #include "datasystem/worker/authenticate.h" #include "datasystem/common/util/id_tool.h" #include "datasystem/worker/object_cache/object_kv.h" +#include "datasystem/worker/object_cache/worker_request_manager.h" #include "datasystem/worker/object_cache/worker_worker_oc_api.h" DS_DECLARE_string(other_cluster_names); @@ -85,13 +87,17 @@ WorkerOcServiceGetImpl::WorkerOcServiceGetImpl(WorkerOcServiceCrudParam &initPar } } } + workerBatchQueryMetaThreadPool_ = std::make_unique(1, FLAGS_rpc_thread_num, "BatchQureyMeta"); + if (FLAGS_enable_worker_worker_batch_get) { + workerBatchRemoteGetThreadPool_ = std::make_unique(1, FLAGS_rpc_thread_num, "BatchRemoteGet"); + } } Status WorkerOcServiceGetImpl::Get(std::shared_ptr> &serverApi) { workerOperationTimeCost.Clear(); Timer timer; - std::shared_ptr posixPoint = std::make_shared(AccessRecorderKey::DS_POSIX_GET); + auto request = std::make_shared(AccessRecorderKey::DS_POSIX_GET); INJECT_POINT("WorkerOCServiceImpl.Get.Retry", [&serverApi]() { return serverApi->SendStatus(Status(K_TRY_AGAIN, "test get retry")); }); PerfPoint point(PerfKey::WORKER_GET_OBJECT); @@ -101,10 +107,7 @@ Status WorkerOcServiceGetImpl::Get(std::shared_ptr(req.read_offset_list_size()); - uint64_t readSizeCount = static_cast(req.read_size_list_size()); - CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(objectsCount == readOffsetCount || readOffsetCount == 0, K_INVALID, - "invalid read offset"); - CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(objectsCount == readSizeCount || readSizeCount == 0, K_INVALID, - "invalid read size"); - std::unordered_map offsetInfos; - for (size_t i = 0; i < readOffsetCount; i++) { - std::string objectKey = objectKeys[i]; - offsetInfos.emplace(std::move(objectKey), OffsetInfo(req.read_offset_list(i), req.read_size_list(i))); - } + + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(request->Init(tenantId, req, memoryRefTable_, serverApi), + "GetRequest Init failed"); timer.Reset(); std::string traceID = Trace::Instance().GetTraceID(); auto cost = workerOperationTimeCost; - threadPool_->Execute([=]() mutable { + if (serverApi->EnableMsgQ()) { + threadPool_->Execute([=]() mutable { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); + workerOperationTimeCost = cost; + auto elapsed = static_cast(timer.ElapsedMilliSecond()); + LOG(INFO) << "Process Get from client: " << clientId + << ", objects: " << VectorToString(request->GetRawObjectKeys()) + << ", get threads Statistics: " << threadPool_->GetStatistics() << ", elapsed ms: " << elapsed + << ", remainingTime: " << timeout; + if (elapsed >= timeout) { + LOG(ERROR) << "RPC timeout. time elapsed " << elapsed << ", subTimeout:" << subTimeout + << ", get threads Statistics: " << threadPool_->GetStatistics(); + LOG_IF_ERROR(serverApi->SendStatus(Status(K_RUNTIME_ERROR, "Rpc timeout")), "Send status failed"); + } else { + reqTimeoutDuration.Init(timeout - elapsed); + auto newSubTimeout = std::max(subTimeout - elapsed, 0); + LOG_IF_ERROR(ProcessGetObjectRequest(newSubTimeout, request), "Process Get failed"); + workerOperationTimeCost.Append("ProcessGetObjectRequest", + static_cast(timer.ElapsedMilliSecond())); + LOG(INFO) << FormatString( + "Process Get done, clientId: %s, objectKeys: %s, get threads Statistics: %s." + "The operations of worker Get %s", + clientId, VectorToString(request->GetRawObjectKeys()), threadPool_->GetStatistics(), + workerOperationTimeCost.GetInfo()); + } + }); + } else { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); workerOperationTimeCost = cost; - int64_t elapsed = timer.ElapsedMilliSecond(); - LOG(INFO) << "Process Get from client: " << clientId << ", objects: " << VectorToString(objectKeys) - << ", get threads Statistics: " << threadPool_->GetStatistics() << ", elapsed ms: " << elapsed - << ", remainingTime: " << timeout; - if (elapsed >= timeout) { - LOG(ERROR) << "RPC timeout. time elapsed " << elapsed << ", subTimeout:" << subTimeout - << ", get threads Statistics: " << threadPool_->GetStatistics(); - LOG_IF_ERROR(serverApi->SendStatus(Status(K_RUNTIME_ERROR, "Rpc timeout")), "Send status failed"); - } else { - reqTimeoutDuration.Init(timeout - elapsed); - auto newSubTimeout = std::max(subTimeout - elapsed, 0); - LOG_IF_ERROR( - ProcessGetObjectRequest(objectKeys, offsetInfos, serverApi, newSubTimeout, clientId, posixPoint, req), - "Process Get failed"); - workerOperationTimeCost.Append("ProcessGetObjectRequest", timer.ElapsedMilliSecond()); - LOG(INFO) << FormatString( - "Process Get done, clientId: %s, objectKeys: %s, get threads Statistics: %s." - "The operations of worker Get %s", - clientId, VectorToString(objectKeys), threadPool_->GetStatistics(), workerOperationTimeCost.GetInfo()); - } - posixPoint.reset(); - }); + reqTimeoutDuration.Init(timeout); + LOG_IF_ERROR(ProcessGetObjectRequest(subTimeout, request), "Process Get failed"); + workerOperationTimeCost.Append("ProcessGetObjectRequest", timer.ElapsedMilliSecond()); + } + return Status::OK(); } @@ -184,9 +186,7 @@ Status WorkerOcServiceGetImpl::GetObjectFromAnywhere(const ReadKey &readKey, con } else { evictionManager_->Add(objectKey); } - RETURN_IF_NOT_OK(UpdateRequestForSuccess(objectKV)); - ReturnToClientByObjectKey(objectKey); - return Status::OK(); + return UpdateRequestForSuccess(objectKV, nullptr); } SetObjectEntryAccordingToMeta(meta, GetMetadataSize(), *entry); @@ -194,7 +194,7 @@ Status WorkerOcServiceGetImpl::GetObjectFromAnywhere(const ReadKey &readKey, con ReadObjectKV objectKV(readKeyAfterSet, *entry); Status status = queryMeta.payload_indexs_size() == 0 ? GetObjectFromRemoteOnLock(meta, nullptr, address, queryMeta.single_copy(), objectKV) - : GetObjectFromQueryMetaResultOnLock(queryMeta, payloads, objectKV); + : GetObjectFromQueryMetaResultOnLock(nullptr, queryMeta, payloads, objectKV); if (status.IsError()) { (void)RemoveLocation(objectKey, meta.version()); if (entry->Get() != nullptr && entry->Get()->GetShmUnit() != nullptr) { @@ -207,8 +207,6 @@ Status WorkerOcServiceGetImpl::GetObjectFromAnywhere(const ReadKey &readKey, con entry->Get()->SetLifeState(ObjectLifeState::OBJECT_INVALID); entry->Get()->stateInfo.SetCacheInvalid(true); } - } else { - ReturnToClientByObjectKey(objectKey); } return status; } @@ -247,103 +245,67 @@ Status WorkerOcServiceGetImpl::GetDataFromL2CacheForPrimaryCopy(const std::strin return Status(StatusCode::K_NOT_FOUND, "Object not found"); } -Status WorkerOcServiceGetImpl::ProcessGetObjectRequest( - const std::vector &objectKeys, const std::unordered_map &offsetInfos, - std::shared_ptr<::datasystem::ServerUnaryWriterReader> serverApi, const int64_t subTimeout, - const std::string &clientId, std::shared_ptr accessRecorderPoint, const GetReqPb &getReqPb) +Status WorkerOcServiceGetImpl::ProcessGetObjectRequest(int64_t subTimeout, std::shared_ptr &request) { + (void)subTimeout; INJECT_POINT("worker.Get.asyncGetStart", [](int timeout) { reqTimeoutDuration.Init(timeout); return Status::OK(); }); PerfPoint point(PerfKey::WORKER_PROCESS_GET_OBJECT); - std::vector objectsNeedGetRemote; - auto request = - std::make_shared(objectKeys, std::move(serverApi), clientId, -1, getReqPb, accessRecorderPoint); - if (!offsetInfos.empty()) { - request->SetOffset(offsetInfos); - } - - MarkObjectsInGetProcess(objectKeys); - - Raii getProcessGuard([this, &objectKeys]() { UnmarkObjectsInGetProcess(objectKeys); }); // Try get from local. - TryGetObjectFromLocal(offsetInfos, request, objectsNeedGetRemote); - - // Try get from remote worker or L2 cache. - RETURN_IF_NOT_OK(TryGetObjectFromRemote(subTimeout, request, objectsNeedGetRemote)); + std::set remoteObjectKeys; + RETURN_IF_NOT_OK(TryGetObjectFromLocal(request, remoteObjectKeys)); + RETURN_OK_IF_TRUE(request->AlreadyReturn()); - // Return if already call ReturnFromGetRequest to avoid circular references between Timer and GetRequest. - RETURN_OK_IF_TRUE(request->isReturn_); - if (request->isFinished_) { - ReturnToClientByRequest(request); + // Register request for subscribe + if (subTimeout > 0) { + request->Register(&workerRequestManager_); } - INJECT_POINT("worker.Get.beforeReturn"); + // Try get from remote worker or L2 cache. + RETURN_IF_NOT_OK(TryGetObjectFromRemote(subTimeout, request, std::move(remoteObjectKeys))); + RETURN_OK_IF_TRUE(request->AlreadyReturn()); + int64_t remainingTimeMs = reqTimeoutDuration.CalcRealRemainingTime(); - if (request->numSatisfiedObjects_ == request->numWaitingObjects_ || subTimeout == 0 || remainingTimeMs <= 0) { - LOG(INFO) << "The satisfied objects num: " << request->numSatisfiedObjects_ - << ", the waiting objects num: " << request->numWaitingObjects_ + if (request->GetNotReadyCount() == 0 || subTimeout == 0 || remainingTimeMs <= 0) { + LOG(INFO) << "The satisfied objects num: " << request->GetReadyCount() + << ", the waiting objects num: " << request->GetNotReadyCount() << ", the sub timeout: " << subTimeout; - Status rc = workerRequestManager_.ReturnFromGetRequest(request, memoryRefTable_); + Status rc = request->ReturnToClient(); point.Record(); return rc; } - TimerQueue::TimerImpl timer; + auto timer = std::make_unique(); auto traceID = Trace::Instance().GetTraceID(); auto weakThis = weak_from_this(); + // For exclusive connections: inform parent that an async child is deployed + request->GetServerApi()->SetRequestInProgress(); RETURN_IF_NOT_OK(TimerQueue::GetInstance()->AddTimer( std::min(subTimeout, remainingTimeMs), [weakThis, subTimeout, request, traceID, remainingTimeMs]() { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); LOG(ERROR) << "The get request times out, the sub timeout: " << subTimeout - << ", remainingTimeMs: " << remainingTimeMs << ", clientId: " << request->clientId_ - << ", satisfied num: " << request->numSatisfiedObjects_ - << ", waiting num: " << request->numWaitingObjects_; - auto workerOcServiceGetImpl = weakThis.lock(); - if (workerOcServiceGetImpl == nullptr) { + << ", remainingTimeMs: " << remainingTimeMs << ", clientId: " << request->GetClientId() + << ", satisfied num: " << request->GetReadyCount() + << ", waiting num: " << request->GetNotReadyCount(); + auto impl = weakThis.lock(); + if (impl == nullptr) { return; } - workerOcServiceGetImpl->workerRequestManager_.ReturnFromGetRequest(request, - workerOcServiceGetImpl->memoryRefTable_); - // Avoid timeCostPoint destruct after traceGuard. - request->accessRecorderPoint_.reset(); + LOG_IF_ERROR(request->ReturnToClient(), "ReturnToClient failed"); + // For exclusive connections: inform parent that async child has finished + request->GetServerApi()->SetRequestComplete(); }, - timer)); - request->timer_ = std::make_unique(timer); + *timer)); + + request->SetTimer(std::move(timer)); point.Record(); return Status::OK(); } -void WorkerOcServiceGetImpl::MarkObjectsInGetProcess(const std::vector &keys) -{ - std::lock_guard lock(objectsInGetProcessMutex_); - for (const auto &key : keys) { - objectsInGetProcess_[key]++; - } -} - -void WorkerOcServiceGetImpl::UnmarkObjectsInGetProcess(const std::vector &keys) -{ - std::lock_guard lock(objectsInGetProcessMutex_); - for (const auto &key : keys) { - if (--objectsInGetProcess_[key] <= 0) { - objectsInGetProcess_.erase(key); - } - } -} - -bool WorkerOcServiceGetImpl::IsObjectInGetProcess(const std::string &key) -{ - std::shared_lock lock(objectsInGetProcessMutex_); - if (objectsInGetProcess_.count(key) > 0 && objectsInGetProcess_[key] > 1) { - return true; - } - return false; -} - static Status CheckAndResetStatus(const Status &status, std::set &bypassCode) { // If the error is RPC error, return them directly, other error would be covered up as RUNTIME_ERROR. @@ -352,52 +314,45 @@ static Status CheckAndResetStatus(const Status &status, std::set &by : Status(K_RUNTIME_ERROR, status.GetMsg()); } -void WorkerOcServiceGetImpl::TryGetObjectFromLocal(const std::unordered_map &offsetInfos, - std::shared_ptr &request, - std::vector &objectsNeedGetRemote) +Status WorkerOcServiceGetImpl::TryGetObjectFromLocal(std::shared_ptr &request, + std::set &remoteObjectKeys) { - std::vector localExistKeys; - localExistKeys.reserve(request->deduplicatedObjectKeys_.size()); - for (const auto &objectKey : request->deduplicatedObjectKeys_) { - if (asyncRollbackManager_->IsObjectsInRollBack({ objectKey })) { - if (request->objects_.emplace(objectKey, nullptr)) { - (void)request->numSatisfiedObjects_.fetch_add(1); - } + Status lastRc; + auto &uniqueObjectMap = request->GetObjects(); + + asyncRollbackManager_->UpdateIsRollback(uniqueObjectMap); + + for (auto &[objectKey, objectInfo] : uniqueObjectMap) { + if (objectInfo.isRollBack) { + objectInfo.rc = Status(K_NOT_FOUND, FormatString("ObjectKey %s in rollback", objectKey)); } else { - ReadKey readKey(objectKey); - auto iter = offsetInfos.find(objectKey); - if (iter != offsetInfos.end()) { - readKey.readOffset = iter->second.readOffset; - readKey.readSize = iter->second.readSize; - } - Status status = PreProcessGetObject(readKey, request, objectsNeedGetRemote, localExistKeys); + ReadKey readKey(objectKey, objectInfo.offsetInfo); + Status status = PreProcessGetObject(readKey, objectInfo, remoteObjectKeys); if (status.IsError()) { + objectInfo.rc = status; + lastRc = status; LOG(ERROR) << "PreProcessGetObject failed:" << status.GetMsg(); - static std::set bypassCode{ K_OUT_OF_MEMORY, K_OUT_OF_RANGE }; - Status finalStatus = CheckAndResetStatus(status, bypassCode); - request->SetStatus(finalStatus); - if (request->objects_.emplace(objectKey, nullptr)) { - (void)request->numSatisfiedObjects_.fetch_add(1); - } } } - // Add request even if failed. - (void)workerRequestManager_.AddRequest(objectKey, request); } - LOG(INFO) << "Local exist keys: " << VectorToString(localExistKeys); + if (lastRc.IsError()) { + static std::set bypassCode{ K_OUT_OF_MEMORY, K_OUT_OF_RANGE }; + lastRc = CheckAndResetStatus(lastRc, bypassCode); + } + return request->UpdateAfterLocalGet(std::move(lastRc), remoteObjectKeys.size()); } Status WorkerOcServiceGetImpl::TryGetObjectFromRemote(int64_t subTimeout, std::shared_ptr &request, - std::vector &objectsNeedGetRemote) + std::set remoteObjectKeys) { - RETURN_OK_IF_TRUE(objectsNeedGetRemote.empty()); - auto needRemoteGetIds = objectsNeedGetRemote; + RETURN_OK_IF_TRUE(remoteObjectKeys.empty()); + auto needRemoteGetIds = std::move(remoteObjectKeys); PerfPoint pointRemote(PerfKey::WORKER_PROCESS_GET_OBJECT_REMOTE); std::unordered_set failedIds; Status status; do { - std::vector needRetryIds; + std::set needRetryIds; status = ProcessObjectsNotExistInLocal(needRemoteGetIds, subTimeout, failedIds, needRetryIds, request); int64_t remainTimeMs = reqTimeoutDuration.CalcRealRemainingTime(); const int64_t timeoutThresholdMs = 100; @@ -405,7 +360,7 @@ Status WorkerOcServiceGetImpl::TryGetObjectFromRemote(int64_t subTimeout, std::s // If we meets OOM, never try get again because there is no space for us to save the objects. if (status.GetCode() == K_OUT_OF_MEMORY || remainTimeMs <= timeoutThresholdMs) { std::for_each(needRetryIds.begin(), needRetryIds.end(), - [&](ReadKey &key) { failedIds.emplace(key.objectKey); }); + [&](const ReadKey &key) { failedIds.emplace(key.objectKey); }); break; } needRemoteGetIds.swap(needRetryIds); @@ -415,7 +370,7 @@ Status WorkerOcServiceGetImpl::TryGetObjectFromRemote(int64_t subTimeout, std::s pointRemote.Record(); if (status.GetCode() == K_OUT_OF_MEMORY) { LOG(INFO) << "TryGetObjectFromRemote failed, detail: " << status.ToString(); - return workerRequestManager_.ReturnFromGetRequest(request, memoryRefTable_, status); + return request->ReturnToClient(status); } Status lastRc; @@ -424,28 +379,26 @@ Status WorkerOcServiceGetImpl::TryGetObjectFromRemote(int64_t subTimeout, std::s // K_OUT_OF_RANGE: offset > szie. static std::set bypassCodeRemoteGet{ K_OUT_OF_RANGE }; if (status.GetCode() == K_NOT_FOUND_IN_L2CACHE) { - LOG(ERROR) << status.ToString(); + LOG(ERROR) << "ProcessObjectsNotExistInLocal failed with status: " << status.ToString(); auto msg = "Cannot get object from worker and l2 cache"; lastRc = Status(K_NOT_FOUND, msg); } else { lastRc = CheckAndResetStatus(status, bypassCodeRemoteGet); } for (const auto &id : failedIds) { - LOG_IF_ERROR(workerRequestManager_.UpdateRequestForFailed(id, lastRc, memoryRefTable_), - "UpdateRequestForFailed failed"); + LOG_IF_ERROR(request->MarkFailed(id, lastRc), "MarkFailed failed"); } } return Status::OK(); } -Status WorkerOcServiceGetImpl::PreProcessGetObject(const ReadKey &readKey, std::shared_ptr &request, - std::vector &objectsNeedGetRemote, - std::vector &localExistKeys) +Status WorkerOcServiceGetImpl::PreProcessGetObject(const ReadKey &readKey, GetObjInfo &info, + std::set &remoteObjectKeys) { INJECT_POINT("worker.PreProcessGetObject.begin"); // use RLock instead of WLock try get from memory. bool objIsValidInMem = true; - Status memGetRes = RLockGetObjectFromMem(readKey, request, objectsNeedGetRemote, objIsValidInMem, localExistKeys); + Status memGetRes = RLockGetObjectFromMem(readKey, info, remoteObjectKeys, objIsValidInMem); INJECT_POINT("set.objectIsInvalidInmem", [&objIsValidInMem]() { objIsValidInMem = false; return Status::OK(); @@ -459,7 +412,7 @@ Status WorkerOcServiceGetImpl::PreProcessGetObject(const ReadKey &readKey, std:: Status rc = objectTable_->Get(readKey.objectKey, entry); RETURN_IF_NOT_OK_EXCEPT(rc, K_NOT_FOUND); if (rc.GetCode() == K_NOT_FOUND) { - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); return Status::OK(); } ReadObjectKV objectKV(readKey, *entry); @@ -469,12 +422,12 @@ Status WorkerOcServiceGetImpl::PreProcessGetObject(const ReadKey &readKey, std:: rc = entry->WLock(true); RETURN_IF_NOT_OK_EXCEPT(rc, K_NOT_FOUND); if (rc.GetCode() == K_NOT_FOUND) { - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); return Status::OK(); } Raii unlock([&entry]() { entry->WUnlock(); }); if ((*entry).Get() == nullptr) { - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); return Status::OK(); } INJECT_POINT("set.objectIsInComplete", [&entry]() { @@ -489,32 +442,28 @@ Status WorkerOcServiceGetImpl::PreProcessGetObject(const ReadKey &readKey, std:: auto res = KeepObjectDataInMemory(objectKV); if (res.IsError()) { // object not found from disk or l2 cache, try get from remote. - objectsNeedGetRemote.emplace_back(readKey); + (void)remoteObjectKeys.emplace(readKey); return Status::OK(); } LOG(INFO) << FormatString("[ObjectKey %s] already load to memory", readKey); } else { // case 2: object exist in local node, but it is the expired version. Status status = TryGetObjectsFromPrimaryWorker((*entry)->GetAddress(), (*entry)->GetDataSize(), objectKV, - objectsNeedGetRemote); + remoteObjectKeys); if (status.IsError()) { return Status::OK(); } } - if (request->objects_.emplace(readKey.objectKey, - GetObjEntryParams::Create(*entry, readKey.readOffset, readKey.readSize))) { - request->numSatisfiedObjects_.fetch_add(1); - } + info.params = GetObjEntryParams::Create(readKey.objectKey, *entry); } else { // case 3: object didn't exist in local node, is not published or sealed yet. - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); } return Status::OK(); } -Status WorkerOcServiceGetImpl::RLockGetObjectFromMem(const ReadKey &readKey, std::shared_ptr &request, - std::vector &objectsNeedGetRemote, bool &objIsValidInMem, - std::vector &localExistKeys) +Status WorkerOcServiceGetImpl::RLockGetObjectFromMem(const ReadKey &readKey, GetObjInfo &info, + std::set &remoteObjectKeys, bool &objIsValidInMem) { std::shared_ptr entry; // Fetch the object and lock it. @@ -522,7 +471,7 @@ Status WorkerOcServiceGetImpl::RLockGetObjectFromMem(const ReadKey &readKey, std Status rc = objectTable_->Get(readKey.objectKey, entry); RETURN_IF_NOT_OK_EXCEPT(rc, K_NOT_FOUND); if (rc.GetCode() == K_NOT_FOUND) { - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); return Status::OK(); } // If entry RLock is not found, it means the object is deleting locally. @@ -530,12 +479,12 @@ Status WorkerOcServiceGetImpl::RLockGetObjectFromMem(const ReadKey &readKey, std rc = entry->RLock(true); RETURN_IF_NOT_OK_EXCEPT(rc, K_NOT_FOUND); if (rc.GetCode() == K_NOT_FOUND) { - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); return Status::OK(); } Raii unlock([&entry]() { entry->RUnlock(); }); if ((*entry).Get() == nullptr) { - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); return Status::OK(); } CHECK_FAIL_RETURN_STATUS((*entry)->IsBinary(), K_INVALID, "Not a Shm Unit"); @@ -548,12 +497,8 @@ Status WorkerOcServiceGetImpl::RLockGetObjectFromMem(const ReadKey &readKey, std objIsValidInMem = false; RETURN_STATUS(K_NOT_FOUND, FormatString("[ObjectKey %s] not exist in memory.", readKey)); } - if (request->objects_.emplace(readKey.objectKey, - GetObjEntryParams::Create(*entry, readKey.readOffset, readKey.readSize))) { - request->numSatisfiedObjects_.fetch_add(1); - } + info.params = GetObjEntryParams::Create(readKey.objectKey, *entry); evictionManager_->Add(readKey.objectKey); - localExistKeys.emplace_back(readKey.objectKey); return Status::OK(); } // case 2: object exist in local node,and it is the expired version. @@ -561,21 +506,17 @@ Status WorkerOcServiceGetImpl::RLockGetObjectFromMem(const ReadKey &readKey, std RETURN_STATUS(K_NOT_FOUND, FormatString("[ObjectKey %s] exists locally but expired.", readKey)); } else { // object didn't exist in local node, is not published or sealed yet. - objectsNeedGetRemote.push_back(readKey); + (void)remoteObjectKeys.emplace(readKey); } return Status::OK(); } -Status WorkerOcServiceGetImpl::ProcessObjectsNotExistInLocal(const std::vector &objectsNeedGetRemote, - const int64_t subTimeout, +Status WorkerOcServiceGetImpl::ProcessObjectsNotExistInLocal(const std::set &objectsNeedGetRemote, + int64_t subTimeout, std::unordered_set &failedIds, - std::vector &needRetryIds, + std::set &needRetryIds, const std::shared_ptr &request) { - std::map readKeys; - for (const auto &id : objectsNeedGetRemote) { - readKeys.insert(std::make_pair(id.objectKey, id)); - } LOG(INFO) << "Begin to process " << objectsNeedGetRemote.size() << " objects that doesn't exist in local: [" << VectorToString(objectsNeedGetRemote) << "]"; AddInRemoteGetObjects(objectsNeedGetRemote); @@ -587,28 +528,28 @@ Status WorkerOcServiceGetImpl::ProcessObjectsNotExistInLocal(const std::vector, bool>> lockedEntries; + std::map lockedEntries; lastRc = BatchLockForGet(objectsNeedGetRemote, lockedEntries, failedIds); Raii unlockRaii([this, &failedIds, &lockedEntries]() { BatchUnlockForGet(failedIds, lockedEntries); }); // If a local publish or remote get finished before we lock the object, we will get a valid object here. - AttemptGetObjectsLocally(readKeys, lockedEntries); + AttemptGetObjectsLocally(request, lockedEntries); std::vector needRemoteGetObjects; std::transform(lockedEntries.begin(), lockedEntries.end(), std::back_inserter(needRemoteGetObjects), - [](const auto &kv) { return kv.first; }); + [](const auto &kv) { return kv.first.objectKey; }); QueryMetadataFromMasterResult queryMetaResult; std::vector &queryMetas = queryMetaResult.queryMetas; std::vector &payloads = queryMetaResult.payloads; - std::map &absentObjectKeys = queryMetaResult.absentObjectKeysWithVersion; - Status result = QueryMetadataFromMaster(needRemoteGetObjects, subTimeout, queryMetaResult, - !request->requestInfo_.no_query_l2cache()); + const auto &absentObjectKeys = queryMetaResult.absentObjectKeysWithVersion; + Status result = + QueryMetadataFromMaster(needRemoteGetObjects, subTimeout, queryMetaResult, !request->NoQueryL2Cache()); if (result.IsError()) { // If we query meta from master meets RPC error, do not add these objects to failedIds, // otherwise other concurrent get operations would failed, so we just notify ourselves. if (IsRpcTimeoutOrTryAgain(result)) { for (const auto &objectKey : needRemoteGetObjects) { - workerRequestManager_.UpdateSpecificRequestForFailed(request, objectKey, result, memoryRefTable_); + LOG_IF_ERROR(request->MarkFailed(objectKey, result), "MarkFailed failed"); } } else { failedIds.insert(needRemoteGetObjects.begin(), needRemoteGetObjects.end()); @@ -629,7 +570,7 @@ Status WorkerOcServiceGetImpl::ProcessObjectsNotExistInLocal(const std::vector &objectInfos, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds) +Status WorkerOcServiceGetImpl::GetObjectsWithoutMeta(const std::map &objectKeys, + std::map &lockedEntries, + std::unordered_set &failedIds) { Status lastRc = Status::OK(); std::vector successIds; - for (const auto &kv : objectInfos) { - const auto objectKey = kv.first; - auto it = lockedEntries.find(objectKey); + for (const auto &kv : objectKeys) { + const auto &objectKey = kv.first; + auto it = lockedEntries.find(ReadKey(objectKey)); if (it == lockedEntries.end()) { lastRc = Status( K_RUNTIME_ERROR, @@ -663,7 +603,7 @@ Status WorkerOcServiceGetImpl::GetObjectsWithoutMeta( } // full read. ReadKey readKey(objectKey); - ReadObjectKV objectKV(readKey, *it->second.first); + ReadObjectKV objectKV(readKey, *it->second.safeObj); objectKV.GetObjEntry()->SetMetadataSize(metadataSize_); Status rc = GetObjectsWithoutMetaFromL2Cache(objectKV, kv.second); if (rc.IsOk()) { @@ -688,7 +628,7 @@ Status WorkerOcServiceGetImpl::GetObjectsWithoutMetaFromL2Cache(ObjectKV &object Status WorkerOcServiceGetImpl::TryGetObjectsFromPrimaryWorker(const std::string &primaryAddress, uint64_t dataSize, ReadObjectKV &objectKV, - std::vector &objectsNeedGetRemote) + std::set &objectsNeedGetRemote) { const auto &objectKey = objectKV.GetObjKey(); LOG(INFO) << FormatString("[ObjectKey %s] exist in local node but expired, remote worker: %s, local worker: %s", @@ -700,12 +640,12 @@ Status WorkerOcServiceGetImpl::TryGetObjectsFromPrimaryWorker(const std::string LOG(INFO) << "Try to Pull from primary worker failed. The system will obtain from other worker or " "l2 cache again. Detail: " << status.ToString(); - objectsNeedGetRemote.push_back(objectKV.ConstructReadKey()); + objectsNeedGetRemote.emplace(objectKV.ConstructReadKey()); return status; } RETURN_OK_IF_TRUE(status.IsOk()); } - objectsNeedGetRemote.push_back(objectKV.ConstructReadKey()); + objectsNeedGetRemote.emplace(objectKV.ConstructReadKey()); return status; } @@ -773,14 +713,15 @@ Status WorkerOcServiceGetImpl::GetObjectFromRemoteWorkerAndDump(const std::strin } template -Status WorkerOcServiceGetImpl::PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV &objectKV, Req &reqPb, +Status WorkerOcServiceGetImpl::PrepareGetRequestHelper(uint64_t dataSize, ReadObjectKV &objectKV, Req &reqPb, bool &shmUnitAllocated, std::shared_ptr shmOwner) { + // If URMA is enabled, or if shmOwner is not nullptr, memory distribution/allocation needs to be processed. if (!IsUrmaEnabled() && shmOwner == nullptr) { return Status::OK(); } reqPb.set_data_size(dataSize); - INJECT_POINT("WorkerOcServiceGetImpl.PrepareUrmaInfo.changeSize", [&reqPb](uint64_t testDataSize) { + INJECT_POINT("WorkerOcServiceGetImpl.PrepareGetRequestHelper.changeSize", [&reqPb](uint64_t testDataSize) { reqPb.set_data_size(testDataSize); return Status::OK(); }); @@ -802,22 +743,42 @@ Status WorkerOcServiceGetImpl::PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV & RETURN_IF_NOT_OK( AllocateMemoryForObject(objectKey, dataSize, metaSz, populate, evictionManager_, *shmUnit)); } - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); entry->SetShmUnit(shmUnit); shmUnitAllocated = true; } - RETURN_IF_NOT_OK(FillUrmaInfo(shmUnit, localAddress_, metaSz, *reqPb.mutable_urma_info())); + // Early exit for the urma info. + if (!IsUrmaEnabled()) { + return Status::OK(); + } + uint64_t segAddress; + uint64_t dataOffset; + if (FLAGS_urma_register_whole_arena) { + segAddress = reinterpret_cast(shmUnit->GetPointer()) - shmUnit->GetOffset(); + dataOffset = shmUnit->GetOffset() + metaSz; + } else { + segAddress = reinterpret_cast(shmUnit->GetPointer()); + dataOffset = metaSz; + } + auto *urmaInfo = reqPb.mutable_urma_info(); + urmaInfo->set_seg_va(segAddress); + urmaInfo->set_seg_data_offset(dataOffset); + auto *remoteAddr = urmaInfo->mutable_request_address(); + remoteAddr->set_host(localAddress_.Host()); + remoteAddr->set_port(localAddress_.Port()); return Status::OK(); } -Status WorkerOcServiceGetImpl::ConstructBatchGetRequest( - const std::string &address, std::list &metas, const std::map &readKeys, - std::map, bool>> &lockedEntries, - std::vector &successIds, std::vector &needRetryIds, - std::unordered_set &failedIds, BatchGetObjectRemoteReqPb &reqPb) +Status WorkerOcServiceGetImpl::ConstructBatchGetRequest(const std::string &address, std::list &metas, + std::map &lockedEntries, + std::vector &successIds, + std::vector &needRetryIds, + std::unordered_set &failedIds, + BatchGetObjectRemoteReqPb &reqPb) { PerfPoint point(PerfKey::WORKER_CONSTRUCT_BATCH_GET_REQ); - // The function is placed together with PrepareUrmaInfo, as the template definition does not suit in header file. + // The function is placed together with PrepareGetRequestHelper, + // as the template definition does not suit in header file. Status lastRc = Status::OK(); // Pre-allocate an aggregated chunk of shared memory as ShmOwner, to reduce the number of allocation calls. std::vector> shmOwners; @@ -825,39 +786,44 @@ Status WorkerOcServiceGetImpl::ConstructBatchGetRequest( RETURN_IF_NOT_OK(AggregateAllocateHelper(metas, lockedEntries, shmOwners, shmIndexMapping)); bool requestReady = false; - uint32_t objectId = 0; - for (auto metaIter = metas.begin(); metaIter != metas.end(); objectId++) { + uint32_t objectIndex = 0; + for (auto metaIter = metas.begin(); metaIter != metas.end(); objectIndex++) { auto &meta = *metaIter; const auto &objectKey = meta->object_key(); + auto iter = lockedEntries.find(ReadKey(objectKey)); + if (iter == lockedEntries.cend()) { + LOG(WARNING) << "ObjectKey " << objectKey << " not exsits in lockedEntries"; + continue; + } // Checked availability when metas are grouped, so it should be safe to just access the entry here. - auto &pair = lockedEntries.at(objectKey); - auto &entry = pair.first; + auto &entry = iter->second.safeObj; // Re-set object entry in the case of looped for data size change. SetObjectEntryAccordingToMeta(*meta, GetMetadataSize(), *entry); - auto &readKey = readKeys.at(objectKey); + const auto &readKey = iter->first; ReadObjectKV objectKV(readKey, *entry); Status status = objectKV.CheckReadOffset(); if (status.IsError()) { - BatchGetObjectHandleIndividualStatus(status, objectKey, readKey, successIds, needRetryIds, failedIds); + BatchGetObjectHandleIndividualStatus(status, readKey, successIds, needRetryIds, failedIds); metaIter = metas.erase(metaIter); lastRc = status; continue; } - datasystem::BatchGetObjectRemoteReqPb_GetObjectRemoteBaseReqPb subReq; + GetObjectRemoteReqPb subReq; subReq.set_object_key(objectKey); subReq.set_version((*entry)->GetCreateTime()); subReq.set_read_offset(objectKV.GetReadOffset()); subReq.set_read_size(objectKV.GetReadSize()); + subReq.set_data_size(meta->data_size()); // Prepare the protobuf with urma info for data transfer if applicable. // BatchGetObjectHandleIndividualStatus will free ShmUnit upon error, so no need to actually record it here. bool shmUnitAllocated = false; std::shared_ptr shmOwner = nullptr; - if (shmIndexMapping.size() > objectId && shmOwners.size() > shmIndexMapping[objectId]) { - shmOwner = shmOwners[shmIndexMapping[objectId]]; + if (shmIndexMapping.size() > objectIndex && shmOwners.size() > shmIndexMapping[objectIndex]) { + shmOwner = shmOwners[shmIndexMapping[objectIndex]]; } - status = PrepareUrmaInfo(meta->data_size(), objectKV, subReq, shmUnitAllocated, shmOwner); + status = PrepareGetRequestHelper(meta->data_size(), objectKV, subReq, shmUnitAllocated, shmOwner); if (status.IsError()) { - BatchGetObjectHandleIndividualStatus(status, objectKey, readKey, successIds, needRetryIds, failedIds); + BatchGetObjectHandleIndividualStatus(status, readKey, successIds, needRetryIds, failedIds); metaIter = metas.erase(metaIter); lastRc = status; continue; @@ -907,7 +873,7 @@ Status WorkerOcServiceGetImpl::PullObjectDataFromRemoteWorker(const std::string dataSizeChange = false; bool shmUnitAllocated = false; // Prepare the protobuf with urma info for data transfer if applicable. - RETURN_IF_NOT_OK(PrepareUrmaInfo(dataSize, objectKV, reqPb, shmUnitAllocated)); + RETURN_IF_NOT_OK(PrepareGetRequestHelper(dataSize, objectKV, reqPb, shmUnitAllocated)); // If getting data from other AZ, then we leave 3/4 remain time to query from L2 cache in case getting data // failed. int64_t timeoutMs = @@ -976,7 +942,7 @@ Status WorkerOcServiceGetImpl::RetrieveRemotePayload( auto shmUnit = std::make_shared(); RETURN_IF_NOT_OK(AllocateMemoryForObject(objectKey, completeDataSize, metaSz, false, evictionManager_, *shmUnit, entry->modeInfo.GetCacheType())); - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); entry->SetShmUnit(shmUnit); } @@ -1020,22 +986,16 @@ Status WorkerOcServiceGetImpl::RetrieveRemotePayload( return Status::OK(); } -void WorkerOcServiceGetImpl::AttemptGetObjectsLocally( - const std::map &readKeys, - std::map, bool>> &lockedEntries) +void WorkerOcServiceGetImpl::AttemptGetObjectsLocally(const std::shared_ptr &request, + std::map &lockedEntries) { - auto localGet = [this, readKeys](const std::string &objectKey, std::shared_ptr &entry) { + auto localGet = [this, request](const ReadKey &readKey, std::shared_ptr &entry) { if ((entry->Get() != nullptr) && entry->Get()->IsBinary() && !entry->Get()->stateInfo.IsCacheInvalid() && entry->Get()->IsGetDataEnablelFromLocal()) { - auto it = readKeys.find(objectKey); - if (it == readKeys.end()) { - return Status(K_NOT_FOUND, ""); - } - ReadKey readKey = it->second; ReadObjectKV objectKV(readKey, *entry); RETURN_IF_NOT_OK(KeepObjectDataInMemory(objectKV)); - RETURN_IF_NOT_OK(UpdateRequestForSuccess(objectKV)); - RemoveInRemoteGetObject(objectKey); + RETURN_IF_NOT_OK(UpdateRequestForSuccess(objectKV, request)); + RemoveInRemoteGetObject(readKey.objectKey); entry->WUnlock(); return Status::OK(); } @@ -1044,9 +1004,9 @@ void WorkerOcServiceGetImpl::AttemptGetObjectsLocally( auto it = lockedEntries.begin(); while (it != lockedEntries.end()) { - const auto &objectKey = it->first; - auto &entry = it->second.first; - if (localGet(objectKey, entry).IsOk()) { + auto &entry = it->second.safeObj; + const auto &readKey = it->first; + if (localGet(readKey, entry).IsOk()) { lockedEntries.erase(it++); } else { it++; @@ -1208,7 +1168,7 @@ Status WorkerOcServiceGetImpl::QueryMetadataFromMaster(const std::vector> objKeysGrpByMaster; std::unordered_map> objKeysUndecidedMaster; - RETURN_IF_NOT_OK(etcdCM_->GroupObjKeysByMasterHostPort(objectKeys, objKeysGrpByMaster, objKeysUndecidedMaster)); + etcdCM_->GroupObjKeysByMasterHostPort(objectKeys, objKeysGrpByMaster, objKeysUndecidedMaster); // 2. Send requests for each master std::vector> futures; std::string traceID = Trace::Instance().GetTraceID(); @@ -1217,12 +1177,9 @@ Status WorkerOcServiceGetImpl::QueryMetadataFromMaster(const std::vector batchQueryResults; batchQueryResults.resize(objKeysGrpByMaster.size()); size_t idx = 0; - size_t threadNum = std::min(objKeysGrpByMaster.size(), FLAGS_rpc_thread_num); - auto batchQueryThreadPool = std::make_unique(1, threadNum, "BatchQureyMeta"); for (auto &item : objKeysGrpByMaster) { BatchQueryMetaResult &res = batchQueryResults[idx++]; - futures.emplace_back(batchQueryThreadPool->Submit([&res, realTimeoutMs, subTimeout, item, traceID, timer, - this]() { + auto func = [&res, realTimeoutMs, subTimeout, item, traceID, timer, this]() { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); int64_t elapsed = timer.ElapsedMilliSecond(); reqTimeoutDuration.Init(realTimeoutMs - elapsed); @@ -1239,7 +1196,13 @@ Status WorkerOcServiceGetImpl::QueryMetadataFromMaster(const std::vectorSubmit(std::move(func))); + } } for (auto &f : futures) { f.wait(); @@ -1470,30 +1433,29 @@ Status WorkerOcServiceGetImpl::CorrectQueryMetaResponse(std::vector return Status::OK(); } -Status WorkerOcServiceGetImpl::GetObjectsFromAnywhere( - std::vector &queryMetas, const std::map &readKeys, - const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds) +Status WorkerOcServiceGetImpl::GetObjectsFromAnywhere(std::vector &queryMetas, + const std::shared_ptr &request, + std::vector &payloads, + std::map &lockedEntries, + std::unordered_set &failedIds, + std::set &needRetryIds) { if (FLAGS_enable_worker_worker_batch_get) { - return GetObjectsFromAnywhereBatched(queryMetas, readKeys, request, payloads, lockedEntries, failedIds, - needRetryIds); + return GetObjectsFromAnywhereBatched(queryMetas, request, payloads, lockedEntries, failedIds, needRetryIds); } - return GetObjectsFromAnywhereParallelly(queryMetas, readKeys, request, payloads, lockedEntries, failedIds, - needRetryIds); + return GetObjectsFromAnywhereParallelly(queryMetas, request, payloads, lockedEntries, failedIds, needRetryIds); } -Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereParallelly( - const std::vector &queryMetas, const std::map &readKeys, - const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds) +Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereParallelly(const std::vector &queryMetas, + const std::shared_ptr &request, + std::vector &payloads, + std::map &lockedEntries, + std::unordered_set &failedIds, + std::set &needRetryIds) { const size_t kMinParallelRequests = 2; if (queryMetas.size() < kMinParallelRequests) { - return GetObjectsFromAnywhereSerially(queryMetas, readKeys, request, payloads, lockedEntries, failedIds, - needRetryIds); + return GetObjectsFromAnywhereSerially(queryMetas, request, payloads, lockedEntries, failedIds, needRetryIds); } Status lastRc = Status::OK(); std::vector successIds; @@ -1518,27 +1480,14 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereParallelly( LOG(ERROR) << lastRc; continue; } - auto iter = lockedEntries.find(meta.object_key()); - if (iter == lockedEntries.end()) { - LOG(ERROR) << FormatString("[ObjectKey %s] QueryMeta exist but lock entry absent, should not happen", - meta.object_key()); - lastRc = Status(K_UNKNOWN_ERROR, "QueryMeta exist but lock entry absent, should not happen"); - continue; - } - if (readKeys.find(meta.object_key()) == readKeys.end()) { - LOG(ERROR) << FormatString("[ObjectKey %s] cant find offset and size to get", meta.object_key()); - lastRc = Status(K_UNKNOWN_ERROR, "Can not find offset or size to get object"); - continue; - } - ReadKey readKey = readKeys.at(meta.object_key()); Timer timer; int64_t realTimeoutMs = reqTimeoutDuration.CalcRealRemainingTime(); - + auto traceId = Trace::Instance().GetTraceID(); futures.emplace_back(remoteGetThreadPool_->Submit([=, &lockedEntries, &commonMutex, &abortAllTasks, &request, &payloads, &lastRc, &successIds, &needRetryIds, - &failedIds]() { - TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + &failedIds, &traceId]() { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); int64_t elapsed = timer.ElapsedMilliSecond(); reqTimeoutDuration.Init(realTimeoutMs - elapsed); if (abortAllTasks.load()) { @@ -1546,16 +1495,18 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereParallelly( } const auto &queryMeta = queryMetas[i]; const auto &meta = queryMeta.meta(); - auto subIter = lockedEntries.find(meta.object_key()); + const auto &objectKey = meta.object_key(); + auto subIter = lockedEntries.find(ReadKey(objectKey)); if (subIter == lockedEntries.end()) { std::lock_guard lock(commonMutex); - LOG(INFO) << FormatString("[ObjectKey %s] Object not found in locked entries", meta.object_key()); - lastRc = Status(K_NOT_FOUND, - FormatString("[ObjectKey %s] Object not found in locked entries", meta.object_key())); + LOG(INFO) << FormatString("[ObjectKey %s] Object not found in locked entries", objectKey); + lastRc = + Status(K_NOT_FOUND, FormatString("[ObjectKey %s] Object not found in locked entries", objectKey)); return lastRc; } - std::shared_ptr &subEntry = subIter->second.first; - bool isInsert = subIter->second.second; + const auto &readKey = subIter->first; + std::shared_ptr &subEntry = subIter->second.safeObj; + bool isInsert = subIter->second.insert; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(subEntry->TransferWLockToCurrentThread(), "Lock failed"); Status status = GetObjectFromAnywhereWithLock(readKey, request, subEntry, isInsert, queryMeta, payloads); @@ -1563,19 +1514,18 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereParallelly( std::lock_guard lock(commonMutex); if (status.IsOk()) { - LOG(INFO) << FormatString("[ObjectKey %s] Get from remote success.", meta.object_key()); - successIds.push_back(meta.object_key()); + LOG(INFO) << FormatString("[ObjectKey %s] Get from remote success.", objectKey); + successIds.push_back(objectKey); } else if (status.GetCode() == K_WORKER_PULL_OBJECT_NOT_FOUND) { - LOG(INFO) << FormatString("[ObjectKey %s] Object not found in remote worker.", meta.object_key()); - needRetryIds.emplace_back(readKey); + LOG(INFO) << FormatString("[ObjectKey %s] Object not found in remote worker.", objectKey); + (void)needRetryIds.emplace(readKey); } else if (status.GetCode() == K_OUT_OF_MEMORY) { - LOG(INFO) << FormatString("[ObjectKey %s] Out of memory, get remote abort.", meta.object_key()); + LOG(INFO) << FormatString("[ObjectKey %s] Out of memory, get remote abort.", objectKey); lastRc = status; abortAllTasks.store(true); } else { - LOG(ERROR) << FormatString("[ObjectKey %s] Get from remote failed: %s.", meta.object_key(), - status.ToString()); - failedIds.emplace(meta.object_key()); + LOG(ERROR) << FormatString("[ObjectKey %s] Get from remote failed: %s.", objectKey, status.ToString()); + failedIds.emplace(objectKey); lastRc = status; } @@ -1595,11 +1545,12 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereParallelly( return lastRc; } -Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereSerially( - const std::vector &queryMetas, const std::map &readKeys, - const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds) +Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereSerially(const std::vector &queryMetas, + const std::shared_ptr &request, + std::vector &payloads, + std::map &lockedEntries, + std::unordered_set &failedIds, + std::set &needRetryIds) { Status lastRc; std::vector successIds; @@ -1608,43 +1559,38 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereSerially( const auto &queryMeta = *queryIt; const auto &meta = queryMeta.meta(); const auto dataFormat = static_cast(queryMeta.meta().config().data_format()); + const auto &objectKey = meta.object_key(); if (dataFormat != DataFormat::BINARY && dataFormat != DataFormat::HETERO) { lastRc = Status(K_INVALID, "object data format not match."); - failedIds.emplace(meta.object_key()); + failedIds.emplace(objectKey); LOG(ERROR) << lastRc; continue; } - auto iter = lockedEntries.find(meta.object_key()); + auto iter = lockedEntries.find(ReadKey(objectKey)); if (iter == lockedEntries.end()) { LOG(ERROR) << FormatString("[ObjectKey %s] QueryMeta exist but lock entry absent, should not happen", - meta.object_key()); + objectKey); lastRc = Status(K_UNKNOWN_ERROR, "QueryMeta exist but lock entry absent, should not happen"); continue; } - if (readKeys.find(meta.object_key()) == readKeys.end()) { - LOG(ERROR) << FormatString("[ObjectKey %s] cant find offset and size to get", meta.object_key()); - lastRc = Status(K_UNKNOWN_ERROR, "Can not find offset or size to get object"); - continue; - } - ReadKey readKey = readKeys.at(meta.object_key()); - auto status = GetObjectFromAnywhereWithLock(readKey, request, iter->second.first, iter->second.second, + const auto &readKey = iter->first; + auto status = GetObjectFromAnywhereWithLock(readKey, request, iter->second.safeObj, iter->second.insert, queryMeta, payloads); if (status.IsOk()) { - LOG(INFO) << FormatString("[ObjectKey %s] Get from remote success.", meta.object_key()); - successIds.push_back(meta.object_key()); + LOG(INFO) << FormatString("[ObjectKey %s] Get from remote success.", objectKey); + successIds.push_back(objectKey); } else if (status.GetCode() == K_WORKER_PULL_OBJECT_NOT_FOUND) { - LOG(INFO) << FormatString("[ObjectKey %s] Object not found in remote worker.", meta.object_key()); + LOG(INFO) << FormatString("[ObjectKey %s] Object not found in remote worker.", objectKey); lastRc = Status::OK(); - needRetryIds.emplace_back(readKey); + (void)needRetryIds.emplace(readKey); } else if (status.GetCode() == K_OUT_OF_MEMORY) { - LOG(INFO) << FormatString("[ObjectKey %s] Out of memory, get remote abort.", meta.object_key()); + LOG(INFO) << FormatString("[ObjectKey %s] Out of memory, get remote abort.", objectKey); lastRc = status; break; } else { - LOG(ERROR) << FormatString("[ObjectKey %s] Get from remote failed: %s.", meta.object_key(), - status.ToString()); + LOG(ERROR) << FormatString("[ObjectKey %s] Get from remote failed: %s.", objectKey, status.ToString()); lastRc = status; - failedIds.emplace(meta.object_key()); + failedIds.emplace(objectKey); } } @@ -1682,14 +1628,14 @@ Status WorkerOcServiceGetImpl::GetObjectFromAnywhereWithLock(const ReadKey &read // we will get a valid object here. ReadObjectKV objectKV(readKey, *entry); RETURN_IF_NOT_OK(KeepObjectDataInMemory(objectKV)); - RETURN_IF_NOT_OK(UpdateRequestForSuccess(objectKV)); + RETURN_IF_NOT_OK(UpdateRequestForSuccess(objectKV, request)); return Status::OK(); } SetObjectEntryAccordingToMeta(meta, GetMetadataSize(), *entry); ReadObjectKV objectKV(readKey, *entry); Status status = queryMeta.payload_indexs_size() == 0 ? GetObjectFromRemoteOnLock(meta, request, address, queryMeta.single_copy(), objectKV) - : GetObjectFromQueryMetaResultOnLock(queryMeta, payloads, objectKV); + : GetObjectFromQueryMetaResultOnLock(request, queryMeta, payloads, objectKV); if (status.IsError()) { HandleGetFailureHelper(objectKey, meta.version(), entry, isInsert); } @@ -1778,7 +1724,7 @@ Status WorkerOcServiceGetImpl::GetObjectFromRemoteOnLock(const ObjectMetaPb &met LOG(INFO) << FormatString("object(%s) get from remote finish, size:%zu, use %f millisecond.", objKey, entry->GetDataSize(), endToEndTimer.ElapsedMilliSecond()); point.Record(); - return UpdateRequestForSuccessNotReturnForClient(objectKV, request); + return UpdateRequestForSuccess(objectKV, request); } void WorkerOcServiceGetImpl::TryGetObjectFromOtherAZ(const ObjectMetaPb &meta, const HostPort &hostAddr, @@ -1808,7 +1754,7 @@ void WorkerOcServiceGetImpl::TryGetFromL2CacheWhenNotFoundInWorker(const ObjectM bool ifWorkerConnected, ObjectKV &objectKV, Status &status) { - const ObjectMetaPb::ConfigPb &configPb = meta.config(); + const ConfigPb &configPb = meta.config(); bool writeToL2Storage = WriteMode(configPb.write_mode()) != WriteMode::NONE_L2_CACHE && WriteMode(configPb.write_mode()) != WriteMode::NONE_L2_CACHE_EVICT; // If a copy exists and the worker where the copy is located is disconnected, the data will not be cached locally @@ -1869,7 +1815,8 @@ Status WorkerOcServiceGetImpl::GetObjectFromPersistenceAndDump(ObjectKV &objectK return GetObjectFromPersistenceAndDumpWithoutCopyMeta(objectKV); } -Status WorkerOcServiceGetImpl::GetObjectFromQueryMetaResultOnLock(const master::QueryMetaInfoPb &queryMeta, +Status WorkerOcServiceGetImpl::GetObjectFromQueryMetaResultOnLock(const std::shared_ptr &request, + const master::QueryMetaInfoPb &queryMeta, std::vector &payloads, ReadObjectKV &objectKV) { @@ -1902,7 +1849,7 @@ Status WorkerOcServiceGetImpl::GetObjectFromQueryMetaResultOnLock(const master:: evictionManager_->Add(objectKey); } point.Record(); - return UpdateRequestForSuccess(objectKV); + return UpdateRequestForSuccess(objectKV, request); } Status WorkerOcServiceGetImpl::CreateCopyMetaToMaster(ObjectKV &objectKV) @@ -2011,18 +1958,14 @@ bool WorkerOcServiceGetImpl::HaveOtherAZ() return !FLAGS_other_cluster_names.empty(); } -Status WorkerOcServiceGetImpl::BatchLockForGet( - const std::vector &objectKeys, - std::map, bool>> &lockedEntries, - std::unordered_set &failObjects) +Status WorkerOcServiceGetImpl::BatchLockForGet(const std::set &readKeys, + std::map &lockedEntries, + std::unordered_set &failObjects) { Status lastRc; lockedEntries.clear(); - std::set toLockIds; - for (const auto &readKey : objectKeys) { - toLockIds.insert(readKey.objectKey); - } - for (const auto &objectKey : toLockIds) { + for (const auto &readKey : readKeys) { + const auto &objectKey = readKey.objectKey; std::shared_ptr entry; bool isInsert; Status rc = objectTable_->ReserveGetAndLock(objectKey, entry, isInsert); @@ -2030,7 +1973,7 @@ Status WorkerOcServiceGetImpl::BatchLockForGet( if (isInsert) { SetEmptyObjectEntry(objectKey, *entry); } - (void)lockedEntries.emplace(objectKey, std::make_pair(std::move(entry), isInsert)); + (void)lockedEntries.emplace(readKey, LockedEntity{ .safeObj = std::move(entry), .insert = isInsert }); } else { LOG(ERROR) << FormatString("[ObjectKey %s] GetObjectFromRemote failed: %s.", objectKey, rc.ToString()); failObjects.emplace(objectKey); @@ -2040,38 +1983,39 @@ Status WorkerOcServiceGetImpl::BatchLockForGet( return lastRc; } -void WorkerOcServiceGetImpl::BatchUnlockForGet( - const std::unordered_set &failedObjectKeys, - std::map, bool>> &lockedEntries) +void WorkerOcServiceGetImpl::BatchUnlockForGet(const std::unordered_set &failedObjectKeys, + std::map &lockedEntries) { for (auto &entry : lockedEntries) { - if (!entry.second.first->IsWLockedByCurrentThread()) { + const auto &objectKey = entry.first.objectKey; + auto &safeObj = entry.second.safeObj; + if (!safeObj->IsWLockedByCurrentThread()) { continue; } - if (failedObjectKeys.find(entry.first) != failedObjectKeys.end() && entry.second.second) { - (void)objectTable_->Erase(entry.first); + if (failedObjectKeys.find(objectKey) != failedObjectKeys.end() && entry.second.insert) { + (void)objectTable_->Erase(objectKey); } - entry.second.first->WUnlock(); + safeObj->WUnlock(); } } -void WorkerOcServiceGetImpl::BatchUnlockForGet( - const std::map &failedObjectKeys, - std::map, bool>> &lockedEntries) +void WorkerOcServiceGetImpl::BatchUnlockForGet(const std::map &objectKeys, + std::map &lockedEntries) { - for (const auto &kv : failedObjectKeys) { - auto iter = lockedEntries.find(kv.first); + for (const auto &kv : objectKeys) { + auto iter = lockedEntries.find(ReadKey(kv.first)); if (iter == lockedEntries.end()) { continue; } + auto &safeObj = iter->second.safeObj; // Not held by the current thread means that the previous process has been cleaned up. - if (!iter->second.first->IsWLockedByCurrentThread()) { + if (!safeObj->IsWLockedByCurrentThread()) { continue; } - if (iter->second.second) { - (void)objectTable_->Erase(iter->first); + if (iter->second.insert) { + (void)objectTable_->Erase(kv.first); } - iter->second.first->WUnlock(); + safeObj->WUnlock(); (void)lockedEntries.erase(iter); } } @@ -2082,7 +2026,7 @@ bool WorkerOcServiceGetImpl::IsInRemoteGetObject(const std::string &objectKey) return inRemoteGetIds_.find(objectKey) != inRemoteGetIds_.end(); } -void WorkerOcServiceGetImpl::AddInRemoteGetObjects(const std::vector &objectsNeedGetRemote) +void WorkerOcServiceGetImpl::AddInRemoteGetObjects(const std::set &objectsNeedGetRemote) { std::lock_guard l(inRemoteGetIdsMutex_); for (const auto &id : objectsNeedGetRemote) { @@ -2090,7 +2034,7 @@ void WorkerOcServiceGetImpl::AddInRemoteGetObjects(const std::vector &o } } -void WorkerOcServiceGetImpl::RemoveInRemoteGetObjects(const std::vector &objectsNeedGetRemote) +void WorkerOcServiceGetImpl::RemoveInRemoteGetObjects(const std::set &objectsNeedGetRemote) { std::lock_guard l(inRemoteGetIdsMutex_); for (const auto &id : objectsNeedGetRemote) { @@ -2140,7 +2084,7 @@ Status WorkerOcServiceGetImpl::GetMapOfObjectKeys(const std::vector> objKeysGrpByMaster; std::unordered_map> objKeysNotInHashRing; - RETURN_IF_NOT_OK(etcdCM_->GroupObjKeysByMasterHostPort(objectKeys, objKeysGrpByMaster, objKeysNotInHashRing)); + etcdCM_->GroupObjKeysByMasterHostPort(objectKeys, objKeysGrpByMaster, objKeysNotInHashRing); for (auto &[master, objs] : objKeysGrpByMaster) { HostPort workerAddr = master.GetAddressAndSaveDbName(); master::GetObjectLocationsReqPb masterReq; @@ -2267,7 +2211,7 @@ Status WorkerOcServiceGetImpl::Exist(const ExistReqPb &req, ExistRspPb &rsp) } RequestParam reqParam; - reqParam.objectKey = objectKeysToAbbrStr(req.object_keys()); + reqParam.objectKey = ObjectKeysToAbbrStr(req.object_keys()); posixPoint.Record(rc.GetCode(), "0", reqParam, rc.GetMsg()); workerOperationTimeCost.Append("Total Exist", timer.ElapsedMilliSecond()); LOG(INFO) << FormatString("The operations of Exist %s", workerOperationTimeCost.GetInfo()); diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h index 64491c370ab5c6be7dcfd6f3e1b22d6767415f97..3c032c283bed75e09a9be6ad7c6cefb8811b5db5 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h @@ -22,6 +22,7 @@ #include +#include "datasystem/common/object_cache/object_base.h" #include "datasystem/utils/status.h" #include "datasystem/common/ak_sk/ak_sk_manager.h" #include "datasystem/common/rpc/rpc_message.h" @@ -30,6 +31,7 @@ #include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" #include "datasystem/worker/object_cache/object_kv.h" #include "datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.h" +#include "datasystem/worker/object_cache/worker_request_manager.h" namespace datasystem { namespace object_cache { @@ -60,8 +62,8 @@ public: * @param[in] request Get request instance. * @return Status of the call. */ - Status ProcessObjectsNotExistInLocal(const std::vector &objectsNeedGetRemote, const int64_t subTimeout, - std::unordered_set &failedIds, std::vector &needRetryIds, + Status ProcessObjectsNotExistInLocal(const std::set &objectsNeedGetRemote, int64_t subTimeout, + std::unordered_set &failedIds, std::set &needRetryIds, const std::shared_ptr &request = nullptr); /** @@ -144,13 +146,6 @@ public: */ Status GetMetaInfo(const GetMetaInfoReqPb &req, GetMetaInfoRspPb &rsp); - /** - * @brief Check whether the specified object key is still in the Get process. - * @param[in] key The object key to check. - * @return true if the key is still in the Get process and should not be deleted; false otherwise. - */ - bool IsObjectInGetProcess(const std::string &key); - private: using ObjectKeysQueryMetaFailed = std::tuple, std::unordered_set>; @@ -166,6 +161,11 @@ private: std::vector payloads; }; + struct LockedEntity { + std::shared_ptr safeObj; + bool insert; + }; + /** * @brief Get map of objectKeys grouped by master. * @param[in] objectKeys The vector of objectkeys. @@ -185,11 +185,7 @@ private: * @param[in] clientId The client making this request. * @return Status of the call. */ - Status ProcessGetObjectRequest(const std::vector &objectKeys, - const std::unordered_map &offsetInfos, - std::shared_ptr<::datasystem::ServerUnaryWriterReader> serverApi, - const int64_t subTimeout, const std::string &clientId, - std::shared_ptr accessRecorderPoint, const GetReqPb &getReqPb); + Status ProcessGetObjectRequest(int64_t subTimeout, std::shared_ptr &request); /** * @brief Get one object from local. @@ -197,8 +193,7 @@ private: * @param[out] request The GetRequest. * @param[out] objectsNeedGetRemote These objects not exist in local. */ - void TryGetObjectFromLocal(const std::unordered_map &offsetInfos, - std::shared_ptr &request, std::vector &objectsNeedGetRemote); + Status TryGetObjectFromLocal(std::shared_ptr &request, std::set &remoteObjectKeys); /** * @brief Get one object from remote. @@ -208,7 +203,7 @@ private: * @return Status of the call */ Status TryGetObjectFromRemote(int64_t subTimeout, std::shared_ptr &request, - std::vector &objectsNeedGetRemote); + std::set remoteObjectKeys); /** * @brief Preprocess for get one object. @@ -218,8 +213,7 @@ private: * @param[out] localExistKeys vector of keys which exists in local mem. * @return Status of the call */ - Status PreProcessGetObject(const ReadKey &objectKey, std::shared_ptr &request, - std::vector &objectsNeedGetRemote, std::vector &localExistKeys); + Status PreProcessGetObject(const ReadKey &readKey, GetObjInfo &info, std::set &remoteObjectKeys); /** * @brief Preprocess for get one object from memory, in this case, we only hold RLock instead of WLock, because the @@ -232,9 +226,8 @@ private: * @param[out] localExistKeys vector of keys which exists in local mem. * @return Status of the call */ - Status RLockGetObjectFromMem(const ReadKey &objectKey, std::shared_ptr &request, - std::vector &objectsNeedGetRemote, bool &objIsValidInMem, - std::vector &localExistKeys); + Status RLockGetObjectFromMem(const ReadKey &readKey, GetObjInfo &info, std::set &remoteObjectKeys, + bool &objIsValidInMem); /** * @brief Try to get object from primary copy worker. @@ -245,7 +238,7 @@ private: * @return Status of the call */ Status TryGetObjectsFromPrimaryWorker(const std::string &primaryAddress, uint64_t dataSize, ReadObjectKV &objectKV, - std::vector &objectsNeedGetRemote); + std::set &objectsNeedGetRemote); /** * @brief Get object data from remote worker based on object meta. @@ -302,7 +295,7 @@ private: * @return Status of the call. */ Status AggregateAllocateHelper(const std::list &metas, - std::map, bool>> &lockedEntries, + std::map &lockedEntries, std::vector> &shmOwners, std::vector &shmIndexMapping); @@ -317,8 +310,8 @@ private: * @return Status of the call. */ template - Status PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV &objectKV, Req &reqPb, bool &shmUnitAllocated, - std::shared_ptr shmOwner = nullptr); + Status PrepareGetRequestHelper(uint64_t dataSize, ReadObjectKV &objectKV, Req &reqPb, bool &shmUnitAllocated, + std::shared_ptr shmOwner = nullptr); /** * @brief Pull object data from remote worker. @@ -357,8 +350,8 @@ private: * @brief Attempt to get object from local before query meta. * @param[in out] lockedEntries Object lock entries. */ - void AttemptGetObjectsLocally(const std::map &readKeys, - std::map, bool>> &lockedEntries); + void AttemptGetObjectsLocally(const std::shared_ptr &request, + std::map &lockedEntries); /** * @brief Query the metadata of the specified objects in the master. @@ -431,10 +424,9 @@ private: * @return Status of the call. */ Status GetObjectsFromAnywhere(std::vector &queryMetas, - const std::map &readKeys, const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds); + std::map &lockedEntries, + std::unordered_set &failedIds, std::set &needRetryIds); /** * @brief Get objects from anywhere parallelly. @@ -447,11 +439,12 @@ private: * @param[out] needRetryIds Need retry get id list. * @return Status of the call. */ - Status GetObjectsFromAnywhereParallelly( - const std::vector &queryMetas, const std::map &readKeys, - const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds); + Status GetObjectsFromAnywhereParallelly(const std::vector &queryMetas, + const std::shared_ptr &request, + std::vector &payloads, + std::map &lockedEntries, + std::unordered_set &failedIds, + std::set &needRetryIds); /** * @brief Get objects from anywhere serially. @@ -464,11 +457,10 @@ private: * @param[out] needRetryIds Need retry get id list. * @return Status of the call. */ - Status GetObjectsFromAnywhereSerially( - const std::vector &queryMetas, const std::map &readKeys, - const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds); + Status GetObjectsFromAnywhereSerially(const std::vector &queryMetas, + const std::shared_ptr &request, std::vector &payloads, + std::map &lockedEntries, + std::unordered_set &failedIds, std::set &needRetryIds); /** * @brief Get objects from anywhere batched. @@ -481,11 +473,10 @@ private: * @param[out] needRetryIds Need retry get id list. * @return Status of the call. */ - Status GetObjectsFromAnywhereBatched( - std::vector &queryMetas, const std::map &readKeys, - const std::shared_ptr &request, std::vector &payloads, - std::map, bool>> &lockedEntries, - std::unordered_set &failedIds, std::vector &needRetryIds); + Status GetObjectsFromAnywhereBatched(std::vector &queryMetas, + const std::shared_ptr &request, std::vector &payloads, + std::map &lockedEntries, + std::unordered_set &failedIds, std::set &needRetryIds); /** * @brief Get object data from remote cache (remote worker or redis) based on object meta. @@ -508,8 +499,8 @@ private: * @param[out] failedIds Failed get object keys. * @return Status of the call. */ - Status GetObjectsWithoutMeta(std::map &objectKeys, - std::map, bool>> &lockedEntries, + Status GetObjectsWithoutMeta(const std::map &objectKeys, + std::map &lockedEntries, std::unordered_set &failedIds); /** @@ -545,12 +536,12 @@ private: * @param[out] failedMetas Failed get object metas. * @return Status of the call. */ - Status BatchGetObjectFromRemoteOnLock( - const std::string &address, std::list &metas, const std::map &readKeys, - const std::shared_ptr &request, - std::map, bool>> &lockedEntries, - std::vector &successIds, std::vector &needRetryIds, - std::unordered_set &failedIds, std::list &failedMetas); + Status BatchGetObjectFromRemoteOnLock(const std::string &address, std::list &metas, + const std::shared_ptr &request, + std::map &lockedEntries, + std::vector &successIds, std::vector &needRetryIds, + std::unordered_set &failedIds, + std::list &failedMetas); /** * @brief Helper function to split query meta based off address and threshold. @@ -573,7 +564,7 @@ private: * @param[out] failedIds Failed get object keys. * @return Status of the call. */ - void BatchGetObjectHandleIndividualStatus(Status &status, const std::string &objectKey, ReadKey readKey, + void BatchGetObjectHandleIndividualStatus(Status &status, const ReadKey &readKey, std::vector &successIds, std::vector &needRetryIds, std::unordered_set &failedIds); @@ -590,12 +581,12 @@ private: * @param[out] failedMetas Failed get object metas. * @return Status of the call. */ - Status BatchGetObjectFromRemoteWorker( - const std::string &address, std::list &metas, const std::map &readKeys, - const std::shared_ptr &request, - std::map, bool>> &lockedEntries, - std::vector &successIds, std::vector &needRetryIds, - std::unordered_set &failedIds, std::list &failedMetas); + Status BatchGetObjectFromRemoteWorker(const std::string &address, std::list &metas, + const std::shared_ptr &request, + std::map &lockedEntries, + std::vector &successIds, std::vector &needRetryIds, + std::unordered_set &failedIds, + std::list &failedMetas); /** * @brief Helper function to construct batch get request. @@ -610,8 +601,7 @@ private: * @return Status of the call. */ Status ConstructBatchGetRequest(const std::string &address, std::list &metas, - const std::map &readKeys, - std::map, bool>> &lockedEntries, + std::map &lockedEntries, std::vector &successIds, std::vector &needRetryIds, std::unordered_set &failedIds, BatchGetObjectRemoteReqPb &reqPb); @@ -661,13 +651,12 @@ private: * @return Status of the call. */ Status ProcessBatchResponse(const std::string &address, Status &checkConnectStatus, - std::list &metas, const std::map &readKeys, - const std::shared_ptr &request, - std::map, bool>> &lockedEntries, - const Status &status, BatchGetObjectRemoteRspPb &rspPb, - std::vector &payloads, std::vector &successIds, - std::vector &needRetryIds, std::unordered_set &failedIds, - std::list &failedMetas, bool &dataSizeChange); + std::list &metas, const std::shared_ptr &request, + std::map &lockedEntries, const Status &status, + BatchGetObjectRemoteRspPb &rspPb, std::vector &payloads, + std::vector &successIds, std::vector &needRetryIds, + std::unordered_set &failedIds, std::list &failedMetas, + bool &dataSizeChange); /** * @brief Try get object from other AZ. @@ -713,7 +702,8 @@ private: * @param[out] objectKV The reserved and locked safe object and its corresponding objectKey. * @return Status of the call. */ - Status GetObjectFromQueryMetaResultOnLock(const master::QueryMetaInfoPb &queryMeta, + Status GetObjectFromQueryMetaResultOnLock(const std::shared_ptr &request, + const master::QueryMetaInfoPb &queryMeta, std::vector &payloads, ReadObjectKV &objectKV); /** @@ -762,8 +752,7 @@ private: * @param[out] failObjects Locked failed object list. * @return Status of the call. */ - Status BatchLockForGet(const std::vector &objectKeys, - std::map, bool>> &lockedEntries, + Status BatchLockForGet(const std::set &objectKeys, std::map &lockedEntries, std::unordered_set &failObjects); /** @@ -772,27 +761,27 @@ private: * @param[in] lockedEntries Locked entry list. */ void BatchUnlockForGet(const std::unordered_set &failedObjectKeys, - std::map, bool>> &lockedEntries); + std::map &lockedEntries); /** * @brief Batch unlock and erase for remote get via failed object keys. * @param[in] failedObjectKeys Failed object key list that needs to be unlocked and erase. * @param[in] lockedEntries Locked entry list. */ - void BatchUnlockForGet(const std::map &failedObjectKeys, - std::map, bool>> &lockedEntries); + void BatchUnlockForGet(const std::map &objectKeys, + std::map &lockedEntries); /** * @brief Add remote get object key list. * @param[in] objectKeys Object key list. */ - void AddInRemoteGetObjects(const std::vector &objectsNeedGetRemote); + void AddInRemoteGetObjects(const std::set &objectsNeedGetRemote); /** * @brief Remove remote get object key list. * @param[in] objectKeys Object key list. */ - void RemoveInRemoteGetObjects(const std::vector &objectsNeedGetRemote); + void RemoveInRemoteGetObjects(const std::set &objectsNeedGetRemote); /** * @brief Remove remote get object key. @@ -800,18 +789,6 @@ private: */ void RemoveInRemoteGetObject(const std::string &objectKey); - /** - * @brief Mark the specified object keys as being in the Get process. - * @param[in] keys The list of object keys that are entering the Get process. - */ - void MarkObjectsInGetProcess(const std::vector &keys); - - /** - * @brief Unmark the specified object keys from the Get process. - * @param[in] keys The list of object keys that are exiting the Get process. - */ - void UnmarkObjectsInGetProcess(const std::vector &keys); - /** * @brief Fill the GetObjMetaInfoRspPb. * @param[in] objectKeys The objects for obtaining meta info. @@ -898,6 +875,10 @@ private: std::shared_ptr workerBatchThreadPool_{ nullptr }; + std::shared_ptr workerBatchQueryMetaThreadPool_{ nullptr }; + + std::shared_ptr workerBatchRemoteGetThreadPool_{ nullptr }; + std::shared_ptr threadPool_{ nullptr }; std::unique_ptr remoteGetThreadPool_{ nullptr }; @@ -906,16 +887,12 @@ private: HostPort localAddress_; - std::shared_timed_mutex inRemoteGetIdsMutex_; // the mutex for inRemoteGetIds_ + std::shared_timed_mutex inRemoteGetIdsMutex_; // the mutex for inRemoteGetIds_ - std::unordered_set inRemoteGetIds_; // the object keys that in remote get + std::unordered_set inRemoteGetIds_; // the object keys that in remote get std::vector otherAZNames_; - std::shared_mutex objectsInGetProcessMutex_; // the mutex for objectsInGetProcess_ - - std::unordered_map objectsInGetProcess_; - static constexpr size_t OBJECTS_NOT_EXIST_IDX = 0; static constexpr size_t OBJECTS_PUZZLED_IDX = 1; }; diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp index b1451cddce9bb1739b5ba573a69d844d62780b34..61e8dadc5e61277b0b83734775573e74bd74bedf 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp @@ -36,6 +36,7 @@ #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/log/log.h" #include "datasystem/common/rpc/rpc_message.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/rpc_util.h" #include "datasystem/common/util/status_helper.h" @@ -586,7 +587,7 @@ Status WorkerOcServiceMigrateImpl::AllocateAndAssignData( shmUnit->AllocateMemory(tenantId, needSize, false, ServiceType::OBJECT, static_cast((*entry)->modeInfo.GetCacheType())), FormatString("[Migrate Data] %s allocate memory failed, size: %ld", objectKey, needSize)); - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); RETURN_IF_NOT_OK_PRINT_ERROR_MSG( shmUnit->MemoryCopy(payloads, memcpyThreadPool_, metaSize), FormatString("[Migrate Data] Memory copy failed, offset: %ld, size: %ld", metaSize, needSize)); diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.cpp index 336f2707d0d597ece90ae383fbc2fa9f10eb175c..9dda412cd503435b89bafb3627fa7fbb426ee89f 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.cpp @@ -25,6 +25,7 @@ #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/log/log.h" #include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/deadlock_util.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/raii.h" @@ -82,10 +83,10 @@ Status WorkerOcServiceMultiPublishImpl::MultiPublishImpl(const MultiPublishReqPb RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker::Authenticate(akSkManager_, req, tenantId), "Authenticate failed."); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(req.object_info_size()), StatusCode::K_INVALID, "invalid object info size"); - std::vector shmUnits; + std::vector shmUnits; for (const auto &info : req.object_info()) { if (!info.shm_id().empty()) { - shmUnits.emplace_back(info.shm_id()); + shmUnits.emplace_back(ShmKey::Intern(info.shm_id())); } } RETURN_IF_NOT_OK( @@ -104,6 +105,7 @@ Status WorkerOcServiceMultiPublishImpl::MultiPublishImpl(const MultiPublishReqPb std::vector namespaceUri; size_t objectSize = static_cast(req.object_info_size()); + namespaceUri.reserve(objectSize); for (size_t i = 0; i < objectSize; ++i) { namespaceUri.emplace_back( TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.object_info(i).object_key())); @@ -256,6 +258,7 @@ Status WorkerOcServiceMultiPublishImpl::MultiPublishObjectNtx(const MultiPublish std::vector &payloads, Status &lastRc) { // Save shmId to the entry, it also need to copy data to share memory if it's the small data. + const bool shmEnabled = ClientShmEnabled(req.client_id()); for (auto index = successIndex.begin(); index != successIndex.end();) { auto i = *index; if (entries[i]->Get() != nullptr) { @@ -274,7 +277,7 @@ Status WorkerOcServiceMultiPublishImpl::MultiPublishObjectNtx(const MultiPublish GetMetadataSize(), *entries[i]); } - lastRc = AttachShmUnitToObject(req.client_id(), objectKeys[i], req.object_info(i).shm_id(), + lastRc = AttachShmUnitToObject(shmEnabled, objectKeys[i], ShmKey::Intern(req.object_info(i).shm_id()), req.object_info(i).data_size(), *entries[i]); if (lastRc.IsError()) { LOG(ERROR) << "objKey: " << objectKeys[i] << " AttachShmUnitToObject failed, status: " << lastRc.ToString(); @@ -320,6 +323,7 @@ Status WorkerOcServiceMultiPublishImpl::MultiPublishObject(const MultiPublishReq std::vector &payloads) { // Save shmId to the entry, it also need to copy data to share memory if it's the small data. + const bool shmEnabled = ClientShmEnabled(req.client_id()); size_t idx = 0; for (size_t i = 0; i < objectKeys.size(); ++i) { if (entries[i]->Get() != nullptr) { @@ -332,7 +336,8 @@ Status WorkerOcServiceMultiPublishImpl::MultiPublishObject(const MultiPublishReq static_cast(req.cache_type()), req.object_info(i).data_size(), GetMetadataSize(), *entries[i]); } - RETURN_IF_NOT_OK(AttachShmUnitToObject(req.client_id(), objectKeys[i], req.object_info(i).shm_id(), + RETURN_IF_NOT_OK(AttachShmUnitToObject(shmEnabled, objectKeys[i], + ShmKey::Intern(req.object_info(i).shm_id()), req.object_info(i).data_size(), *entries[i])); // Small object use payload to transfer the value, the object and RpcMessage is one-to-one. if (req.object_info(i).shm_id().empty()) { @@ -349,18 +354,18 @@ Status WorkerOcServiceMultiPublishImpl::MultiPublishObject(const MultiPublishReq } } - CreateMultiMetaRspPb resp; + std::vector versions(objectKeys.size()); std::vector successIndex; for (size_t i = 0; i < objectKeys.size(); i++) { successIndex.emplace_back(i); } if (FLAGS_enable_distributed_master) { - RETURN_IF_NOT_OK(CreateMultiMetaToDistributedMaster(objectKeys, entries, req, resp)); + RETURN_IF_NOT_OK(CreateMultiMetaToDistributedMaster(objectKeys, entries, req, versions)); } else { - RETURN_IF_NOT_OK(CreateMultiMetaToCentralMaster(objectKeys, successIndex, entries, req, resp)); + RETURN_IF_NOT_OK(CreateMultiMetaToCentralMaster(objectKeys, entries, req, versions)); } - UpdateObjectAfterCreatingMeta(objectKeys, entries, resp, successIndex); + UpdateObjectAfterCreatingMeta(objectKeys, entries, versions, successIndex); return Status::OK(); } @@ -372,25 +377,27 @@ void WorkerOcServiceMultiPublishImpl::FillMultiMetaReqPhaseOne( { for (const auto &obj : objectKeys) { auto entry = entries[obj.second]; - datasystem::ObjectMetaPb metadata; + ObjectBaseInfoPb metadata; metadata.set_object_key(obj.first); metadata.set_data_size((*entry)->GetDataSize()); - metadata.set_life_state(static_cast(ObjectLifeState::OBJECT_PUBLISHED)); - metadata.set_ttl_second(pubReq.ttl_second()); - metadata.set_existence(static_cast<::datasystem::ExistenceOptPb>(pubReq.existence())); - ObjectMetaPb::ConfigPb *configPb = metadata.mutable_config(); - configPb->set_write_mode(static_cast((*entry)->modeInfo.GetWriteMode())); - configPb->set_data_format(static_cast((*entry)->stateInfo.GetDataFormat())); - configPb->set_consistency_type(static_cast((*entry)->modeInfo.GetConsistencyType())); - configPb->set_cache_type(pubReq.cache_type()); req.mutable_metas()->Add(std::move(metadata)); } - std::sort(req.mutable_metas()->begin(), req.mutable_metas()->end(), - [](const ObjectMetaPb &fir, const ObjectMetaPb &sec) { return fir.object_key() < sec.object_key(); }); + std::sort( + req.mutable_metas()->begin(), req.mutable_metas()->end(), + [](const ObjectBaseInfoPb &fir, const ObjectBaseInfoPb &sec) { return fir.object_key() < sec.object_key(); }); req.set_address(localAddress_.ToString()); req.set_istx(true); req.set_is_pre_commit(true); req.set_redirect(true); + req.set_life_state(static_cast(ObjectLifeState::OBJECT_PUBLISHED)); + req.set_ttl_second(pubReq.ttl_second()); + req.set_existence(static_cast<::datasystem::ExistenceOptPb>(pubReq.existence())); + auto &firstEntry = *entries[0]; + ConfigPb *configPb = req.mutable_config(); + configPb->set_write_mode(static_cast(firstEntry->modeInfo.GetWriteMode())); + configPb->set_data_format(static_cast(firstEntry->stateInfo.GetDataFormat())); + configPb->set_consistency_type(static_cast(firstEntry->modeInfo.GetConsistencyType())); + configPb->set_cache_type(pubReq.cache_type()); } Status WorkerOcServiceMultiPublishImpl::RetryRollbackMultiMetaWhenMoving(std::shared_ptr api, @@ -653,7 +660,7 @@ Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaPhaseOne( } Status WorkerOcServiceMultiPublishImpl::Process2PCResults(std::vector> &futures, - const ObjGroupMap &objGroup, CreateMultiMetaRspPb &resp) + const ObjGroupMap &objGroup, std::vector &versions) { std::vector> needRollBack; Status lastRc; @@ -665,16 +672,11 @@ Status WorkerOcServiceMultiPublishImpl::Process2PCResults(std::vectorGetHostPort()); - const size_t verSize = static_cast(res.rsp.version_size()); - if (objs.size() != verSize) { - LOG(WARNING) << FormatString("[Process2PCResults] The objs size(%lu) and version size(%d) does not match", - objs.size(), verSize); - } - if (verSize == 0) { - continue; - } for (size_t i = 0; i < objs.size(); i++) { - resp.set_version(objs[i].second, res.rsp.version(i % verSize)); + if (objs[i].second > versions.size()) { + continue; + } + versions[objs[i].second] = res.rsp.version(); } } if (lastRc.IsError()) { @@ -685,7 +687,7 @@ Status WorkerOcServiceMultiPublishImpl::Process2PCResults(std::vector &versions) { std::vector> apis(objGroup.size()); std::vector> objs(objGroup.size()); @@ -724,12 +726,12 @@ Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaPhaseTwo(const ObjGroupMa return CreateMeta2PCRes{ rc, rsp, apis[i] }; })); } - return Process2PCResults(futures, objGroup, resp); + return Process2PCResults(futures, objGroup, versions); } Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToDistributedMaster( const std::vector &objectKeys, const std::vector> &entries, - const MultiPublishReqPb &pubReq, CreateMultiMetaRspPb &resp) + const MultiPublishReqPb &pubReq, std::vector &versions) { CHECK_FAIL_RETURN_STATUS(!asyncRollbackManager_->IsObjectsInRollBack(objectKeys), K_OC_KEY_ALREADY_EXIST, "The object is being rolled back."); @@ -750,23 +752,24 @@ Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToDistributedMaster( // If phase one causes an asynchronous rollback, we need return. CHECK_FAIL_RETURN_STATUS(!asyncRollbackManager_->IsObjectsInRollBack(objectKeys), K_OC_KEY_ALREADY_EXIST, "The object is being rolled back."); - resp.mutable_version()->Resize(objectKeys.size(), 0); - RETURN_IF_NOT_OK(CreateMultiMetaPhaseTwo(objGroup, pubReq, resp)); + RETURN_IF_NOT_OK(CreateMultiMetaPhaseTwo(objGroup, pubReq, versions)); return Status::OK(); } Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToDistributedMasterNtx( const std::vector &objectKeys, std::vector &successIndex, const std::vector> &entries, const MultiPublishReqPb &pubReq, - CreateMultiMetaRspPb &totalResp) + CreateMultiMetaRspPb &totalResp, std::vector &versions) { ObjGroupMap objGroup; + std::unordered_map> workerAddrToApi; for (const auto index : successIndex) { std::shared_ptr api; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerMasterApiManager_->GetWorkerMasterApi(objectKeys[index], etcdCM_, api), "Getting master api failed. ObjectKey = " + objectKeys[index]); auto addr = api->GetHostPort(); objGroup[addr].emplace_back(objectKeys[index], index); + workerAddrToApi[addr] = api; } std::vector> apis(objGroup.size()); @@ -775,13 +778,8 @@ Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToDistributedMasterNtx( int idx = 0; for (const auto &[masterAddr, objInfos] : objGroup) { auto &req = createReqs[idx]; - RETURN_IF_NOT_OK(workerMasterApiManager_->GetWorkerMasterApiByAddr(masterAddr, etcdCM_, apis[idx])); - for (const auto &obj : objInfos) { - ConstructCreateReq(objectKeys[obj.second], entries[obj.second], pubReq, - pubReq.object_info(obj.second).blob_sizes(), req); - } - req.set_address(localAddress_.ToString()); - req.set_istx(false); + apis[idx] = workerAddrToApi[masterAddr]; + ConstructCreateReq(objInfos, entries, pubReq, req); idx++; } @@ -792,18 +790,14 @@ Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToDistributedMasterNtx( CHECK_FAIL_RETURN_STATUS(respRes.size() == createReqs.size(), K_RUNTIME_ERROR, "The object size and the versions is not equal"); - totalResp.mutable_version()->Resize(objectKeys.size(), 0); for (size_t index = 0; index < respRes.size(); index++) { auto &resp = respRes[index].rsp; for (const auto &failedId : resp.failed_object_keys()) { totalResp.add_failed_object_keys(failedId); } LOG_IF_ERROR(respRes[index].rc, "Get error with createMeta"); - CHECK_FAIL_RETURN_STATUS(createReqs[index].metas().size() == resp.version().size(), K_RUNTIME_ERROR, - "The object size and the versions is not equal"); - uint64_t version = resp.version_size() > 0 ? resp.version(0) : 0; for (const auto &obj : objGroup.at(respRes[index].api->GetHostPort())) { - totalResp.set_version(obj.second, version); + versions[obj.second] = resp.version(); } if (resp.has_last_rc() && static_cast(resp.last_rc().error_code()) != StatusCode::K_OK) { totalResp.mutable_last_rc()->set_error_code(resp.last_rc().error_code()); @@ -813,32 +807,69 @@ Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToDistributedMasterNtx( return Status::OK(); } -void WorkerOcServiceMultiPublishImpl::ConstructCreateReq(const std::string &objectKey, - const std::shared_ptr &entry, - const MultiPublishReqPb &pubReq, - const google::protobuf::RepeatedField blobSizes, - CreateMultiMetaReqPb &req) +void WorkerOcServiceMultiPublishImpl::ConstructCreateReqCommon(SafeObjType &entry, const MultiPublishReqPb &pubReq, + CreateMultiMetaReqPb &req) { - datasystem::ObjectMetaPb metadata; - metadata.set_object_key(objectKey); - metadata.set_data_size((*entry)->GetDataSize()); - metadata.set_life_state(static_cast(ObjectLifeState::OBJECT_PUBLISHED)); - metadata.set_ttl_second(pubReq.ttl_second()); - metadata.set_existence(static_cast<::datasystem::ExistenceOptPb>(pubReq.existence())); - metadata.mutable_device_info()->mutable_blob_sizes()->Add(blobSizes.begin(), blobSizes.end()); - ObjectMetaPb::ConfigPb *configPb = metadata.mutable_config(); - configPb->set_write_mode(static_cast((*entry)->modeInfo.GetWriteMode())); - configPb->set_data_format(static_cast((*entry)->stateInfo.GetDataFormat())); - configPb->set_consistency_type(static_cast((*entry)->modeInfo.GetConsistencyType())); - configPb->set_cache_type(static_cast((*entry)->modeInfo.GetCacheType())); + if (pubReq.istx()) { + // Optimize the scenario when worker1 set key1 and key2, while worker2 set key2 and key1, if both request + // arrives the master concurrently, master generate meta for key1 of the worker1 and key2 of the worker2 + // initially, then master try to process key2 of the worker1 and key1 of the worker2, it will find the keys have + // already been occupied that will caused both request failed. + std::sort(req.mutable_metas()->begin(), req.mutable_metas()->end(), + [](const ObjectBaseInfoPb &fir, const ObjectBaseInfoPb &sec) { + return fir.object_key() < sec.object_key(); + }); + } + req.set_address(localAddress_.ToString()); + req.set_istx(pubReq.istx()); + req.set_life_state(static_cast(ObjectLifeState::OBJECT_PUBLISHED)); + req.set_ttl_second(pubReq.ttl_second()); + req.set_existence(static_cast<::datasystem::ExistenceOptPb>(pubReq.existence())); + ConfigPb *configPb = req.mutable_config(); + configPb->set_write_mode(static_cast(entry->modeInfo.GetWriteMode())); + configPb->set_data_format(static_cast(entry->stateInfo.GetDataFormat())); + configPb->set_consistency_type(static_cast(entry->modeInfo.GetConsistencyType())); + configPb->set_cache_type(static_cast(entry->modeInfo.GetCacheType())); configPb->set_is_replica(pubReq.is_replica()); - req.mutable_metas()->Add(std::move(metadata)); +} + +void WorkerOcServiceMultiPublishImpl::ConstructCreateReq(const std::vector> &objectInfos, + const std::vector> &entries, + const MultiPublishReqPb &pubReq, CreateMultiMetaReqPb &req) +{ + for (const auto &[objectKey, i] : objectInfos) { + ObjectBaseInfoPb meta; + meta.set_object_key(objectKey); + meta.set_data_size((*entries[i])->GetDataSize()); + if (pubReq.object_info(i).blob_sizes_size() > 0) { + meta.mutable_device_info()->mutable_blob_sizes()->Add(pubReq.object_info(i).blob_sizes().begin(), + pubReq.object_info(i).blob_sizes().end()); + } + req.mutable_metas()->Add(std::move(meta)); + } + ConstructCreateReqCommon(*entries[0], pubReq, req); +} + +void WorkerOcServiceMultiPublishImpl::ConstructCreateReq(const std::vector &objectInfos, + const std::vector> &entries, + const MultiPublishReqPb &pubReq, CreateMultiMetaReqPb &req) +{ + for (size_t i = 0; i < objectInfos.size(); i++) { + ObjectBaseInfoPb meta; + meta.set_object_key(objectInfos[i]); + meta.set_data_size((*entries[i])->GetDataSize()); + if (pubReq.object_info(i).blob_sizes_size() > 0) { + meta.mutable_device_info()->mutable_blob_sizes()->Add(pubReq.object_info(i).blob_sizes().begin(), + pubReq.object_info(i).blob_sizes().end()); + } + req.mutable_metas()->Add(std::move(meta)); + } + ConstructCreateReqCommon(*entries[0], pubReq, req); } Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToCentralMaster( - const std::vector &objectKeys, std::vector &successIndex, - const std::vector> &entries, const MultiPublishReqPb &pubReq, - CreateMultiMetaRspPb &resp) + const std::vector &objectKeys, const std::vector> &entries, + const MultiPublishReqPb &pubReq, std::vector &versions) { std::shared_ptr workerMasterApi = workerMasterApiManager_->GetWorkerMasterApi(objectKeys[0], etcdCM_); @@ -847,38 +878,26 @@ Status WorkerOcServiceMultiPublishImpl::CreateMultiMetaToCentralMaster( LOG(INFO) << FormatString("Create meta to master[%s]", workerMasterApi->GetHostPort()); CreateMultiMetaReqPb req; - for (auto index = successIndex.begin(); index != successIndex.end(); index++) { - auto i = *index; - auto blobSizes = pubReq.object_info(i).blob_sizes(); - ConstructCreateReq(objectKeys[i], entries[i], pubReq, blobSizes, req); - } - if (pubReq.istx()) { - // Optimize the scenario when worker1 set key1 and key2, while worker2 set key2 and key1, if both request - // arrives the master concurrently, master generate meta for key1 of the worker1 and key2 of the worker2 - // initially, then master try to process key2 of the worker1 and key1 of the worker2, it will find the keys have - // already been occupied that will caused both request failed. - std::sort(req.mutable_metas()->begin(), req.mutable_metas()->end(), - [](const ObjectMetaPb &fir, const ObjectMetaPb &sec) { return fir.object_key() < sec.object_key(); }); - } - req.set_address(localAddress_.ToString()); - req.set_istx(pubReq.istx()); + std::vector> objectInfos; + ConstructCreateReq(objectKeys, entries, pubReq, req); PerfPoint point(PerfKey::WORKER_CREATE_MULTI_META); + + CreateMultiMetaRspPb resp; Status status = RetryWhenDeadlock([&workerMasterApi, &req, &resp] { return workerMasterApi->CreateMultiMeta(req, resp); }); point.Record(); - if (status.IsOk()) { - CHECK_FAIL_RETURN_STATUS(req.metas().size() == resp.version().size(), K_RUNTIME_ERROR, - "The object size and the versions is not equal"); + for (auto &version : versions) { + version = resp.version(); } return status; } void WorkerOcServiceMultiPublishImpl::UpdateObjectAfterCreatingMeta(std::vector &objectKeys, std::vector> entries, - const master::CreateMultiMetaRspPb &rsp, + const std::vector &versions, std::vector &successIndex) { - auto rollBackPersistenceIfFail = [this, &rsp, &objectKeys](const Status &rc, const ObjectKV &kv, int idx) { + auto rollBackPersistenceIfFail = [this, &versions, &objectKeys](const Status &rc, const ObjectKV &kv, int idx) { if (rc.IsError()) { LOG(ERROR) << FormatString("Multiple set fails to save object %s to l2cache.", objectKeys[idx]); std::shared_ptr workerMasterApi = @@ -887,15 +906,17 @@ void WorkerOcServiceMultiPublishImpl::UpdateObjectAfterCreatingMeta(std::vector< master::RollbackMultiMetaRspPb resp; req.set_persistence_only(true); req.add_object_keys(kv.GetObjKey()); - req.add_versions(rsp.version(idx)); + req.add_versions(versions[idx]); workerMasterApi->RollbackMultiMeta(req, resp); } }; + std::vector objectKeysSucc; + objectKeysSucc.reserve(successIndex.size()); for (auto index = successIndex.begin(); index != successIndex.end(); index++) { auto i = *index; ObjectKV objectKV(objectKeys[i], *entries[i]); - objectKV.GetObjEntry()->SetCreateTime(rsp.version(i)); + objectKV.GetObjEntry()->SetCreateTime(versions[i]); // Save object to L2 cache if ((*entries[i])->IsWriteThroughMode()) { if (IsSupportL2Storage(supportL2Storage_)) { @@ -910,8 +931,6 @@ void WorkerOcServiceMultiPublishImpl::UpdateObjectAfterCreatingMeta(std::vector< } // Update entry information (*entries[i])->stateInfo.SetNeedToDelete(false); - LOG_IF_ERROR(workerRequestManager_.UpdateRequestForPublish(objectKV, memoryRefTable_), - FormatString("Multiple set fails to update object %s get request.", objectKeys[i])); (*entries[i])->SetLifeState(ObjectLifeState::OBJECT_PUBLISHED); (*entries[i])->stateInfo.SetPrimaryCopy(true); (*entries[i])->stateInfo.SetCacheInvalid(false); @@ -921,8 +940,27 @@ void WorkerOcServiceMultiPublishImpl::UpdateObjectAfterCreatingMeta(std::vector< LOG_IF_ERROR(DeleteObjectFromDisk(objectKV), FormatString("Multiple set fails to delete spilled object %s from disk.", objectKeys[i])); } - evictionManager_->Add(objectKeys[i]); + objectKeysSucc.emplace_back(objectKeys[i]); } + threadPool_->Execute([this, objectKeys = std::move(objectKeysSucc)]() { + Status rc; + for (const auto &key : objectKeys) { + evictionManager_->Add(key); + std::shared_ptr entry; + Raii raii([key, &rc]() { LOG_IF_ERROR(rc, FormatString("Fails to update object %s get request.", key)); }); + rc = objectTable_->Get(key, entry); + if (rc.IsError()) { + continue; + } + rc = entry->RLock(); + if (rc.IsError()) { + continue; + } + ObjectKV objectKV(key, *entry); + rc = workerRequestManager_.NotifyPendingGetRequest(objectKV); + entry->RUnlock(); + } + }); } Status WorkerOcServiceMultiPublishImpl::SendToMasterAndUpdateObject( @@ -936,7 +974,8 @@ Status WorkerOcServiceMultiPublishImpl::SendToMasterAndUpdateObject( objectIndexMap[objectKeys[index]] = index; } - RETURN_IF_NOT_OK(CreateMultiMetaToDistributedMasterNtx(objectKeys, successIndex, entries, req, rsp)); + std::vector versions(objectKeys.size()); + RETURN_IF_NOT_OK(CreateMultiMetaToDistributedMasterNtx(objectKeys, successIndex, entries, req, rsp, versions)); std::vector failedObjectKey(rsp.failed_object_keys().begin(), rsp.failed_object_keys().end()); std::set sortedFailedIndex; @@ -972,7 +1011,7 @@ Status WorkerOcServiceMultiPublishImpl::SendToMasterAndUpdateObject( Status recvRc(static_cast(rsp.last_rc().error_code()), rsp.last_rc().error_msg()); lastRc = recvRc.IsOk() ? lastRc : recvRc; } - UpdateObjectAfterCreatingMeta(objectKeys, entries, rsp, successIndex); + UpdateObjectAfterCreatingMeta(objectKeys, entries, versions, successIndex); return Status::OK(); } @@ -1071,6 +1110,9 @@ Status WorkerOcServiceMultiPublishImpl::BatchLockForSetNtx(const std::vector objectToIdx; for (size_t i = 0; i < keyNum; ++i) { + if (objectToIdx.find(objectKeys[i]) != objectToIdx.end()) { + continue; + } objectToIdx[objectKeys[i]] = i; } std::vector isFinish(keyNum, false); diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.h b/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.h index 9eb55c3df523e8b479ae2f18db7dbecdb1c5cfc4..2e5270cc47a7b003e46134ee0645c04d4b6eb04d 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.h +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_multi_publish_impl.h @@ -134,29 +134,45 @@ private: /** * @brief Create or update metadata to master, object will be unlocked during requesting master. * @param[in] objectKeys Object key list. - * @param[in] successIndex success index of object list * @param[in] entries The object entries. * @param[in] pubReq The request of multipublish. - * @param[out] resp responese info of CreateMultiMeta + * @param[out] versions The versions of objects * @return Status of the call. */ - Status CreateMultiMetaToCentralMaster(const std::vector &objectKeys, std::vector &successIndex, + Status CreateMultiMetaToCentralMaster(const std::vector &objectKeys, const std::vector> &entries, - const MultiPublishReqPb &pubReq, master::CreateMultiMetaRspPb &resp); + const MultiPublishReqPb &pubReq, std::vector &versions); + + /** + * @brief Construct the request info for create multiple meta. + * @param[in] entry The object entry. + * @param[in] pubReq The request of multipublish. + * @param[out] req The multimeta request to construct. + */ + void ConstructCreateReqCommon(SafeObjType &entry, const MultiPublishReqPb &pubReq, + master::CreateMultiMetaReqPb &req); /** * @brief Construct the request info for create multiple meta. - * @param[in] objectKey Object key . - * @param[in] entries The object entry. + * @param[in] objectInfos Object info list. + * @param[in] entries The object entries. * @param[in] pubReq The request of multipublish. - * @param[in] blobSizes the blob size of key - * @param[out] req request info of CreateMultiMetaReqPb + * @param[out] req The multimeta request to construct. */ - void ConstructCreateReq(const std::string &objectKey, const std::shared_ptr &entry, - const MultiPublishReqPb &pubReq, - const google::protobuf::RepeatedField blobSizes, + void ConstructCreateReq(const std::vector> &objectInfos, + const std::vector> &entries, const MultiPublishReqPb &pubReq, master::CreateMultiMetaReqPb &req); + /** + * @brief Construct the request info for create multiple meta. + * @param[in] objectKeys Object key list. + * @param[in] entries The object entries. + * @param[in] pubReq The request of multipublish. + * @param[out] req The multimeta request to construct. + */ + void ConstructCreateReq(const std::vector &objectKeys, + const std::vector> &entries, const MultiPublishReqPb &pubReq, + master::CreateMultiMetaReqPb &req); /** * @brief Create or update metadata to master, object will be unlocked during requesting master. * @param[in] objectKeys Object key list. @@ -164,12 +180,14 @@ private: * @param[in] entries The object entries. * @param[in] pubReq The request of multipublish. * @param[out] resp responese info of CreateMultiMeta + * @param[out] versions The versions of objects. * @return Status of the call. */ Status CreateMultiMetaToDistributedMasterNtx(const std::vector &objectKeys, std::vector &successIndex, const std::vector> &entries, - const MultiPublishReqPb &pubReq, master::CreateMultiMetaRspPb &resp); + const MultiPublishReqPb &pubReq, master::CreateMultiMetaRspPb &resp, + std::vector &versions); /** * @brief Fill multimeta request. @@ -187,12 +205,12 @@ private: * @param[in] objectKeys Object key list. * @param[in] entries The object entries. * @param[in] pubReq The request of multipublish. - * @param[out] resp The responese info of CreateMultiMeta. + * @param[out] versions The versions of objects. * @return Status of the call. */ Status CreateMultiMetaToDistributedMaster(const std::vector &objectKeys, const std::vector> &entries, - const MultiPublishReqPb &pubReq, master::CreateMultiMetaRspPb &resp); + const MultiPublishReqPb &pubReq, std::vector &versions); /** * @brief Create multimeta request to master in parallel. @@ -242,21 +260,21 @@ private: * @brief Create multimeta phase two request to master. * @param[in] objGroup The group of objects. * @param[in] pubReq The request of multipublish. - * @param[out] resp The responese info of CreateMultiMeta. + * @param[out] versions The versions of objects. * @return Status of the call. */ Status CreateMultiMetaPhaseTwo(const ObjGroupMap &objGroup, const MultiPublishReqPb &pubReq, - master::CreateMultiMetaRspPb &resp); + std::vector &versions); /** * @brief Process 2PC results. * @param[in] futures The 2PC request futures. * @param[in] objGroup The group of objects. - * @param[out] resp The responese info of CreateMultiMeta. + * @param[out] versions The versions of objects. * @return Status of the call. */ Status Process2PCResults(std::vector> &futures, const ObjGroupMap &objGroup, - master::CreateMultiMetaRspPb &resp); + std::vector &versions); /** * @brief Rollback metadata request to master. @@ -300,12 +318,12 @@ private: * @brief Fill the entry and save object to L2 cache if success to create meta. * @param[in] objectKeys Object key list. * @param[in] entries The object entries. - * @param[in] rsp Response from master. + * @param[in] versions The versions of objects. * @return K_OK on success; the error code otherwise. */ void UpdateObjectAfterCreatingMeta(std::vector &objectKeys, std::vector> entries, - const master::CreateMultiMetaRspPb &rsp, std::vector &successIndex); + const std::vector &versions, std::vector &successIndex); /** * @brief Publish newly objects. This function will publish entry and save data to cache. diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_publish_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_publish_impl.cpp index bb25e0435d42ee723e20255ca7bea164fa9b2bcb..455302b37da8cf8cbadbd093a3e3b174fcf1360f 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_publish_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_publish_impl.cpp @@ -26,6 +26,7 @@ #include "datasystem/common/log/log.h" #include "datasystem/common/l2cache/l2_storage.h" #include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/deadlock_util.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/raii.h" @@ -90,7 +91,8 @@ Status WorkerOcServicePublishImpl::PrepareForPublish(const PublishReqPb &req, Ob RETURN_IF_NOT_OK(CheckIfL2CacheNeededAndWritable(supportL2Storage_, WriteMode(req.write_mode()))); - return AttachShmUnitToObject(req.client_id(), objectKey, req.shm_id(), req.data_size(), safeObj); + return AttachShmUnitToObject(ClientShmEnabled(req.client_id()), objectKey, ShmKey::Intern(req.shm_id()), + req.data_size(), safeObj); } Status WorkerOcServicePublishImpl::CreateMetadataToMaster(const ObjectKV &objectKV, const PublishParams ¶ms, @@ -110,7 +112,7 @@ Status WorkerOcServicePublishImpl::CreateMetadataToMaster(const ObjectKV &object metadata->set_life_state(static_cast(params.lifeState)); metadata->set_ttl_second(params.ttlSecond); metadata->set_existence(static_cast<::datasystem::ExistenceOptPb>(params.existence)); - ObjectMetaPb::ConfigPb *configPb = metadata->mutable_config(); + ConfigPb *configPb = metadata->mutable_config(); configPb->set_write_mode(static_cast(safeObj->modeInfo.GetWriteMode())); configPb->set_data_format(static_cast(safeObj->stateInfo.GetDataFormat())); configPb->set_consistency_type(static_cast(safeObj->modeInfo.GetConsistencyType())); @@ -265,7 +267,7 @@ Status WorkerOcServicePublishImpl::PublishObject(ObjectKV &objectKV, const Publi // Step 3: Notify GetRequest for subscription purpose. safeObj->stateInfo.SetNeedToDelete(false); - RETURN_IF_NOT_OK(workerRequestManager_.UpdateRequestForPublish(objectKV, memoryRefTable_)); + RETURN_IF_NOT_OK(workerRequestManager_.NotifyPendingGetRequest(objectKV)); safeObj->SetLifeState(params.lifeState); safeObj->stateInfo.SetPrimaryCopy(true); safeObj->stateInfo.SetCacheInvalid(false); @@ -344,7 +346,7 @@ Status WorkerOcServicePublishImpl::PublishImpl(const PublishReqPb &req, PublishR LOG(INFO) << FormatString("[ObjectKey %s] is being publishing [Sz: %zu].", req.object_key(), req.data_size()); std::string tenantId; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker::Authenticate(akSkManager_, req, tenantId), "Authenticate failed."); - std::vector shmUnits = { req.shm_id() }; + std::vector shmUnits = { ShmKey::Intern(req.shm_id()) }; RETURN_IF_NOT_OK( WorkerOcServiceCrudCommonApi::CheckShmUnitByTenantId(tenantId, req.client_id(), shmUnits, memoryRefTable_)); std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.object_key()); diff --git a/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp b/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp index 1dc683a630c688ca1bb98746b4e15252cb7c5017..25bcc798909b5e83bb82639317c67f47704b8ceb 100644 --- a/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp +++ b/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp @@ -57,9 +57,11 @@ #include "datasystem/common/object_cache/object_base.h" #include "datasystem/common/object_cache/object_bitmap.h" #include "datasystem/common/object_cache/safe_object.h" +#include "datasystem/common/parallel/parallel_for.h" #include "datasystem/common/rpc/rpc_auth_key_manager.h" #include "datasystem/common/rpc/rpc_stub_cache_mgr.h" #include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/deadlock_util.h" #include "datasystem/common/util/format.h" #include "datasystem/common/util/gflag/common_gflags.h" @@ -196,18 +198,19 @@ Status WorkerOCServiceImpl::InitL2Cache() void WorkerOCServiceImpl::InitServiceImpl() { - WorkerOcServiceCrudParam param{ .workerMasterApiManager = workerMasterApiManager_, - .workerRequestManager = workerRequestManager_, - .memoryRefTable = memoryRefTable_, - .objectTable = objectTable_, - .evictionManager = evictionManager_, - .workerDevOcManager = workerDevOcManager_, - .asyncPersistenceDelManager = asyncPersistenceDelManager_, - .asyncSendManager = asyncSendManager_, - .asyncRollbackManager = asyncRollbackManager_, - .metadataSize = metadataSize_, - .persistenceApi = persistenceApi_, - .etcdCM = etcdCM_, + WorkerOcServiceCrudParam param{ + .workerMasterApiManager = workerMasterApiManager_, + .workerRequestManager = workerRequestManager_, + .memoryRefTable = memoryRefTable_, + .objectTable = objectTable_, + .evictionManager = evictionManager_, + .workerDevOcManager = workerDevOcManager_, + .asyncPersistenceDelManager = asyncPersistenceDelManager_, + .asyncSendManager = asyncSendManager_, + .asyncRollbackManager = asyncRollbackManager_, + .metadataSize = metadataSize_, + .persistenceApi = persistenceApi_, + .etcdCM = etcdCM_, }; createProc_ = std::make_shared(param, etcdCM_, akSkManager_); @@ -242,6 +245,7 @@ Status WorkerOCServiceImpl::Init() CHECK_FAIL_RETURN_STATUS(workerMasterApi != nullptr, K_RUNTIME_ERROR, "Hash master get failed, Init failed"); RETURN_IF_EXCEPTION_OCCURS(threadPool_ = std::make_shared(FLAGS_oc_thread_num, 0, "OcGetThread")); RETURN_IF_EXCEPTION_OCCURS(memCpyThreadPool_ = std::make_shared(MEMCOPY_THREAD_NUM)); + datasystem::Parallel::InitParallelThreadPool(PARALLEL_THREAD_NUM); constexpr uint32_t gcThrdNum = 4; RETURN_IF_EXCEPTION_OCCURS(gcThreadPool_ = std::make_unique(gcThrdNum, 0, "OcCleanClient")); @@ -349,13 +353,13 @@ Status WorkerOCServiceImpl::MultiPublish(const MultiPublishReqPb &req, MultiPubl RETURN_IF_NOT_OK(multiPublishProc_->MultiPublish(req, resp, payloads)); if (req.auto_release_memory_ref()) { std::set failedSet{ resp.failed_object_keys().begin(), resp.failed_object_keys().end() }; - std::vector shmIds; + std::vector shmIds; shmIds.reserve(req.object_info_size()); for (auto &info : req.object_info()) { if (failedSet.find(info.object_key()) != failedSet.end()) { continue; } - shmIds.emplace_back(info.shm_id()); + shmIds.emplace_back(ShmKey::Intern(info.shm_id())); } VLOG(1) << "auto release ref " << VectorToString(shmIds); return DecreaseMemoryRef(req.client_id(), shmIds); @@ -997,7 +1001,7 @@ void WorkerOCServiceImpl::FillMetadata(const std::string &objectKey, const MetaA if ((*currSafeObj)->IsBinary()) { SetObjectMetaFields(metadata, objectKey, *currSafeObj); } - ObjectMetaPb::ConfigPb *configPb = metadata->mutable_config(); + ConfigPb *configPb = metadata->mutable_config(); configPb->set_write_mode((uint64_t)(*currSafeObj)->modeInfo.GetWriteMode()); configPb->set_data_format((uint64_t)(*currSafeObj)->stateInfo.GetDataFormat()); isFill = true; @@ -1149,7 +1153,7 @@ Status WorkerOCServiceImpl::MultiCreate(const MultiCreateReqPb &req, MultiCreate Raii raii([&returnStatus, &accessPoint, &req]() { auto &key = req.object_key().empty() ? "" : req.object_key(0); accessPoint.Record(returnStatus.GetCode(), std::to_string(key.size()), - RequestParam{ .objectKey = objectKeysToAbbrStr(req.object_key()) }, returnStatus.GetMsg()); + RequestParam{ .objectKey = ObjectKeysToAbbrStr(req.object_key()) }, returnStatus.GetMsg()); }); ReadLock noRecon; returnStatus = ValidateWorkerState(noRecon, reqTimeoutDuration.CalcRemainingTime()); @@ -1298,7 +1302,7 @@ Status WorkerOCServiceImpl::RefreshMeta(const std::string &clientId) LOG_IF_ERROR(TryUnShmQueueLatch(lockId), "Failed to clear locked id"); }; // 0th: Release the object buffer resource - std::vector shmIds; + std::vector shmIds; memoryRefTable_->GetClientRefIds(clientId, shmIds); for (const auto &shmId : shmIds) { std::shared_ptr shmUnit; @@ -1424,7 +1428,7 @@ void WorkerOCServiceImpl::EraseFailedWorkerMasterApi(HostPort &masterAddr) } Status WorkerOCServiceImpl::GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, - std::string &id) + ShmKey &id) { size_t index = lockId / SHM_QUEUE_SLOT_NUM; std::shared_ptr circularQueue; @@ -1458,7 +1462,7 @@ void WorkerOCServiceImpl::DecreaseHandlerForShmQueue(uint8_t *element) return; } std::string byteShmId((char *)element + sizeof(uint32_t), UUID_SIZE); - std::string shmId = BytesUuidToString(byteShmId); + auto shmId = ShmKey::Intern(BytesUuidToString(byteShmId)); std::string byteClientId((char *)element + sizeof(uint32_t) + UUID_SIZE, UUID_SIZE); std::string clientId = BytesUuidToString(byteClientId); // to do clear all client ref with the shmId; @@ -1492,7 +1496,7 @@ Status WorkerOCServiceImpl::InitShmCircularQueue(std::shared_ptr(); - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(shmUnit->AllocateMemory(DEFAULT_TENANT_ID, memorySize, true), "Allocate memory failed"); auto result = memset_s(shmUnit->pointer, memorySize, 0, memorySize); @@ -1558,7 +1562,7 @@ Status WorkerOCServiceImpl::StartDecreaseReferenceProcess() return Status::OK(); } -Status WorkerOCServiceImpl::DecreaseMemoryRef(const std::string &clientId, const std::vector &shmIds) +Status WorkerOCServiceImpl::DecreaseMemoryRef(const std::string &clientId, const std::vector &shmIds) { workerOperationTimeCost.Clear(); Timer timer; @@ -1585,7 +1589,11 @@ Status WorkerOCServiceImpl::DecreaseReference(const DecreaseReferenceRequest &re if (req.object_keys_size() > 0) { LOG(INFO) << FormatString("[shmId %s] [client: %s] DoDecrease", req.object_keys(0), req.client_id()); } - std::vector shmIds = { req.object_keys().begin(), req.object_keys().end() }; + std::vector shmIds; + // Although the field in pb is called object key, its content is actually shmId, which is very misleading. + shmIds.reserve(req.object_keys().size()); + std::transform(req.object_keys().begin(), req.object_keys().end(), std::back_inserter(shmIds), + [](const auto &key) { return ShmKey::Intern(key); }); auto rc = DecreaseMemoryRef(req.client_id(), shmIds); if (rc.IsError()) { resp.mutable_error()->set_error_code(rc.GetCode()); @@ -1750,9 +1758,6 @@ Status WorkerOCServiceImpl::DeleteObject(const std::string &objectKey, uint64_t { LOG(INFO) << FormatString("[ObjectKey %s] DeleteObject begin%s.", objectKey, (version > 0 ? ", version " + std::to_string(version) : "")); - if (getProc_->IsObjectInGetProcess(objectKey)) { - return Status::OK(); - } std::shared_ptr entry; RETURN_IF_NOT_OK(objectTable_->Get(objectKey, entry)); ObjectKV objectKV(objectKey, *entry); @@ -2009,7 +2014,7 @@ Status WorkerOCServiceImpl::PublishDeviceObject(const PublishDeviceObjectReqPb & { std::string tenantId; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker::Authenticate(akSkManager_, req, tenantId), "Authenticate failed."); - std::vector shmUnits = { req.shm_id() }; + std::vector shmUnits = { ShmKey::Intern(req.shm_id()) }; RETURN_IF_NOT_OK( WorkerOcServiceCrudCommonApi::CheckShmUnitByTenantId(tenantId, req.client_id(), shmUnits, memoryRefTable_)); PerfPoint point(PerfKey::WORKER_SEAL_OBJECT); @@ -2158,9 +2163,7 @@ Status WorkerOCServiceImpl::GetP2PMeta( first = false; } LOG(INFO) << FormatString("Worker processes GetP2PMeta from client: %s, allKeys: [%s], threads Statistics: %s", - clientId, - allKeys.str(), - devThreadPool_->GetStatistics()); + clientId, allKeys.str(), devThreadPool_->GetStatistics()); int64_t elapsed = timer.ElapsedMilliSecond(); if (elapsed >= timeout) { LOG(ERROR) << "GetP2PMeta RPC timeout. time elapsed " << elapsed << ", subTimeout:" << timeout diff --git a/src/datasystem/worker/object_cache/worker_oc_service_impl.h b/src/datasystem/worker/object_cache/worker_oc_service_impl.h index ef32f15790dc5b49e762f6f04fae3ec38877bda7..c4538d85ab6233b8697803110cd81a0190401432 100644 --- a/src/datasystem/worker/object_cache/worker_oc_service_impl.h +++ b/src/datasystem/worker/object_cache/worker_oc_service_impl.h @@ -85,6 +85,7 @@ class MasterOCServiceImpl; } namespace object_cache { static constexpr int MEMCOPY_THREAD_NUM = 16; +static constexpr int PARALLEL_THREAD_NUM = 8; class MasterWorkerOCServiceImpl; class WorkerWorkerOCServiceImpl; class WorkerDeviceOcManager; @@ -208,7 +209,7 @@ public: * @return Status of the call */ Status MigrateData(const std::vector &objectKeys, const std::string &taskId, - MigrateStrategy::MigrationStrategyStage stage = MigrateStrategy::MigrationStrategyStage::FIRST); + MigrateStrategy::MigrationStrategyStage stage = MigrateStrategy::MigrationStrategyStage::FIRST); /** * @brief Handle migrate data future results. @@ -220,9 +221,9 @@ public: * @return Status of the call. */ Status HandleMigrateDataResult(const std::string &taskId, const std::shared_ptr progress, - const std::unique_ptr &threadPool, - std::vector> &futures, - std::vector> &newFutures); + const std::unique_ptr &threadPool, + std::vector> &futures, + std::vector> &newFutures); /** * @brief Redirect the remote node to migrate data. @@ -626,7 +627,7 @@ public: * @param[out] id The id of this shmUnit. * @return Status of the call. */ - Status GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, std::string &id); + Status GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, ShmKey &id); /** * @brief Handle PublishDeviceObject request from the client. @@ -852,7 +853,7 @@ private: * @param[in] shmIds The ids of object reference. * @return K_OK on success; the error code otherwise. */ - Status DecreaseMemoryRef(const std::string &clientId, const std::vector &shmIds); + Status DecreaseMemoryRef(const std::string &clientId, const std::vector &shmIds); /** * @brief Get object data from remote cache (remote worker or redis) based on object meta. @@ -1025,8 +1026,8 @@ private: * @return Status */ static void FindObjectKeyNotInRsp(std::vector &queryMetas, - std::vector ¤tIds, - std::vector &objectKeysMayInOtherAz); + std::vector ¤tIds, + std::vector &objectKeysMayInOtherAz); /** * @brief Check whether the size of the node table in EtcdClusterManager equals to the number of running workers. @@ -1164,7 +1165,7 @@ private: std::shared_ptr gMigrateProc_{ nullptr }; - std::shared_ptr expireProc_{nullptr}; + std::shared_ptr expireProc_{ nullptr }; }; } // namespace object_cache } // namespace datasystem diff --git a/src/datasystem/worker/object_cache/worker_request_manager.cpp b/src/datasystem/worker/object_cache/worker_request_manager.cpp index e80642e9bf4fdd58e2cbddd3e88c50c38dce32f8..56726e9f8546e4fbf31ec4b38548be57d367685f 100644 --- a/src/datasystem/worker/object_cache/worker_request_manager.cpp +++ b/src/datasystem/worker/object_cache/worker_request_manager.cpp @@ -18,8 +18,10 @@ * Description: Defines the worker class to communicate with the worker service. */ #include "datasystem/worker/object_cache/worker_request_manager.h" +#include #include #include +#include #include #include @@ -40,265 +42,361 @@ namespace datasystem { namespace object_cache { std::function WorkerRequestManager::deleteFunc_ = nullptr; -Status WorkerRequestManager::AddRequest(const std::string &objectKey, std::shared_ptr &request) +Status GetRequest::Init(const std::string &tenantId, const GetReqPb &req, + std::shared_ptr shmRefTable, + std::shared_ptr> api) { - return requestTable_.AddRequest(objectKey, request); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(req.object_keys_size()), + StatusCode::K_INVALID, "invalid object size"); + + rawObjectKeys_ = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.object_keys()); + + // Get offset and size. + uint64_t objectsCount = rawObjectKeys_.size(); + uint64_t readOffsetCount = static_cast(req.read_offset_list_size()); + uint64_t readSizeCount = static_cast(req.read_size_list_size()); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + objectsCount == readOffsetCount || readOffsetCount == 0, K_INVALID, + FormatString("Invalid readOffsetCount %zu, should be 0 or eqeal to objectCount %zu", readOffsetCount, + objectsCount)); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + objectsCount == readSizeCount || readSizeCount == 0, K_INVALID, + FormatString("Invalid readSizeCount %zu, should be 0 or eqeal to objectCount %zu", readSizeCount, + objectsCount)); + + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + readOffsetCount == readSizeCount, K_INVALID, + FormatString("readOffsetCount %zu should be the same with readSizeCount %zu", readOffsetCount, readSizeCount)); + + clientId_ = req.client_id(); + subTimeout_ = req.sub_timeout(); + shmRefTable_ = std::move(shmRefTable); + serverApi_ = std::move(api); + noQueryL2Cache_ = req.no_query_l2cache(); + enableReturnObjectIndex_ = req.return_object_index(); + for (size_t i = 0; i < objectsCount; i++) { + const auto &objectKey = rawObjectKeys_[i]; + OffsetInfo offsetInfo; + if (readOffsetCount > 0 && readSizeCount > 0) { + offsetInfo.readOffset = req.read_offset_list(static_cast(i)); + offsetInfo.readSize = req.read_size_list(static_cast(i)); + } + GetObjInfo info{ .offsetInfo = offsetInfo, .params = nullptr, .rc = Status::OK() }; + auto [iter, insert] = objects_.emplace(objectKey, std::move(info)); + if (!insert) { + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(iter->second.offsetInfo == offsetInfo, K_INVALID, + FormatString("Duplicate offset read for objectKey %s", objectKey)); + } + VLOG(1) << "objectKey " << objectKey << " add to GetRequest success"; + } + + return Status::OK(); } -Status WorkerRequestManager::UpdateRequestForSuccess(ReadObjectKV &objectKV, - std::shared_ptr &memoryRefApi, - bool isDelayToReturn, const std::shared_ptr &request) +Status GetRequest::UpdateAfterLocalGet(Status rc, size_t remoteObjectCount) { - SafeObjType &safeObj = objectKV.GetObjEntry(); - CHECK_FAIL_RETURN_STATUS(safeObj.Get() != nullptr && memoryRefApi != nullptr, K_INVALID, - "The pointer of entry and memoryRefApi for UpdateRequest is null."); - auto entry = GetObjEntryParams::Create(safeObj, objectKV.GetReadOffset(), objectKV.GetReadSize()); - return UpdateRequestImpl(objectKV.GetObjKey(), entry, memoryRefApi, Status::OK(), request, isDelayToReturn); + CHECK_FAIL_RETURN_STATUS(!Registered(), K_RUNTIME_ERROR, + FormatString("UpdateAfterLocalGet called after GetRequest Register")); + auto uniqueObjectCount = objects_.size(); + CHECK_FAIL_RETURN_STATUS(readyCount_ == 0, K_RUNTIME_ERROR, + FormatString("Invalid readyCount_ %zu, should be 0 when call UpdateAfterLocalGet")); + CHECK_FAIL_RETURN_STATUS(uniqueObjectCount >= remoteObjectCount, K_RUNTIME_ERROR, + FormatString("The remote object key count %zu exceed the request object count %zu", + remoteObjectCount, uniqueObjectCount)); + // exist in local or failed when get from local + readyCount_ = uniqueObjectCount - remoteObjectCount; + if (rc.IsError()) { + lastRc_ = std::move(rc); + } + + // Direct return to client if get all objects. + return remoteObjectCount == 0 ? ReturnToClient() : Status::OK(); } -Status WorkerRequestManager::UpdateRequestForPublish(ObjectKV &objectKV, - std::shared_ptr &memoryRefApi) +Status GetRequest::MarkSuccess(const ObjectKey &objectKey, SafeObjType &safeObj) { - SafeObjType &safeObj = objectKV.GetObjEntry(); - CHECK_FAIL_RETURN_STATUS(safeObj.Get() != nullptr && memoryRefApi != nullptr, K_INVALID, - "The pointer of entry and memoryRefApi for UpdateRequest is null."); - auto entry = GetObjEntryParams::Create(safeObj, 0, 0); - return UpdateRequestImpl(objectKV.GetObjKey(), entry, memoryRefApi); + VLOG(1) << "MarkSuccess for object key " << objectKey; + auto params = GetObjEntryParams::Create(objectKey, safeObj); + return MarkSuccessImpl(objectKey, std::move(params)); } -Status WorkerRequestManager::UpdateRequestForFailed(const std::string &objectKey, Status lastRc, - std::shared_ptr &memoryRefApi) +Status GetRequest::MarkFailed(const ObjectKey &objectKey, const Status &rc) { - CHECK_FAIL_RETURN_STATUS(memoryRefApi != nullptr, K_INVALID, - "The pointer memoryRefApi for UpdateFailedRequest is null."); - - CHECK_FAIL_RETURN_STATUS(lastRc.IsError(), K_INVALID, "The lastRc not failed."); - return UpdateRequestImpl(objectKey, nullptr, memoryRefApi, lastRc); + VLOG(1) << "MarkFailed for object key " << objectKey; + CHECK_FAIL_RETURN_STATUS(rc.IsError(), K_RUNTIME_ERROR, "Invalid Status when MarkFailed"); + auto iter = objects_.find(objectKey); + CHECK_FAIL_RETURN_STATUS(iter != objects_.cend(), K_RUNTIME_ERROR, + FormatString("Not found object key %s in GetRequest", objectKey)); + readyCount_.fetch_add(1, std::memory_order_relaxed); + { + std::lock_guard locker(mutex_); + lastRc_ = rc; + iter->second.rc = rc; + } + return Status::OK(); } -Status WorkerRequestManager::UpdateSpecificRequestForFailed(const std::shared_ptr &request, - const std::string &objectKey, Status lastRc, - std::shared_ptr &memoryRefApi) +Status GetRequest::MarkSuccessForNotify(const ObjectKey &objectKey, std::unique_ptr params) { - CHECK_FAIL_RETURN_STATUS(memoryRefApi != nullptr, K_INVALID, - "The pointer memoryRefApi for UpdateFailedRequest is null."); + VLOG(1) << "MarkSuccessForNotify for object key " << objectKey; + CHECK_FAIL_RETURN_STATUS(Registered(), K_RUNTIME_ERROR, + FormatString("MarkSuccessForNotify called before GetRequest Register")); + RETURN_IF_NOT_OK(MarkSuccessImpl(objectKey, std::move(params))); + return GetNotReadyCount() == 0 ? ReturnToClient() : Status::OK(); +} - CHECK_FAIL_RETURN_STATUS(lastRc.IsError(), K_INVALID, "The lastRc not failed."); - return UpdateRequestImpl(objectKey, nullptr, memoryRefApi, lastRc, request); +Status GetRequest::MarkSuccessImpl(const ObjectKey &objectKey, std::unique_ptr params) +{ + auto iter = objects_.find(objectKey); + CHECK_FAIL_RETURN_STATUS(iter != objects_.cend(), K_RUNTIME_ERROR, + FormatString("Not found object key %s in GetRequest", objectKey)); + { + std::lock_guard locker(mutex_); + RETURN_OK_IF_TRUE(iter->second.params != nullptr); + iter->second.params = std::move(params); + } + readyCount_.fetch_add(1, std::memory_order_relaxed); + return Status::OK(); } -Status WorkerRequestManager::UpdateRequestImpl(const std::string &objectKey, std::shared_ptr entry, - std::shared_ptr &memoryRefApi, Status lastRc, - const std::shared_ptr &request, bool isDelayToReturn) +void GetRequest::SetStatus(const Status &rc) { - auto checkFun = - [&entry](const std::string &objKey, const std::shared_ptr req) { - OffsetInfo info(entry->readOffset, entry->readSize); - return req->IsOffsetAndSizeMatchWithoutLock(objKey, entry->dataSize, info); - }; - - if (!isDelayToReturn) { - return requestTable_.UpdateRequest( - objectKey, entry, lastRc, - [memoryRefApi, this](std::shared_ptr req) mutable { - VLOG(1) << "All object data ready for clientId: " + req->clientId_; - LOG_IF_ERROR(ReturnFromGetRequest(req, memoryRefApi), "ReturnFromGetRequest failed"); - }, - request, false, checkFun); - } else { - // only check to set finish for finished request, not return to client. - return requestTable_.UpdateRequest( - objectKey, entry, lastRc, - [memoryRefApi](std::shared_ptr req) mutable { - VLOG(1) << "All object data ready for clientId: " + req->clientId_; - req->isFinished_ = true; - }, - request, false, checkFun); + if (rc.IsError()) { + lastRc_ = rc; } } -void WorkerRequestManager::CheckAndReturnToClient(const std::string objectKey, - std::shared_ptr &memoryRefApi) +size_t GetRequest::GetReadyCount() const { - auto requests = requestTable_.GetRequestsByObject(objectKey); - for (auto &request : requests) { - if (request->isFinished_) { - LOG_IF_ERROR(ReturnFromGetRequest(request, memoryRefApi), "return to client failed"); - } + return readyCount_; +} + +size_t GetRequest::GetNotReadyCount() const +{ + return objects_.size() - readyCount_; +} + +bool GetRequest::AlreadyReturn() const +{ + return isReturn_; +} + +const std::string &GetRequest::GetClientId() const +{ + return clientId_; +} + +bool GetRequest::NoQueryL2Cache() const +{ + return noQueryL2Cache_; +} + +const std::vector &GetRequest::GetRawObjectKeys() const +{ + return rawObjectKeys_; +} + +std::unordered_map &GetRequest::GetObjects() +{ + return objects_; +} + +std::vector GetRequest::GetUniqueObjectkeys() const +{ + std::vector objectKeys; + objectKeys.reserve(objects_.size()); + for (const auto &kv : objects_) { + objectKeys.emplace_back(kv.first); } + return objectKeys; } -Status WorkerRequestManager::AddEntryToGetResponse( - const std::shared_ptr &request, - const std::pair> &retIdEntry, GetRspPb &resp, - std::vector &outPayloads, std::shared_ptr &memoryRefApi, - std::map &needDeleteObjects) +std::shared_ptr> GetRequest::GetServerApi() const { - const std::string &namespaceUri = *retIdEntry.first; - std::string objectKey; - TenantAuthManager::Instance()->NamespaceUriToObjectKey(namespaceUri, objectKey); - auto safeEntry = retIdEntry.second; - auto clientInfo = worker::ClientManager::Instance().GetClientInfo(request->clientId_); - bool shmEnabled = clientInfo != nullptr && clientInfo->ShmEnabled(); - // Only add shm ref when we will return this shmUnit to the client. - if (shmEnabled) { - GetRspPb::ObjectInfoPb *object = resp.add_objects(); - SetObjectInfoPb(objectKey, *safeEntry, *object); - // If object is shm, we increase the refCnt for client. - // The client will be using this object and be responsible for releasing this object. - if (worker::ClientManager::Instance().CheckClientId(request->clientId_).IsOk()) { - memoryRefApi->AddShmUnit(request->clientId_, safeEntry->shmUnit); + return serverApi_; +} + +void GetRequest::Register(WorkerRequestManager *workerRequestManager) +{ + workerRequestManager_ = workerRequestManager; + auto request = shared_from_this(); + for (auto &[objectKey, objectInfo] : objects_) { + // The object key not found in local and remote + VLOG(1) << "Register GetRequest for objectKey " << objectKey << ", params " + << (objectInfo.params == nullptr ? "is null" : "not null") << ", status: " << objectInfo.rc.ToString(); + if (objectInfo.params == nullptr && objectInfo.rc.IsOk()) { + workerRequestManager_->AddRequest(objectKey, request); } - } else { - RETURN_IF_NOT_OK(CopyShmUnitToPayloads(retIdEntry, resp, outPayloads)); } - // If object is shm, must delete entry after memoryRefApi->AddShmUnit so that the ShmUnit won't be released - // immediately when entry is releasing. - bool needDeleted = safeEntry->objectState.IsNeedToDelete(); - INJECT_POINT("worker.AddEntryToGetResponse", [&needDeleted] { - needDeleted = true; - return Status::OK(); - }); - if (needDeleted) { - needDeleteObjects.emplace(namespaceUri, safeEntry->version); - } - return Status::OK(); } -Status WorkerRequestManager::CopyShmUnitToPayloads( - const std::pair> &retIdEntry, GetRspPb &resp, - std::vector &outPayloads) +void GetRequest::UnRegister() { - const std::string &namespaceUri = *retIdEntry.first; - std::string objectKey; - TenantAuthManager::Instance()->NamespaceUriToObjectKey(namespaceUri, objectKey); - auto safeEntry = retIdEntry.second; - const uint64_t metaSize = safeEntry->metaSize; - const uint64_t dataSize = safeEntry->dataSize; - const uint64_t readOffset = safeEntry->readOffset; - const uint64_t readSize = safeEntry->readSize; - - ShmGuard shmGuard(safeEntry->shmUnit, dataSize, metaSize); - if (WorkerOcServiceCrudCommonApi::ShmEnable()) { - RETURN_IF_NOT_OK_PRINT_ERROR_MSG( - shmGuard.TryRLatch(), - FormatString("Try read latch failed while getting object %s from shmUnit.", objectKey)); + if (Registered()) { + workerRequestManager_->RemoveGetRequest(shared_from_this()); } - auto curIndex = outPayloads.size(); - LOG(INFO) << FormatString("CopyShmUnitToPayloads, objectKey: %s, read offset: %ld, read size: %ld", objectKey, - readOffset, readSize); - RETURN_IF_NOT_OK(shmGuard.TransferTo(outPayloads, readOffset, readSize)); - auto lastIndex = outPayloads.size(); - GetRspPb::PayloadInfoPb *payloadInfo = resp.add_payload_info(); - SetPayloadInfoPb(objectKey, *safeEntry, *payloadInfo); - for (auto index = curIndex; index < lastIndex; index++) { - payloadInfo->add_part_index(index); - } - return Status::OK(); } -void WorkerRequestManager::ConstructGetRsp(std::shared_ptr &req, uint64_t &totalSize, Status &lastRc, - std::shared_ptr &memoryRefApi, GetRspPb &resp, - std::vector &payloads, - std::map &needDeleteObjects) +void GetRequest::SetTimer(std::unique_ptr timer) { - std::string objectKey; - for (auto &namespaceUri : req->rawObjectKeys_) { - TenantAuthManager::Instance()->NamespaceUriToObjectKey(namespaceUri, objectKey); - GetRequest::TbbGetObjsTable::const_accessor accessor; - bool isFindObj = false; - if (req->objects_.find(accessor, namespaceUri) && accessor->second != nullptr) { - isFindObj = true; - totalSize += accessor->second->dataSize; - lastRc = AddEntryToGetResponse(req, std::make_pair(&namespaceUri, accessor->second), resp, payloads, - memoryRefApi, needDeleteObjects); - } - if (!isFindObj || lastRc.IsError()) { - if (req->lastRc_.GetCode() != K_OUT_OF_MEMORY) { - req->SetStatus(lastRc); - } - LOG(ERROR) << FormatString("Can't find object %s or AddEntryToGetResponse failed, clientId %s, rc %s", - namespaceUri, req->clientId_, lastRc.ToString()); - GetRspPb::ObjectInfoPb *object = resp.add_objects(); - SetDefaultObjectInfoPb(objectKey, *object); - } - } - resp.mutable_last_rc()->set_error_code(req->lastRc_.GetCode()); - resp.mutable_last_rc()->set_error_msg(req->lastRc_.GetMsg()); - VLOG(1) << FormatString("The total size of the currently get is %llu", totalSize); + std::lock_guard locker(mutex_); + timer_ = std::move(timer); +} + +bool GetRequest::Registered() const +{ + return workerRequestManager_ != nullptr; } -Status WorkerRequestManager::ReturnFromGetRequest(std::shared_ptr req, - std::shared_ptr &memoryRefApi, Status lastRc) +Status GetRequest::ReturnToClient(const Status &rc) { + INJECT_POINT("worker.Get.beforeReturn"); PerfPoint point(PerfKey::WORKER_RETURN_FROM_GET_REQUEST); - RETURN_RUNTIME_ERROR_IF_NULL(req); bool expected = false; - RETURN_OK_IF_TRUE(!req->isReturn_.compare_exchange_strong(expected, true)); - VLOG(1) << "Begin to ReturnFromGetRequest, client id: " << req->clientId_; - + RETURN_OK_IF_TRUE(!isReturn_.compare_exchange_strong(expected, true)); + VLOG(1) << "Begin to ReturnToClient, client id: " << clientId_; + Status lastRc; + { + std::lock_guard locker(mutex_); + lastRc = lastRc_; + } + if (rc.IsError()) { + lastRc = rc; + } uint64_t totalSize = 0; - Raii raii([&totalSize, &req] { + Raii raii([this, &totalSize, &lastRc] { GetReqPb reqPb; RequestParam reqParam; - reqParam.subTimeout = "0"; - if (req->serverApi_->Read(reqPb).IsOk()) { - reqParam.subTimeout = std::to_string(reqPb.sub_timeout()); - } - reqParam.objectKey = objectKeysToAbbrStr(req->rawObjectKeys_); - StatusCode code = req->lastRc_.GetCode() == K_NOT_FOUND ? K_OK : req->lastRc_.GetCode(); - req->accessRecorderPoint_->Record(code, std::to_string(totalSize), reqParam, req->lastRc_.GetMsg()); + reqParam.subTimeout = std::to_string(subTimeout_); + reqParam.objectKey = ObjectKeysToAbbrStr(rawObjectKeys_); + StatusCode code = lastRc.GetCode() == K_NOT_FOUND ? K_OK : lastRc.GetCode(); + recorder_.Record(code, std::to_string(totalSize), reqParam, lastRc.GetMsg()); }); std::map needDeleteObjects; - Raii deleteRaii([&needDeleteObjects] { DeleteObjects(needDeleteObjects); }); - std::lock_guard lck(req->mutex_); - req->SetStatus(lastRc); + Raii deleteRaii([&needDeleteObjects] { WorkerRequestManager::DeleteObjects(needDeleteObjects); }); int64_t remainingTimeMs = reqTimeoutDuration.CalcRealRemainingTime(); if (remainingTimeMs <= 0) { - LOG(ERROR) << "ReturnFromGetRequest timeout when get object: " << VectorToString(req->rawObjectKeys_); - RemoveGetRequest(req); - auto rc = req->lastRc_.IsOk() ? Status(K_RPC_DEADLINE_EXCEEDED, "Rpc timeout") : req->lastRc_; - req->SetStatus(rc); - return req->serverApi_->SendStatus(rc); + LOG(ERROR) << "ReturnFromGetRequest timeout when get object: " << VectorToString(rawObjectKeys_); + UnRegister(); + auto rc = lastRc.IsOk() ? Status(K_RPC_DEADLINE_EXCEEDED, "Rpc timeout") : lastRc; + return serverApi_->SendStatus(rc); } GetRspPb resp; std::vector payloads; - ConstructGetRsp(req, totalSize, lastRc, memoryRefApi, resp, payloads, needDeleteObjects); + auto constructRc = ConstructResponse(totalSize, resp, payloads, needDeleteObjects); + if (constructRc.IsError() && lastRc.GetCode() != K_OUT_OF_MEMORY) { + lastRc = constructRc; + } // Remove the get request from each of the relevant object_get_requests hash // tables if it is present there. It should only be present there if the get request timed out. - RemoveGetRequest(req); + UnRegister(); - // Close the request time out event. - if (req->timer_ != nullptr) { - if (!TimerQueue::GetInstance()->Cancel(*(req->timer_))) { - LOG(ERROR) << "Failed to Cancel the timer: " << req->timer_->GetId(); + { + // Close the request time out event. + std::lock_guard locker(mutex_); + if (timer_ != nullptr) { + if (!TimerQueue::GetInstance()->Cancel(*timer_)) { + LOG(ERROR) << "Failed to Cancel the timer: " << timer_->GetId(); + } + timer_.reset(); } - req->timer_.reset(); } - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(req->serverApi_->Write(resp), "Write reply to client stream failed."); - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(req->serverApi_->SendPayload(payloads), "SendPayload to client stream failed"); + resp.mutable_last_rc()->set_error_code(lastRc.GetCode()); + resp.mutable_last_rc()->set_error_msg(lastRc.GetMsg()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi_->Write(resp), "Write reply to client stream failed."); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi_->SendPayload(payloads), "SendPayload to client stream failed"); return Status::OK(); } -void WorkerRequestManager::SetDefaultObjectInfoPb(const std::string &objectKey, GetRspPb::ObjectInfoPb &info) +Status GetRequest::ConstructResponse(uint64_t &totalSize, GetRspPb &resp, std::vector &payloads, + std::map &needDeleteObjects) { - info.set_object_key(objectKey); - info.set_store_fd(-1); - info.set_offset(-1); - info.set_data_size(-1); - info.set_metadata_size(-1); - info.set_mmap_size(-1); - info.set_version(-1); - info.set_is_seal(false); - info.set_write_mode(static_cast(WriteMode::NONE_L2_CACHE)); - info.set_consistency_type(static_cast(ConsistencyType::PRAM)); + auto clientInfo = worker::ClientManager::Instance().GetClientInfo(clientId_); + bool shmEnabled = clientInfo != nullptr && clientInfo->ShmEnabled(); + Status lastRc; + for (size_t objectIndex = 0; objectIndex < rawObjectKeys_.size(); objectIndex++) { + auto &objectKeyUri = rawObjectKeys_[objectIndex]; + Status rc; + auto iter = objects_.find(objectKeyUri); + if (iter == objects_.cend() || iter->second.params == nullptr) { + LOG(ERROR) << FormatString("Can't find object %s, clientId %s", objectKeyUri, clientId_); + SetDefaultObjectInfoPb(objectKeyUri, objectIndex, *resp.add_objects()); + continue; + } + const auto ¶ms = iter->second.params; + totalSize += params->dataSize; + rc = AddObjectToResponse(iter->first, iter->second, objectIndex, shmEnabled, resp, payloads); + if (shmEnabled) { + // If object is shm, we increase the refCnt for client. + // The client will be using this object and be responsible for releasing this object. + shmRefTable_->AddShmUnit(clientId_, params->shmUnit); + } + + bool needDeleted = params->objectState.IsNeedToDelete(); + INJECT_POINT("worker.AddEntryToGetResponse", [&needDeleted] { + needDeleted = true; + return Status::OK(); + }); + if (needDeleted) { + needDeleteObjects.emplace(objectKeyUri, params->version); + } + if (rc.IsError()) { + LOG(ERROR) << FormatString("Can't find object %s or AddObjectToResponse failed, clientId %s, rc %s", + objectKeyUri, clientId_, rc.ToString()); + lastRc = rc; + SetDefaultObjectInfoPb(objectKeyUri, objectIndex, *resp.add_objects()); + } + } + VLOG(1) << FormatString("The total size of the currently get is %llu", totalSize); + return lastRc; } -void WorkerRequestManager::SetObjectInfoPb(const std::string &objectKey, GetObjEntryParams &safeEntry, - GetRspPb::ObjectInfoPb &info) +Status GetRequest::AddObjectToResponse(const ObjectKey &objectKeyUri, GetObjInfo &objectInfo, size_t objectIndex, + bool shmEnabled, GetRspPb &resp, std::vector &outPayloads) +{ + const auto ¶ms = objectInfo.params; + if (shmEnabled) { + GetRspPb::ObjectInfoPb *object = resp.add_objects(); + SetShmObjectInfoPb(objectKeyUri, objectIndex, *params, *object); + return Status::OK(); + } + + const uint64_t metaSize = params->metaSize; + const uint64_t dataSize = params->dataSize; + objectInfo.offsetInfo.AdjustReadSize(dataSize); + const uint64_t readOffset = objectInfo.offsetInfo.readOffset; + const uint64_t readSize = objectInfo.offsetInfo.readSize; + + ShmGuard shmGuard(params->shmUnit, dataSize, metaSize); + if (WorkerOcServiceCrudCommonApi::ShmEnable()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + shmGuard.TryRLatch(), + FormatString("Try read latch failed while getting object %s from shmUnit.", objectKeyUri)); + } + auto curIndex = outPayloads.size(); + LOG(INFO) << FormatString("CopyShmUnitToPayloads, objectKey: %s, read offset: %ld, read size: %ld", objectKeyUri, + readOffset, readSize); + RETURN_IF_NOT_OK(shmGuard.TransferTo(outPayloads, readOffset, readSize)); + auto lastIndex = outPayloads.size(); + GetRspPb::PayloadInfoPb *payloadInfo = resp.add_payload_info(); + SetNoShmObjectInfoPb(objectKeyUri, objectIndex, objectInfo, *payloadInfo); + for (auto index = curIndex; index < lastIndex; index++) { + payloadInfo->add_part_index(index); + } + return Status::OK(); +} + +void GetRequest::SetShmObjectInfoPb(const ObjectKey &objectKeyUri, size_t objectIndex, GetObjEntryParams &safeEntry, + GetRspPb::ObjectInfoPb &info) { auto &shmUnit = safeEntry.shmUnit; - info.set_object_key(objectKey); + if (enableReturnObjectIndex_) { + info.set_object_index(objectIndex); + } else { + ObjectKey objectKey; + TenantAuthManager::Instance()->NamespaceUriToObjectKey(objectKeyUri, objectKey); + info.set_object_key(objectKey); + } info.set_store_fd(shmUnit->GetFd()); info.set_offset(static_cast(shmUnit->GetOffset())); info.set_data_size(static_cast(safeEntry.dataSize)); @@ -311,31 +409,67 @@ void WorkerRequestManager::SetObjectInfoPb(const std::string &objectKey, GetObjE info.set_shm_id(shmUnit->id); } -void WorkerRequestManager::SetPayloadInfoPb(const std::string &objectKey, GetObjEntryParams &safeEntry, - GetRspPb::PayloadInfoPb &info) +void GetRequest::SetNoShmObjectInfoPb(const ObjectKey &objectKeyUri, size_t objectIndex, const GetObjInfo &objectInfo, + GetRspPb::PayloadInfoPb &info) { - info.set_object_key(objectKey); - info.set_data_size(static_cast(safeEntry.readSize)); + if (enableReturnObjectIndex_) { + info.set_object_index(objectIndex); + } else { + ObjectKey objectKey; + TenantAuthManager::Instance()->NamespaceUriToObjectKey(objectKeyUri, objectKey); + info.set_object_key(objectKey); + } + const auto &safeEntry = *objectInfo.params; + info.set_data_size(static_cast(objectInfo.offsetInfo.readSize)); info.set_version(static_cast(safeEntry.createTime)); info.set_is_seal(safeEntry.isSealed); info.set_write_mode(static_cast(safeEntry.objectMode.GetWriteMode())); info.set_consistency_type(static_cast(safeEntry.objectMode.GetConsistencyType())); } -void WorkerRequestManager::RemoveGetRequest(std::shared_ptr &request) +void GetRequest::SetDefaultObjectInfoPb(const ObjectKey &objectKeyUri, size_t objectIndex, GetRspPb::ObjectInfoPb &info) { - VLOG(1) << "Begin to RemoveGetRequest, client id: " << request->clientId_; - requestTable_.RemoveRequest(request); + if (enableReturnObjectIndex_) { + info.set_object_index(objectIndex); + } else { + ObjectKey objectKey; + TenantAuthManager::Instance()->NamespaceUriToObjectKey(objectKeyUri, objectKey); + info.set_object_key(objectKey); + } + info.set_store_fd(-1); + info.set_offset(-1); + info.set_data_size(-1); + info.set_metadata_size(-1); + info.set_mmap_size(-1); + info.set_version(-1); + info.set_is_seal(false); + info.set_write_mode(static_cast(WriteMode::NONE_L2_CACHE)); + info.set_consistency_type(static_cast(ConsistencyType::PRAM)); } -void WorkerRequestManager::SetDeleteObjectsFunc(std::function deleteFunc) +Status WorkerRequestManager::AddRequest(const std::string &objectKey, std::shared_ptr &request) { - deleteFunc_ = std::move(deleteFunc); + return requestTable_.AddRequest(objectKey, request); } -bool WorkerRequestManager::IsInGettingObject(const std::string &objectKey) +Status WorkerRequestManager::NotifyPendingGetRequest(ObjectKV &objectKV) { - return requestTable_.ObjectInRequest(objectKey); + SafeObjType &safeObj = objectKV.GetObjEntry(); + CHECK_FAIL_RETURN_STATUS(safeObj.Get() != nullptr, K_INVALID, + "The pointer of entry and memoryRefApi for UpdateRequest is null."); + auto params = GetObjEntryParams::Create(objectKV.GetObjKey(), safeObj); + return requestTable_.NotifyPendingGetRequest(objectKV.GetObjKey(), std::move(params)); +} + +void WorkerRequestManager::RemoveGetRequest(const std::shared_ptr &request) +{ + VLOG(1) << "Begin to RemoveGetRequest, client id: " << request->GetClientId(); + requestTable_.RemoveRequest(request); +} + +void WorkerRequestManager::SetDeleteObjectsFunc(std::function deleteFunc) +{ + deleteFunc_ = std::move(deleteFunc); } void WorkerRequestManager::DeleteObjects(std::map &objects) diff --git a/src/datasystem/worker/object_cache/worker_request_manager.h b/src/datasystem/worker/object_cache/worker_request_manager.h index 643c8b787335abc6e772c9e2ae75d89186e97c51..4ed17ad888bb700492c10d5a1b542a29615cee65 100644 --- a/src/datasystem/worker/object_cache/worker_request_manager.h +++ b/src/datasystem/worker/object_cache/worker_request_manager.h @@ -29,6 +29,7 @@ #include #include "datasystem/common/log/log.h" +#include "datasystem/common/object_cache/object_base.h" #include "datasystem/common/object_cache/object_ref_info.h" #include "datasystem/common/object_cache/safe_object.h" #include "datasystem/common/util/memory.h" @@ -40,26 +41,38 @@ namespace datasystem { namespace object_cache { -struct GetObjEntryParams : public OffsetInfo { - static std::shared_ptr Create(SafeObjType &safeObj, uint64_t offset, uint64_t size) +struct GetObjEntryParams { + static std::unique_ptr Create(const std::string &objectKey, SafeObjType &safeObj) { auto objShmUnit = SafeObjType::GetDerived(safeObj); - GetObjEntryParams params; - params.dataSize = safeObj->GetDataSize(); - params.metaSize = safeObj->GetMetadataSize(); - params.createTime = safeObj->GetCreateTime(); - params.objectMode = objShmUnit->modeInfo; - params.objectState = objShmUnit->stateInfo; - params.lifeState = objShmUnit->GetLifeState(); - params.shmUnit = safeObj->GetShmUnit(); - params.isSealed = safeObj->IsSealed(); - params.version = safeObj->GetCreateTime(); - params.readOffset = offset; - params.readSize = size; - params.AdjustReadSize(params.dataSize); - VLOG(1) << "dataSize: " << params.dataSize << ", metaSize: " << params.metaSize - << ", offset:" << params.readOffset << ", readSize:" << params.readSize; - return std::make_shared(std::move(params)); + auto params = std::make_unique(); + params->dataSize = safeObj->GetDataSize(); + params->metaSize = safeObj->GetMetadataSize(); + params->createTime = safeObj->GetCreateTime(); + params->objectMode = objShmUnit->modeInfo; + params->objectState = objShmUnit->stateInfo; + params->lifeState = objShmUnit->GetLifeState(); + params->shmUnit = safeObj->GetShmUnit(); + params->isSealed = safeObj->IsSealed(); + params->version = safeObj->GetCreateTime(); + VLOG(1) << "Create GetObjEntryParams for objectKey " << objectKey << ", dataSize: " << params->dataSize + << ", metaSize: " << params->metaSize; + return params; + } + + std::unique_ptr Clone() const + { + auto params = std::make_unique(); + params->dataSize = dataSize; + params->metaSize = metaSize; + params->createTime = createTime; + params->objectMode = objectMode; + params->objectState = objectState; + params->lifeState = lifeState; + params->shmUnit = shmUnit; + params->isSealed = isSealed; + params->version = version; + return params; } uint64_t dataSize; @@ -73,172 +86,171 @@ struct GetObjEntryParams : public OffsetInfo { uint64_t version; }; -using GetRequest = UnaryRequest; +struct GetObjInfo { + OffsetInfo offsetInfo; + std::unique_ptr params; + Status rc; + bool isRollBack = false; + bool NotFound() const + { + return params == nullptr && rc.IsOk(); + } +}; -class WorkerRequestManager { +using ObjectKey = std::string; +class WorkerRequestManager; +class GetRequest : public std::enable_shared_from_this { public: - WorkerRequestManager() = default; - - ~WorkerRequestManager() = default; - + GetRequest(AccessRecorderKey key) noexcept : recorder_(key){}; /** - * @brief Add request to WorkerRequestManager. - * @param[in] objectKey The object key. - * @param[in] request The request that is waiting on the object key. - * @return Status of the call. + * @brief Init GetRequst + * @param[in] tenantId The tenantId. + * @param[in] req The GetReqPb instance. + * @param[in] shmRefTable The instance of SharedMemoryRefTable. + * @param[in] api The instance of server api. + * @return Status of this call. */ - Status AddRequest(const std::string &objectKey, std::shared_ptr &request); + Status Init(const std::string &tenantId, const GetReqPb &req, std::shared_ptr shmRefTable, + std::shared_ptr> api); /** - * @brief Update request info after object sealed. - * @param[in] objectKV The safe object and its corresponding objectKey. - * @param[in,out] memoryRefApi The memory refCnt table. - * @return Status of the call. + * @brief Update GetRequst according to local get result, return to client if all data has been obtained + * @param[in] rc The status of local get. + * @param[in] remoteObjectCount The object count need get from remote. + * @return Status of this call. */ - Status UpdateRequestForSuccess(ReadObjectKV &objectKV, std::shared_ptr &memoryRefApi, - bool isDelayToReturn, const std::shared_ptr &request = nullptr); + Status UpdateAfterLocalGet(Status rc, size_t remoteObjectCount); /** - * @brief Update request info after object publish. - * @param[in] objectKV The safe object and its corresponding objectKey. - * @param[in,out] memoryRefApi The memory refCnt table. - * @return Status of the call. + * @brief Mark object get success, should be called in remote get logic. + * @param[in] objectKey The object key. + * @param[in] safeObj The instance reference of SafeObjType + * @return Status of this call. */ - Status UpdateRequestForPublish(ObjectKV &objectKV, std::shared_ptr &memoryRefApi); + Status MarkSuccess(const ObjectKey &objectKey, SafeObjType &safeObj); /** - * @brief Update request info after object process failed. + * @brief Mark object get failed, should be called in remote get logic. * @param[in] objectKey The object key. - * @param[in] lastRc The last error. - * @param[in,out] memRefApi The memory refCnt table. - * @return Status of the call. + * @param[in] rc The failed reason. + * @return Status of this call. */ - Status UpdateRequestForFailed(const std::string &objectKey, Status lastRc, - std::shared_ptr &memRefApi); + Status MarkFailed(const ObjectKey &objectKey, const Status &rc); /** - * @brief Update request info after object process failed. - * @param[in] objectKey The object key. - * @param[in] lastRc The last error. - * @param[in,out] memRefApi The memory refCnt table. - * @return Status of the call. + * @brief Mark object get success and try return to client if all data has been obtained, should be called when + * object publish after the get reqeust happend. + * @param[in] objectKey The object Key. + * @param[in] params The instance of GetObjEntryParams. + * @return Status of this call. */ - Status UpdateSpecificRequestForFailed(const std::shared_ptr &request, const std::string &objectKey, - Status lastRc, std::shared_ptr &memRefApi); + Status MarkSuccessForNotify(const ObjectKey &objectKey, std::unique_ptr params); /** - * @brief Reply to client with the get request. - * @param[in] req The request which to return. - * @param[in,out] memoryRefApi The memory refCnt table. - * @param[in] lastRc The last error. - * @return Status of the call. + * @brief Response to client. + * @param[in] rc The final status return to client. + * @return Status of this call. */ - Status ReturnFromGetRequest(std::shared_ptr req, std::shared_ptr &memoryRefApi, - Status lastRc = Status::OK()); + Status ReturnToClient(const Status &rc = Status::OK()); /** - * @brief Set DeleteObject function for deleting local cache when object is from other AZ. - * @param[in] deleteFunc The delete object function. + * @brief Register Current GetRequest instance to WorkerRequestManager + * @param[in] workerRequestManager The point of WorkerRequestManager. */ - static void SetDeleteObjectsFunc(std::function deleteFunc); + void Register(WorkerRequestManager *workerRequestManager); /** - * @brief Check and return to client when request finished or object finish - * @param[in] objectKey object key - * @param[in] memoryRefApi The memory refCnt table. - * @return Status of the call + * @brief Unregister Current GetRequest instance from WorkerRequestManager. */ - void CheckAndReturnToClient(const std::string objectKey, std::shared_ptr &memoryRefApi); + void UnRegister(); /** - * @brief Check if the object is in getting object. - * @param[in] objectKey Object key. - * @return True if object is in getting. + * @brief Set the timer instance + * @param[in] timer The instance of timer. */ - bool IsInGettingObject(const std::string &objectKey); + void SetTimer(std::unique_ptr timer); + + const std::vector &GetRawObjectKeys() const; + std::unordered_map &GetObjects(); + void SetStatus(const Status &rc); + size_t GetReadyCount() const; + size_t GetNotReadyCount() const; + bool AlreadyReturn() const; + const std::string &GetClientId() const; + bool NoQueryL2Cache() const; + + std::vector GetUniqueObjectkeys() const; + std::shared_ptr> GetServerApi() const; private: - /** - * @brief Add entry information to get response and buffers. - * @param[in] request The request which to return. - * @param[in,out] retIdEntry The id and shared memory unit of the object. - * @param[out] resp The response which to return. - * @param[out] outPayloads The buffers for non-shm passing. - * @param[in,out] memoryRefApi The memory refCnt table. - * @return Status of the call. - */ - static Status AddEntryToGetResponse(const std::shared_ptr &request, - const std::pair> &retIdEntry, - GetRspPb &resp, std::vector &outPayloads, - std::shared_ptr &memoryRefApi, - std::map &needDeleteObjects); + Status MarkSuccessImpl(const ObjectKey &objectKey, std::unique_ptr params); + Status ConstructResponse(uint64_t &totalSize, GetRspPb &resp, std::vector &payloads, + std::map &needDeleteObjects); - /** - * @brief Remove the request from the waiting requests table. - * @param[in] request The request need to remove. - */ - void RemoveGetRequest(std::shared_ptr &request); + Status AddObjectToResponse(const ObjectKey &objectKeyUri, GetObjInfo &objectInfo, size_t index, bool shmEnable, + GetRspPb &resp, std::vector &outPayloads); - /** - * @brief Initialize an ObjCacheShmUnit, give it default value. - * @param[in] objectKey The object key that responds to the request - * @param[out] info The objectInfoPb need to init - */ - static void SetDefaultObjectInfoPb(const std::string &objectKey, GetRspPb::ObjectInfoPb &info); + void SetShmObjectInfoPb(const ObjectKey &objectKeyUri, size_t objectIndex, GetObjEntryParams &safeEntry, + GetRspPb::ObjectInfoPb &info); + + void SetNoShmObjectInfoPb(const ObjectKey &objectKeyUri, size_t objectIndex, const GetObjInfo &objectInfo, + GetRspPb::PayloadInfoPb &info); + void SetDefaultObjectInfoPb(const ObjectKey &objectKeyUri, size_t objectIndex, GetRspPb::ObjectInfoPb &info); + bool Registered() const; + + // the mutex protect GetObjInfo and lastRc_ + // Only after the GetRequest instance is registered with the WorkerRequestManager can other threads become aware of + // it; therefore, accessing data in objects_ before registration requires no lock protection + std::mutex mutex_; + std::shared_ptr> serverApi_; + std::shared_ptr shmRefTable_; + std::string clientId_; + std::vector rawObjectKeys_; + std::unordered_map objects_; + std::atomic readyCount_{ 0 }; + AccessRecorder recorder_; + Status lastRc_; + std::atomic isReturn_{ false }; + + int64_t subTimeout_{ 0 }; + WorkerRequestManager *workerRequestManager_{ nullptr }; + std::unique_ptr timer_; + bool noQueryL2Cache_ = false; + bool enableReturnObjectIndex_ = false; +}; + +class WorkerRequestManager { +public: + WorkerRequestManager() = default; + + ~WorkerRequestManager() = default; /** - * @brief Set objectInfoPb response - * @param[in] objectKey The object key that responds to the request - * @param[in] safeEntry The safe object entry - * @param[out] info The objectInfoPb need to init + * @brief Add request to WorkerRequestManager. + * @param[in] objectKey The object key. + * @param[in] request The request that is waiting on the object key. + * @return Status of the call. */ - static void SetObjectInfoPb(const std::string &objectKey, GetObjEntryParams &safeEntry, - GetRspPb::ObjectInfoPb &info); + Status AddRequest(const std::string &objectKey, std::shared_ptr &request); /** - * @brief Set PayloadInfoPb response - * @param[in] objectKey The object key that responds to the request - * @param[in] safeEntry The safe object entry - * @param[out] info The PayloadInfoPb need to init + * @brief Remove the request from the waiting requests table. + * @param[in] request The request need to remove. */ - static void SetPayloadInfoPb(const std::string &objectKey, GetObjEntryParams &safeEntry, - GetRspPb::PayloadInfoPb &info); + void RemoveGetRequest(const std::shared_ptr &request); /** * @brief Update request info after object sealed. - * @param[in] objectKey The object key. - * @param[in] entry The object entry parameter. - * @param[in,out] memoryRefApi The memory refCnt table. - * @param[in] lastRc The last error. - * @param [in] isDelayToReturn if true, return After get request. - * @return Status of the call. - */ - Status UpdateRequestImpl(const std::string &objectKey, std::shared_ptr entry, - std::shared_ptr &memoryRefApi, Status lastRc = Status::OK(), - const std::shared_ptr &request = nullptr, bool isDelayToReturn = false); - - /* - * @brief Add entry information to get response and pageloads. - * @param[in] retIdEntry The id and shared memory unit of the object. - * @param[out] resp The response which to return. - * @param[out] outPayloads The buffers for non-shm passing. + * @param[in] objectKV The safe object and its corresponding objectKey. * @return Status of the call. */ - static Status CopyShmUnitToPayloads(const std::pair> &retIdEntry, - GetRspPb &resp, std::vector &outPayloads); - - /* - * @brief Construct GetRsp. - * @param[in] req The request which to return. - * @param[in] totalSize The size of objects. - * @param[in] lastRc The last error. - * @param[in] memoryRefApi The memory refCnt table. - * @param[in] resp GetRsp. - * @param[in] payloads The buffers for non-shm passing. + Status NotifyPendingGetRequest(ObjectKV &objectKV); + + /** + * @brief Set DeleteObject function for deleting local cache when object is from other AZ. + * @param[in] deleteFunc The delete object function. */ - void ConstructGetRsp(std::shared_ptr &req, uint64_t &totalSize, Status &lastRc, - std::shared_ptr &memoryRefApi, GetRspPb &resp, - std::vector &payloads, std::map &needDeleteObjects); + static void SetDeleteObjectsFunc(std::function deleteFunc); /** * @brief Delete objects according to object key and version. @@ -246,8 +258,9 @@ private: */ static void DeleteObjects(std::map &objects); - RequestTable requestTable_; +private: static std::function deleteFunc_; + RequestTable requestTable_; }; } // namespace object_cache } // namespace datasystem diff --git a/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.cpp b/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.cpp index 30cc86d8223d01ed9de57769f231dbad1cbada19..75ed2098f3fe3619047a656cc495fa5cdcc01315 100644 --- a/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.cpp +++ b/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.cpp @@ -21,10 +21,15 @@ #include +#include "datasystem/utils/status.h" +#include "tbb/blocked_range.h" +#include "tbb/parallel_for.h" + #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/log/log.h" #include "datasystem/common/object_cache/shm_guard.h" #include "datasystem/common/rdma/urma_manager_wrapper.h" +#include "datasystem/common/util/deadlock_util.h" #include "datasystem/common/util/raii.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/worker/object_cache/object_kv.h" @@ -32,6 +37,10 @@ #include "datasystem/common/perf/perf_manager.h" DS_DECLARE_int32(oc_worker_worker_direct_port); +DS_DECLARE_int32(oc_worker_worker_parallel_nums); +DS_DECLARE_int32(oc_worker_worker_parallel_min); +DS_DECLARE_uint64(oc_worker_aggregate_single_max); +DS_DECLARE_uint64(oc_worker_aggregate_merge_size); namespace datasystem { namespace object_cache { @@ -101,17 +110,144 @@ Status WorkerWorkerOCServiceImpl::GetObjectRemote(GetObjectRemoteReqPb &req, Get return Status::OK(); } -Status WorkerWorkerOCServiceImpl::GetObjectRemoteBatchWrite(GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, - std::vector &payload, - std::vector &keys) +void WorkerWorkerOCServiceImpl::GetObjectRemoteBatchWrite( + uint32_t paraIndex, const GetObjectRemoteReqPb &subReq, BatchGetObjectRemoteRspPb &rsp, + std::vector &payload, + std::map, std::vector>> &keys, + std::vector ¶llelRes, std::shared_ptr batchPtr) +{ + bool disabledParrallel = parallelRes.empty(); + + GetObjectRemoteRspPb &subRsp = + disabledParrallel ? *(rsp.add_responses()) : parallelRes[paraIndex].respPbs.emplace_back(); + + std::vector subPayload; + std::vector subKeys; + + auto status = GetObjectRemoteHandler(subReq, subRsp, subPayload, false, subKeys, batchPtr); + if (status.IsError()) { + subRsp.mutable_error()->set_error_code(status.GetCode()); + subRsp.mutable_error()->set_error_msg(status.GetMsg()); + return; + } + + // If keys are empty, we 1) get payload from spill or 2) urma is not enabled. + // In both cases we send payload as part of the response. + // Otherwise we extend the lifecycle of payload only until urma_write is done. + if (subKeys.empty()) { + auto &localPayload = disabledParrallel ? payload : parallelRes[paraIndex].pays; + localPayload.insert(localPayload.end(), std::make_move_iterator(subPayload.begin()), + std::make_move_iterator(subPayload.end())); + } else { + auto &localKps = disabledParrallel ? keys : parallelRes[paraIndex].kps; + localKps.emplace(paraIndex, std::make_pair(std::move(subKeys), std::move(subPayload))); + } +} + +Status WorkerWorkerOCServiceImpl::AllocateAggreagteMemory(uint64_t parallelIndex, AggregateInfo &info, + std::shared_ptr &batchPtr) +{ + if (!info.canBatchHandler) { + return Status::OK(); + } + batchPtr = std::make_shared(); + batchPtr->batchShmUnit = std::make_shared(); + Status rc = batchPtr->batchShmUnit->AllocateMemory("", info.batchSizes[parallelIndex], false, ServiceType::OBJECT, + static_cast(0)); + if (rc.IsError()) { + LOG(ERROR) << FormatString("Failed to allocate memory for batch get, size: %d", info.batchSizes[parallelIndex]); + } + auto ret = memset_s(batchPtr->batchShmUnit->GetPointer(), info.batchSizes[parallelIndex], 0, + info.batchSizes[parallelIndex]); + if (ret != EOK) { + batchPtr->batchShmUnit->SetHardFreeMemory(); + batchPtr->batchShmUnit->FreeMemory(); + LOG(ERROR) << FormatString("[Aggregated memory] Memset failed, errno: %d", ret); + } + return Status::OK(); +} + +Status WorkerWorkerOCServiceImpl::PrepareAggreagteMemory(BatchGetObjectRemoteReqPb &req, AggregateInfo &info) +{ + uint64_t reqSize = req.requests_size(); + + info.canBatchHandler = true; + // ceil data; + info.batchReqSize.clear(); + info.batchStartIndex.clear(); + info.batchSizes.clear(); + + uint64_t batchReqSize = 0; + uint64_t batchCap = 0; + uint64_t batchStartIndex = 0; + const uint64_t batchLimitKeys = 1024; // must same as obj_cache_shm_unit in req side. + uint64_t metadataSize = ocClientWorkerSvc_->GetMetadataSize(); + + for (uint64_t i = 0; i < reqSize; i++) { + uint64_t dataSize = req.requests(i).data_size(); + if (dataSize > FLAGS_oc_worker_aggregate_single_max) { + info.canBatchHandler = false; + return Status::OK(); + } + uint64_t needSize = dataSize + metadataSize; + if (batchCap + needSize > FLAGS_oc_worker_aggregate_merge_size || batchReqSize >= batchLimitKeys) { + info.batchStartIndex.emplace_back(batchStartIndex); + info.batchSizes.emplace_back(batchCap); + info.batchReqSize.emplace_back(batchReqSize); + + batchCap = 0; + batchReqSize = 0; + batchStartIndex = i; + } + + batchReqSize++; + batchCap += needSize; + } + + if (batchReqSize > 0) { + info.batchStartIndex.emplace_back(batchStartIndex); + info.batchSizes.emplace_back(batchCap); + info.batchReqSize.emplace_back(batchReqSize); + } + + return Status::OK(); +} + +Status WorkerWorkerOCServiceImpl::AggreaedMemorySend(uint64_t subIndex, AggregateInfo &info, + std::shared_ptr aggregatedMem, + std::vector ¶llelRes, + BatchGetObjectRemoteReqPb &req) { - RETURN_IF_NOT_OK(GetObjectRemoteHandler(req, rsp, payload, false, keys)); + if (!info.canBatchHandler) { + return Status::OK(); + } + RETURN_RUNTIME_ERROR_IF_NULL(aggregatedMem); + + std::vector subKeys; + std::vector subPayload; + auto startPos = info.batchStartIndex[subIndex]; + auto *subReq = req.mutable_requests(startPos); + + const uint64_t localObjectAddress = reinterpret_cast(aggregatedMem->batchShmUnit->GetPointer()); + uint64_t localSegAddress = 0; + uint64_t localSegSize; + GetSegmentInfoFromShmUnit(aggregatedMem->batchShmUnit, localObjectAddress, localSegAddress, localSegSize); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + UrmaWritePayload(subReq->urma_info(), localSegAddress, localSegSize, localObjectAddress, 0, + info.batchSizes[subIndex], ocClientWorkerSvc_->GetMetadataSize(), false, subKeys), + "Failed in aggreaed memory urma write"); + + ShmGuard shmGuard(aggregatedMem->batchShmUnit, info.batchSizes[subIndex], 0); + RETURN_IF_NOT_OK(shmGuard.TransferTo(subPayload, 0, info.batchSizes[subIndex])); + ParallelRes &loc = parallelRes[subIndex]; + loc.kps.emplace(startPos, std::make_pair(std::move(subKeys), std::move(subPayload))); return Status::OK(); } -Status WorkerWorkerOCServiceImpl::GetObjectRemoteHandler(GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, +Status WorkerWorkerOCServiceImpl::GetObjectRemoteHandler(const GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, std::vector &payload, bool blocking, - std::vector &keys) + std::vector &keys, + std::shared_ptr batchPtr) { const std::string &objectKey = req.object_key(); const std::string &requestId = req.request_id(); @@ -120,7 +256,7 @@ Status WorkerWorkerOCServiceImpl::GetObjectRemoteHandler(GetObjectRemoteReqPb &r INJECT_POINT("worker.worker_worker_remote_get_sleep"); INJECT_POINT("worker.worker_worker_remote_get_failure"); CHECK_FAIL_RETURN_STATUS(!objectKey.empty(), K_INVALID, "objectKey is empty."); - Status status = GetObjectRemoteImpl(req, rsp, payload, blocking, keys); + Status status = GetObjectRemoteImpl(req, rsp, payload, blocking, keys, batchPtr); INJECT_POINT("worker.batch_get_failure_for_keys", [&objectKey]() { if (objectKey == "key2") { return Status(K_RUNTIME_ERROR, "Injected K_RUNTIME_ERROR"); @@ -151,7 +287,10 @@ Status WorkerWorkerOCServiceImpl::GetSafeObjectEntry(const std::string &objectKe RETURN_STATUS(StatusCode::K_NOT_FOUND, "Object not found"); } bool insert = false; - RETURN_IF_NOT_OK(ocClientWorkerSvc_->objectTable_->ReserveGetAndLock(objectKey, safeEntry, insert, false, false)); + auto func = [this, &objectKey, &safeEntry, &insert]() { + return ocClientWorkerSvc_->objectTable_->ReserveGetAndLock(objectKey, safeEntry, insert, false, false); + }; + RETURN_IF_NOT_OK(RetryWhenDeadlock(func)); if (insert) { Raii innerUnlock([&safeEntry]() { safeEntry->WUnlock(); }); StatusCode code; @@ -175,7 +314,8 @@ Status WorkerWorkerOCServiceImpl::GetSafeObjectEntry(const std::string &objectKe Status WorkerWorkerOCServiceImpl::GetObjectRemoteImpl(const GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, std::vector &outPayload, bool blocking, - std::vector &keys) + std::vector &keys, + std::shared_ptr batchPtr) { (void)keys; (void)blocking; @@ -242,16 +382,30 @@ Status WorkerWorkerOCServiceImpl::GetObjectRemoteImpl(const GetObjectRemoteReqPb PerfPoint p(PerfKey::WORKER_REMOTE_GET_PAYLOAD); // Support send payload exceed 2GB if (IsUrmaEnabled() && req.has_urma_info()) { - // later add a check on data size and read size. - auto shmUnit = entry->GetShmUnit(); - const uint64_t localObjectAddress = reinterpret_cast(shmUnit->GetPointer()); - uint64_t localSegAddress; - uint64_t localSegSize; - GetSegmentInfoFromShmUnit(shmUnit, localObjectAddress, localSegAddress, localSegSize); - RETURN_IF_NOT_OK_PRINT_ERROR_MSG( - ImportSegAndWritePayload(req.urma_info(), localSegAddress, localSegSize, localObjectAddress, offset, - size, entry->GetMetadataSize(), blocking, keys), - ""); + if (batchPtr) { + batchPtr->batchCursor += entry->GetMetadataSize(); + CHECK_FAIL_RETURN_STATUS(entry->GetMetadataSize() == ocClientWorkerSvc_->GetMetadataSize(), + K_RUNTIME_ERROR, + FormatString("Metadata size mismatch, actual = %zu, expected = %zu", + entry->GetMetadataSize(), ocClientWorkerSvc_->GetMetadataSize())); + uint8_t *destPtr = static_cast(batchPtr->batchShmUnit->GetPointer()) + batchPtr->batchCursor; + uint8_t *srcPtr = static_cast(entry->GetShmUnit()->GetPointer()) + entry->GetMetadataSize(); + auto ret = memcpy_s(destPtr, entry->GetDataSize(), srcPtr, entry->GetDataSize()); + CHECK_FAIL_RETURN_STATUS(ret == EOK, K_RUNTIME_ERROR, + FormatString("Copy root info failed, the memcpy_s return: %d", ret)); + batchPtr->batchCursor += entry->GetDataSize(); + } else { + // later add a check on data size and read size. + auto shmUnit = entry->GetShmUnit(); + const uint64_t localObjectAddress = reinterpret_cast(shmUnit->GetPointer()); + uint64_t localSegAddress; + uint64_t localSegSize; + GetSegmentInfoFromShmUnit(shmUnit, localObjectAddress, localSegAddress, localSegSize); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + UrmaWritePayload(req.urma_info(), localSegAddress, localSegSize, localObjectAddress, offset, size, + entry->GetMetadataSize(), blocking, keys), + "Failed in sigle data urma write"); + } rsp.set_data_in_payload(true); } // We need to extend the ShmGuard lifecycle if we perform parallel urma_write. @@ -306,54 +460,72 @@ Status WorkerWorkerOCServiceImpl::BatchGetObjectRemote( std::map, std::vector>> keys; std::vector getObjRemoteSubRsp; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); - for (int i = 0; i < req.requests_size(); i++) { - auto tempReq = req.requests(i); - GetObjectRemoteReqPb subReq; - GetObjectRemoteRspPb subRsp; - std::vector subPayload; - // Note: Request id is remained empty, trace id should be able to identify the request. - subReq.set_object_key(tempReq.object_key()); - subReq.set_try_lock(tempReq.try_lock()); - subReq.set_version(tempReq.version()); - subReq.set_read_offset(tempReq.read_offset()); - subReq.set_read_size(tempReq.read_size()); - if (tempReq.has_urma_info()) { - subReq.set_data_size(tempReq.data_size()); - *subReq.mutable_urma_info() = std::move(tempReq.urma_info()); + if (req.requests_size() > FLAGS_oc_worker_worker_parallel_min && IsUrmaEnabled()) { + tbb::task_arena limited; + if (FLAGS_oc_worker_worker_parallel_nums > 0) { + limited.initialize(FLAGS_oc_worker_worker_parallel_nums); } - std::vector subKeys; - auto status = GetObjectRemoteBatchWrite(subReq, subRsp, subPayload, subKeys); - if (status.IsOk()) { - // If keys are empty, we 1) get payload from spill or 2) urma is not enabled. - // In both cases we send payload as part of the response. - // Otherwise we extend the lifecycle of payload only until urma_write is done. - if (subKeys.empty()) { - payload.insert(payload.end(), std::make_move_iterator(subPayload.begin()), - std::make_move_iterator(subPayload.end())); - } else { - keys.emplace(i, std::make_pair(std::move(subKeys), std::move(subPayload))); - } - } else { - subRsp.mutable_error()->set_error_code(status.GetCode()); - subRsp.mutable_error()->set_error_msg(status.GetMsg()); + std::vector parallelRes; + + AggregateInfo info; + CHECK_FAIL_RETURN_STATUS(PrepareAggreagteMemory(req, info), K_RUNTIME_ERROR, "Prepare Memory failed"); + uint64_t parallelSize = info.canBatchHandler ? info.batchReqSize.size() : req.requests_size(); + + parallelRes.resize(parallelSize); + limited.execute([&] { + tbb::parallel_for( + tbb::blocked_range(0, parallelSize), [&](const tbb::blocked_range &r) { + for (uint64_t i = r.begin(); i != r.end(); ++i) { + uint64_t startPos = info.canBatchHandler ? info.batchStartIndex[i] : i; + uint64_t endPos = info.canBatchHandler ? startPos + info.batchReqSize[i] : startPos + 1; + std::shared_ptr batchPtr = nullptr; + auto rc = AllocateAggreagteMemory(i, info, batchPtr); + if (rc.IsError()) { + LOG(ERROR) << FormatString("[parallel %d] Failed to allocate mem size: %d", i, + info.batchSizes[i]); + break; + } + for (uint64_t j = startPos; j < endPos; ++j) { + auto *subReq = req.mutable_requests(j); + GetObjectRemoteBatchWrite(i, *subReq, rsp, payload, keys, parallelRes, batchPtr); + } + LOG_IF_ERROR(AggreaedMemorySend(i, info, batchPtr, parallelRes, req), + "Send aggregated mem failed"); + } + }); + }); + + for (ParallelRes &loc : parallelRes) { + for (auto &resp : loc.respPbs) + rsp.add_responses()->Swap(&resp); + + payload.insert(payload.end(), std::make_move_iterator(loc.pays.begin()), + std::make_move_iterator(loc.pays.end())); + + for (auto &[idx, kp] : loc.kps) + keys.emplace(idx, std::move(kp)); + } + } else { + for (int i = 0; i < req.requests_size(); i++) { + std::vector emptyRes = {}; + auto *subReq = req.mutable_requests(i); + GetObjectRemoteBatchWrite(i, *subReq, rsp, payload, keys, emptyRes); } - getObjRemoteSubRsp.emplace_back(std::move(subRsp)); } pointImpl.Record(); // Wait for urma events if the events are created and not already waited. for (auto &pair : keys) { int index = pair.first; auto remainingTime = []() { return reqTimeoutDuration.CalcRealRemainingTime(); }; - auto errorHandler = [index, &getObjRemoteSubRsp](Status &status) { - getObjRemoteSubRsp[index].mutable_error()->set_error_code(status.GetCode()); - getObjRemoteSubRsp[index].mutable_error()->set_error_msg(status.GetMsg()); + auto errorHandler = [index, &rsp](Status &status) { + rsp.mutable_responses()->at(index).mutable_error()->set_error_code(status.GetCode()); + rsp.mutable_responses()->at(index).mutable_error()->set_error_msg(status.GetMsg()); return status; }; (void)WaitUrmaEvent(pair.second.first, remainingTime, errorHandler); // Early release of ShmGuard. pair.second.second.clear(); } - *rsp.mutable_responses() = { getObjRemoteSubRsp.begin(), getObjRemoteSubRsp.end() }; PerfPoint pointWrite(PerfKey::WORKER_SERVER_GET_REMOTE_WRITE); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Write(rsp), "GetObjectRemote write error"); pointWrite.Record(); diff --git a/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.h b/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.h index ed60c0d41d5ed5a4f8b7d6b06cc7179e5e972d69..a0c9baed914afe324e1518567e618dfee4ac15ae 100644 --- a/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.h +++ b/src/datasystem/worker/object_cache/worker_worker_oc_service_impl.h @@ -101,6 +101,24 @@ public: serverApi) override; private: + struct AggregateInfo { + bool canBatchHandler = false; + std::vector batchReqSize; + std::vector batchSizes; + std::vector batchStartIndex; + }; + + struct AggregateMemory { + std::shared_ptr batchShmUnit = nullptr; + int64_t batchCursor = 0; + }; + + struct ParallelRes { + std::vector respPbs; + std::vector pays; + std::map, std::vector>> kps; + }; + /** * @brief Load object data in remote get provider mode. * @param[in] req Pb Request for RemoteGet rpc. @@ -108,20 +126,59 @@ private: * @param[out] outPayload Payload buffers. * @param[in] blocking Whether to blocking wait for the urma_write to finish. * @param[out] keys The new request id to wait for if not blocking. + * @param[in] batchPtr Batch ptr, default is nullptr means not in aggreagte path. + * @return Status of the call. */ Status GetObjectRemoteImpl(const GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, - std::vector &outPayload, bool blocking, std::vector &keys); + std::vector &outPayload, bool blocking, std::vector &keys, + std::shared_ptr batchPtr = nullptr); /** * @brief Helper function to GetObjectRemote, but specialized for the batch get path. - * @param[in] req Remote get request. + * @param[in] subIndex Sub slot index of the parallel list. + * @param[in] req Remote get sub request. * @param[out] rsp Remote get response. * @param[out] payload Out payloads. * @param[out] keys The request id to wait for if not blocking. + * @param[out] parallelRes Parallel result. + * @param[in] batchPtr Batch ptr, default is nullptr means not in aggreagte path. + * @return Status of the call. + */ + void GetObjectRemoteBatchWrite(uint32_t subIndex, const GetObjectRemoteReqPb &req, BatchGetObjectRemoteRspPb &rsp, + std::vector &payload, + std::map, std::vector>> &keys, + std::vector ¶llelRes, + std::shared_ptr batchPtr = nullptr); + + /** + * @brief Helper function to BatchGetObjectRemote to prepare the aggregate info. + * @param[in] req Remote get request. + * @param[out] info Aggregated info. + * @return Status of the call. + */ + Status PrepareAggreagteMemory(BatchGetObjectRemoteReqPb &req, AggregateInfo &info); + + /** + * @brief Helper function to BatchGetObjectRemote to allocate the aggregate memory. + * @param[in] parallelIndex Parallel index of the parallel list. + * @param[in] info Aggregated info. + * @param[out] batchPtr Batch ptr, default is nullptr means not in aggreagte path. + * @return Status of the call. + */ + Status AllocateAggreagteMemory(uint64_t parallelIndex, AggregateInfo &info, + std::shared_ptr &batchPtr); + + /** + * @brief Helper function to BatchGetObjectRemote to send the aggregate memory. + * @param[in] subIndex Sub slot index of the parallel list. + * @param[in] info Aggregated info. + * @param[in] batchPtr Batch ptr, default is nullptr means not in aggreagte path. + * @param[out] parallelRes Parallel result. + * @param[in] req Remote get request. * @return Status of the call. */ - Status GetObjectRemoteBatchWrite(GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, - std::vector &payload, std::vector &keys); + Status AggreaedMemorySend(uint64_t subIndex, AggregateInfo &info, std::shared_ptr batchPtr, + std::vector ¶llelRes, BatchGetObjectRemoteReqPb &req); /** * @brief Helper function pre-process and then trigger GetObjectRemoteImpl. @@ -130,10 +187,12 @@ private: * @param[out] payload Out payloads. * @param[in] blocking Whether to blocking wait for the urma_write to finish. * @param[out] keys The request id to wait for if not blocking. + * @param[in] batchPtr Batch ptr, default is nullptr means not in aggreagte path. * @return Status of the call. */ - Status GetObjectRemoteHandler(GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, - std::vector &payload, bool blocking, std::vector &keys); + Status GetObjectRemoteHandler(const GetObjectRemoteReqPb &req, GetObjectRemoteRspPb &rsp, + std::vector &payload, bool blocking, std::vector &keys, + std::shared_ptr batchPtr = nullptr); /** * @brief Get the safe object entry. diff --git a/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp b/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp index 146d4fbb4206432c8a27ee22c01dc1ad9c728e83..a072114b51a74629e5e1ba9badf24b4ecf75d4f9 100644 --- a/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp +++ b/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp @@ -23,6 +23,7 @@ #include "datasystem/common/constants.h" #include "datasystem/common/flags/flags.h" #include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/lock_helper.h" #include "datasystem/common/util/raii.h" #include "datasystem/common/util/status_helper.h" @@ -301,7 +302,7 @@ Status ExclusivePageQueue::ReserveAdditionalMemory() // Check again after we have the lock. RETURN_OK_IF_TRUE(reserveState_.freeListCreated); std::list freeList; - std::vector undoList; + std::vector undoList; bool needRollback = true; Raii raii([this, &needRollback, &undoList]() { if (needRollback) { @@ -360,7 +361,7 @@ Status ExclusivePageQueue::InsertBigElement(void *buf, size_t sz, std::pair([this, &needRollback, &pageUnitInfo]() { if (needRollback) { - std::vector v; + std::vector v; auto pageId = StreamPageBase::CreatePageId(pageUnitInfo); v.push_back(pageId); (void)FreePages(v, true); diff --git a/src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp index 05cdb8bc5bcb9b5195b145161dac0752ed0bf67b..70cad61d61fc43a1a0a37a8b5f9c20af97769df5 100644 --- a/src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp +++ b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp @@ -20,6 +20,7 @@ #include #include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/constants.h" @@ -189,7 +190,7 @@ Status PageQueueBase::AllocMemory(size_t pageSize, bool bigElement, std::shared_ return Status::OK(); } -Status PageQueueBase::AddPageToPool(const std::string &pageId, std::unique_ptr &&pageUnit, bool bigElement) +Status PageQueueBase::AddPageToPool(const ShmKey &pageId, std::unique_ptr &&pageUnit, bool bigElement) { ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); pageUnit->refCount = 1; // The ref count is used in the case of BigElement page only. @@ -494,7 +495,7 @@ Status PageQueueBase::ReleaseBigElementsUpTo(uint64_t cursor, std::shared_ptr("", bigId.back().second, nullptr); + auto pageInfo = std::make_shared(ShmKey::Intern(""), bigId.back().second, nullptr); auto pageId = StreamPageBase::CreatePageId(pageInfo); RETURN_IF_NOT_OK(GetBigElementPageRefCount(pageId, bigRefCount)); if (bigRefCount > 1) { @@ -518,7 +519,7 @@ Status PageQueueBase::ReleaseBigElementsUpTo(uint64_t cursor, std::shared_ptrbigElement) { ++numBigElements; @@ -727,7 +728,7 @@ Status PageQueueBase::ReleaseMemory(const ShmView &pageView) auto bigElementPage = std::make_shared(pageInfo, true); RETURN_IF_NOT_OK(bigElementPage->Init()); LOG(INFO) << FormatString("[%s] Release big element page<%s>", LogPrefix(), bigElementPage->GetPageId()); - std::vector list; + std::vector list; list.push_back(bigElementPage->GetPageId()); return FreePages(list, true); } @@ -783,7 +784,7 @@ void PageQueueBase::ForceUnlockMemViemForPages(uint32_t lockId) Status PageQueueBase::ProcessBigElementPages(std::vector &bigElementId, StreamMetaShm *streamMetaShm) { RETURN_OK_IF_TRUE(bigElementId.empty()); - std::vector freeList; + std::vector freeList; auto func = [this, &freeList](const ShmView &v) { std::shared_ptr shmInfo; RETURN_IF_NOT_OK(LocatePage(v, shmInfo)); @@ -802,7 +803,7 @@ Status PageQueueBase::ProcessBigElementPages(std::vector &bigElementId, Status PageQueueBase::LocatePage(const ShmView &v, std::shared_ptr &out) { - auto pageInfo = std::make_shared("", v, nullptr); + auto pageInfo = std::make_shared(ShmKey::Intern(""), v, nullptr); auto pageId = StreamPageBase::CreatePageId(pageInfo); ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); ShmPagesMap::accessor accessor; @@ -886,7 +887,7 @@ Status PageQueueBase::ProcessAckedPages(uint64_t cursor, std::list return AppendFreePages(freeList); } -Status PageQueueBase::FreePages(std::vector &pages, bool bigElementPage, StreamMetaShm *streamMetaShm) +Status PageQueueBase::FreePages(std::vector &pages, bool bigElementPage, StreamMetaShm *streamMetaShm) { PerfPoint point(PerfKey::PAGE_RELEASE); ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); @@ -958,7 +959,7 @@ Status PageQueueBase::FreePendingList() if (std::chrono::duration_cast(now - start).count() >= interval) { auto ele = std::move(pendingFreePages_.front()); pendingFreePages_.pop_front(); - std::vector freePages; + std::vector freePages; auto &list = std::get(ele); std::transform(list.begin(), list.end(), std::back_inserter(freePages), [](const auto &kv) { return kv->GetPageId(); }); @@ -1165,7 +1166,7 @@ Status PageQueueBase::SendElements(const std::shared_ptr &page, return Status::OK(); } -Status PageQueueBase::IncBigElementPageRefCount(const std::string &pageId) +Status PageQueueBase::IncBigElementPageRefCount(const ShmKey &pageId) { ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); ShmPagesMap::accessor accessor; @@ -1196,7 +1197,7 @@ Status PageQueueBase::ExtractBigElement(DataElement &ele, std::shared_ptr("", v, nullptr); + auto pageInfo = std::make_shared(ShmKey::Intern(""), v, nullptr); auto pageId = StreamPageBase::CreatePageId(pageInfo); ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); ShmPagesMap::accessor accessor; diff --git a/src/datasystem/worker/stream_cache/page_queue/page_queue_base.h b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.h index d40d873d70575e9d95da706f6c9300d36bb6f8f0..64ba84f8a5e0730d2e2c4a9db50fad23089eeea6 100644 --- a/src/datasystem/worker/stream_cache/page_queue/page_queue_base.h +++ b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.h @@ -27,6 +27,7 @@ #include "datasystem/common/flags/flags.h" #include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/string_intern/string_ref.h" DS_DECLARE_uint32(sc_cache_pages); @@ -49,7 +50,7 @@ public: std::chrono::time_point createTime; bool bigElement; }; - using ShmPagesMap = tbb::concurrent_hash_map>; + using ShmPagesMap = tbb::concurrent_hash_map>; using PageShmUnit = std::pair>; PageQueueBase(); @@ -213,9 +214,9 @@ public: */ virtual RemoteWorkerManager *GetRemoteWorkerManager() const = 0; - virtual Status IncBigElementPageRefCount(const std::string &pageId); + virtual Status IncBigElementPageRefCount(const ShmKey &pageId); virtual Status ExtractBigElement(DataElement &ele, std::shared_ptr &bigElementPage); - virtual Status DecBigElementPageRefCount(const std::string &pageId); + virtual Status DecBigElementPageRefCount(const ShmKey &pageId); virtual Status UpdatePageRefIfExist(const ShmView &v, const std::string &logPrefix, bool toggle); /** @@ -249,7 +250,7 @@ protected: std::shared_ptr &lastPage, bool retryOnOOM); Status CreateNewPage(std::shared_ptr &lastPage, bool retryOnOOM); - Status AddPageToPool(const std::string &pageId, std::unique_ptr &&pageUnit, bool bigElement); + Status AddPageToPool(const ShmKey &pageId, std::unique_ptr &&pageUnit, bool bigElement); Status VerifyLastPageRefCountNotLocked() const; Status AppendFreePagesImplNotLocked(uint64_t timeoutMs, Optional> &freeList, bool seal, const bool updateLocalPubLastPage = true); @@ -259,12 +260,12 @@ protected: StreamMetaShm *streamMetaShm = nullptr); Status ReleaseBigElementsUpTo(uint64_t cursor, std::shared_ptr &page, std::vector &bigElementPages, bool &keepThisPageInChain); - Status GetBigElementPageRefCount(const std::string &pageId, int32_t &refCount); + Status GetBigElementPageRefCount(const ShmKey &pageId, int32_t &refCount); Status AppendFreePages(std::list &freeList, const bool updateLocalPubLastPage = true); Status ProcessBigElementPages(std::vector &bigElementId, StreamMetaShm *streamMetaShm); Status LocatePage(const ShmView &v, std::shared_ptr &out); Status ProcessAckedPages(uint64_t cursor, std::list &freeList); - Status FreePages(std::vector &pages, bool bigElementPage = false, + Status FreePages(std::vector &pages, bool bigElementPage = false, StreamMetaShm *streamMetaShm = nullptr); Status FreePendingList(); Status MoveFreeListToPendFree(uint64_t cursor, std::list &freeList); diff --git a/src/datasystem/worker/stream_cache/remote_worker_manager.cpp b/src/datasystem/worker/stream_cache/remote_worker_manager.cpp index 95869708e466b5a368f8a0c6d64525eb7cd07261..66308dba64917398f1b007227c4253968a9a6096 100644 --- a/src/datasystem/worker/stream_cache/remote_worker_manager.cpp +++ b/src/datasystem/worker/stream_cache/remote_worker_manager.cpp @@ -144,12 +144,12 @@ Status StreamElementView::IncRefCount() Raii raii([&bigElementLocked, this]() { // Unlock in case of error. if (!ref_ && bigElementLocked) { - std::string pageId = bigElementPage_->GetPageId(); + auto pageId = bigElementPage_->GetPageId(); (void)dataObj_->DecBigElementPageRefCount(pageId); } }); if (bigElement_) { - std::string pageId = bigElementPage_->GetPageId(); + auto pageId = bigElementPage_->GetPageId(); RETURN_IF_NOT_OK(dataObj_->IncBigElementPageRefCount(pageId)); bigElementLocked = true; } diff --git a/src/datasystem/worker/worker_liveness_check.cpp b/src/datasystem/worker/worker_liveness_check.cpp index f563c9e94ff3c12511699975aa8cd87519aacfd1..6fe6ab148b7f29d27c1e177fe1c42c9f11a378bd 100644 --- a/src/datasystem/worker/worker_liveness_check.cpp +++ b/src/datasystem/worker/worker_liveness_check.cpp @@ -271,7 +271,7 @@ Status WorkerLivenessCheck::CheckRocksDbService() metadata->set_data_size(1); metadata->set_life_state(static_cast(ObjectLifeState::OBJECT_PUBLISHED)); metadata->set_ttl_second(0); - ObjectMetaPb::ConfigPb *configPb = metadata->mutable_config(); + ConfigPb *configPb = metadata->mutable_config(); configPb->set_write_mode(static_cast(WriteMode::NONE_L2_CACHE)); configPb->set_data_format(static_cast(DataFormat::BINARY)); configPb->set_consistency_type(static_cast(ConsistencyType::PRAM)); diff --git a/src/datasystem/worker/worker_oc_server.cpp b/src/datasystem/worker/worker_oc_server.cpp index 48195a5d6b22c40b96fbffadc8427bede09a1061..c3668d49936429faeafcfb447d73e784a2f9f9d6 100644 --- a/src/datasystem/worker/worker_oc_server.cpp +++ b/src/datasystem/worker/worker_oc_server.cpp @@ -159,6 +159,15 @@ DS_DEFINE_bool(shared_memory_populate, false, "startup times (depending on shared_memory_size_mb)."); DS_DECLARE_uint32(arena_per_tenant); DS_DECLARE_bool(enable_fallocate); + +DS_DEFINE_int32(oc_worker_worker_parallel_nums, 0, "worker worker batch rsp control nums, default 0 means unlimited"); +DS_DEFINE_int32(oc_worker_worker_parallel_min, 100, + "Min data count for parallel worker worker batch rsp, default is 100"); +DS_DEFINE_uint64(oc_worker_aggregate_single_max, 65536, + "Max single item size for batching worker worker batch rsp, default is 64KB"); +DS_DEFINE_uint64(oc_worker_aggregate_merge_size, 2097152, + " Target batch size for worker worker responses, default is 2MB"); + static bool ValidatePopulate(const char *flagName, bool value) { if (!value) { @@ -1250,7 +1259,7 @@ Status WorkerOCServer::Shutdown() return Status::OK(); } -Status WorkerOCServer::GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, std::string &id) +Status WorkerOCServer::GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, ShmKey &id) { if (!EnableOCService()) { fd = -1; @@ -1295,6 +1304,11 @@ Status WorkerOCServer::AddClient(const std::string &clientId, bool shmEnabled, i clientId, std::bind(&WorkerOCServer::AfterClientLostHandler, this, clientId), HeartbeatType::RPC_HEARTBEAT); } +Status WorkerOCServer::GetExclConnSockPath(std::string &sockPath) +{ + return objCacheClientWorkerSvc_->GetExclConnSockPath(sockPath); +} + void WorkerOCServer::CheckRule(bool isAsyncTasksRunning, int &checkNum) { int updateCheckNum = static_cast(FLAGS_check_async_queue_empty_time_s / CHECK_ASYNC_SLEEP_TIME_S); diff --git a/src/datasystem/worker/worker_oc_server.h b/src/datasystem/worker/worker_oc_server.h index 6655470bad41c80691a9ee5161a1f9bbe562ec61..eda308111cd664696b60fb1493c7a02ea63fa28c 100644 --- a/src/datasystem/worker/worker_oc_server.h +++ b/src/datasystem/worker/worker_oc_server.h @@ -95,7 +95,7 @@ public: * @param[out] id The id of this shmUnit. * @return Status of the call. */ - Status GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, std::string &id) override; + Status GetShmQueueUnit(uint32_t lockId, int &fd, uint64_t &mmapSize, ptrdiff_t &offset, ShmKey &id) override; /** * @brief After restart crashed server, we need to do some recovery job according to the message from the client. @@ -108,6 +108,8 @@ public: Status ProcessServerReboot(const std::string &clientId, const std::string &tenantId, const std::string &reqToken, const google::protobuf::RepeatedPtrField &msg) override; + Status GetExclConnSockPath(std::string &sockPath) override; + /** * @brief Register a client to client manager. * @param[in] clientId The clientId. diff --git a/src/datasystem/worker/worker_service_impl.cpp b/src/datasystem/worker/worker_service_impl.cpp index 9e08ee4595e0ac84e25bb419eeb04fa1dd0c07e0..b1e3d8487a345a26189a412c8f99c5337e7169db 100644 --- a/src/datasystem/worker/worker_service_impl.cpp +++ b/src/datasystem/worker/worker_service_impl.cpp @@ -254,9 +254,15 @@ Status WorkerServiceImpl::RegisterClient(const RegisterClientReqPb &req, Registe int fd; uint64_t mmapSize; ptrdiff_t offset; - std::string id; + ShmKey id; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker_->GetShmQueueUnit(lockId, fd, mmapSize, offset, id), "worker process get ShmQ unit failed"); + + std::string exclusiveConnSockPath; + if (req.enable_exclusive_connection()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker_->GetExclConnSockPath(exclusiveConnSockPath), + "worker process get exclusive connection socket path failed"); + } rsp.set_page_size(FLAGS_page_size); rsp.set_quorum_timeout_mult(timeoutMultiplier_); rsp.set_client_id(clientId); @@ -277,6 +283,7 @@ Status WorkerServiceImpl::RegisterClient(const RegisterClientReqPb &req, Registe rsp.set_client_dead_timeout_s(clientDeadTimeoutSec); rsp.set_enable_p2p_transfer(FLAGS_enable_p2p_transfer); rsp.set_client_reconnect_wait_s(FLAGS_client_reconnect_wait_s); + rsp.set_exclusive_conn_sockpath(exclusiveConnSockPath); INJECT_POINT("worker.RegisterClient.end", [&rsp](int fd) { rsp.set_store_fd(fd); diff --git a/tests/st/client/kv_cache/kv_cache_client_test.cpp b/tests/st/client/kv_cache/kv_cache_client_test.cpp index 6e75807d27100f66cb7764d0c3c415098c9ba2b2..d60d831139e1c0e692e778386a0ff73cbd6f5072 100644 --- a/tests/st/client/kv_cache/kv_cache_client_test.cpp +++ b/tests/st/client/kv_cache/kv_cache_client_test.cpp @@ -347,7 +347,7 @@ TEST_F(KVCacheClientTest, TestSetAndGetSubscribeTimeout) DS_ASSERT_OK(client0->Set(key, val)); } -TEST_F(KVCacheClientTest, TestSpecialKeyVal) +TEST_F(KVCacheClientTest, DISABLED_TestSpecialKeyVal) { std::shared_ptr client; InitTestKVClient(0, client); diff --git a/tests/st/client/object_cache/client_get_test.cpp b/tests/st/client/object_cache/client_get_test.cpp index 825cc0bf607ddb368d755294c90448adf7e0725f..41fc8b88c8e2a3fa08a61b87624a6f99d3b9b768 100644 --- a/tests/st/client/object_cache/client_get_test.cpp +++ b/tests/st/client/object_cache/client_get_test.cpp @@ -38,6 +38,7 @@ #include "datasystem/common/util/wait_post.h" #include "datasystem/object_client.h" #include "datasystem/object/object_enum.h" +#include "datasystem/common/perf/perf_manager.h" #include "datasystem/utils/status.h" #include "oc_client_common.h" #include "datasystem/common/metrics/res_metric_collector.h" @@ -1529,9 +1530,9 @@ TEST_F(OCClientRemoteGetTest3, DISABLED_LEVEL1_TestGetOOMScenario) std::vector data(1024 * 1024 * 8, '0'); std::vector objKeys = { "ji", "ni", "tai", "mei" }; std::vector> clients; + const int timeoutMs = 30000; for (size_t i = 0; i < 4; ++i) { std::shared_ptr client; - int timeoutMs = 30'000; InitTestClient(i, client, timeoutMs); if (i == 0) { std::vector failedObjectKeys; @@ -1728,7 +1729,8 @@ TEST_F(OCClientRemoteGetTest4, DISABLED_TestObjectPutAndGetConcurrency) std::string newVal = RandomData().GetRandomString(1024ul * 1024ul); std::thread t2([&client0, &newVal, &objKey, ¶m]() { - usleep(1'000); + const int sleepTime = 1000; + usleep(sleepTime); DS_EXPECT_OK(client0->Put(objKey, (uint8_t *)newVal.c_str(), newVal.size(), param)); }); @@ -1820,5 +1822,232 @@ TEST_F(OCClientRemoteGetTest5, TestRemoteGetAndRemoveLocationFailedThenPut) DS_ASSERT_OK(client0->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); } +// Not a permanent testcase. Just create it to aid development of Exclusive connection feature for now +class ExclusiveConn : public OCClientGetTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 2; // only using 1, but have 2 just for giggles + opts.enableDistributedMaster = "false"; + opts.masterIdx = 0; + opts.workerGflagParams = "-shared_memory_size_mb=500 -minloglevel=2"; // for perf runs + opts.disableRocksDB = true; // Maybe the default is true anyway for ctest ExternalCluster tests + opts.numEtcd = 1; + // Since we may be testing perf with these, make sure the AKSK keys are empty to avoid auth overhead + opts.systemAccessKey = ""; + opts.systemSecretKey = ""; + } +}; + +TEST_F(ExclusiveConn, PutTest1) +{ + int numPuts = 1000; + int32_t timeoutMs = 30000; // 30 second timeout. default is 60. + LOG(INFO) << "Testing " << numPuts << "puts with timeout " << timeoutMs; + ConnectOptions connectOptions; + InitConnectOpt(0, connectOptions, timeoutMs); // connect to worker 0 + connectOptions.enableExclusiveConnection = true; + std::shared_ptr client0 = std::make_shared(connectOptions); + DS_ASSERT_OK(client0->Init()); + + std::vector objKeys; + std::string val("a"); + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; + + LOG(INFO) << "create the data locally before putting"; + for (int i = 0; i < numPuts; ++i) { + objKeys.push_back("AmazingKey_" + std::to_string(i)); + } + + LOG(INFO) << "now put the data"; + // A thread not needed right now, could have just run it in current thread. But later, we'll test + // different threads using same client. + std::thread t1([&client0, &objKeys, &val, ¶m, &numPuts]() { + for (int i = 0; i < numPuts; ++i) { + DS_ASSERT_OK(client0->Put(objKeys[i], (uint8_t *)val.c_str(), val.size(), param)); + } + LOG(INFO) << "Thread work completed. Thread exits now."; + }); + LOG(INFO) << "Parent join the complete thread."; + t1.join(); + + // maybe this not needed + LOG(INFO) << "Calling perf manager print from client side in case it didn't get written."; + PerfManager *perfManager = PerfManager::Instance(); + perfManager->PrintPerfLog(); +} + +TEST_F(ExclusiveConn, GetTest1) +{ + std::chrono::time_point start; + std::chrono::time_point end; + uint64_t totalElapsed = 0; + int numClients = 1; + int totalGets = 2000; + int numGetsPerClient = totalGets / numClients; + totalGets = numGetsPerClient * numClients; // Sanity/fix: in case chosen numbers were not multiples + int32_t timeoutMs = 30000; // 30 second timeout. default is 60. + std::vector threads; + + LOG(INFO) << "Testing " << numGetsPerClient << " gets for per cleint. Num clients: " << numClients + << " Using timeout: " << timeoutMs; + ConnectOptions connectOptions; + InitConnectOpt(0, connectOptions, timeoutMs); // connect to worker 0 + // InitConnectOpt from the test framework auto-populates the ak-sk keys for authentication. + // We want to run with authentication disabled, so clear these. + connectOptions.accessKey.clear(); + connectOptions.secretKey.Clear(); + connectOptions.enableExclusiveConnection = true; + std::shared_ptr client0 = std::make_shared(connectOptions); + DS_ASSERT_OK(client0->Init()); + + std::vector objKeys; + std::string val("a"); + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; + objKeys.push_back("AmazingKey_1"); + LOG(INFO) << "Put some data to DS that we'll use for getting later."; + DS_ASSERT_OK(client0->Put(objKeys[0], (uint8_t *)val.c_str(), val.size(), param)); + + // We will use a single client, but many threads can share that same client. + start = std::chrono::steady_clock::now(); + for (int i = 0; i < numClients; ++i) { + threads.emplace_back([&client0, &objKeys, &numGetsPerClient]() { + LOG(INFO) << "Begin get loop in thread."; + std::vector> buffers; + for (int i = 0; i < numGetsPerClient; ++i) { + DS_EXPECT_OK(client0->Get(objKeys, 0, buffers)); + buffers.clear(); + } + LOG(INFO) << "Thread work completed. Thread exits now."; + }); + } + + LOG(INFO) << "Parent join the complete threads."; + for (auto &t : threads) { + t.join(); + } + end = std::chrono::steady_clock::now(); + totalElapsed += std::chrono::duration_cast(end - start).count(); + + // Error level so we get this in perf runs too + LOG(ERROR) << "Count of gets: " << totalGets << "\nTotal time: " << totalElapsed + << "\nAverage per call: " << totalElapsed / totalGets << std::endl; + + // maybe this not needed + LOG(INFO) << "Calling perf manager print from client side in case it didn't get written."; + PerfManager *perfManager = PerfManager::Instance(); + perfManager->PrintPerfLog(); +} + +TEST_F(ExclusiveConn, ChildThreadSinglePutAndGetTest) +{ + FLAGS_v = 10; // get lots of logs at client layer + LOG(INFO) << "Testing single put and get"; + ConnectOptions connectOptions; + InitConnectOpt(0, connectOptions); // connect to worker 0 + connectOptions.enableExclusiveConnection = true; + std::shared_ptr client0 = std::make_shared(connectOptions); + DS_ASSERT_OK(client0->Init()); + + LOG(INFO) << "create the data locally before putting"; + std::string objKey("AmazingKey_a"); + std::string val("a"); + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; + + LOG(INFO) << "now put the data"; + // A thread not needed right now, could have just run it in current thread. But later, we'll test + // different threads using same client. + std::thread t1([&client0, &objKey, &val, ¶m]() { + DS_ASSERT_OK(client0->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); + + std::vector> buffers; + DS_EXPECT_OK(client0->Get({ objKey }, 0, buffers)); + ASSERT_EQ(buffers.size(), size_t(1)); + ASSERT_TRUE(buffers[0]); + std::string getVal((char *)buffers[0]->ImmutableData(), buffers[0]->GetSize()); + EXPECT_EQ(getVal, val); + }); + + t1.join(); + + // maybe this not needed + LOG(INFO) << "Calling perf manager print from client side in case it didn't get written."; + PerfManager *perfManager = PerfManager::Instance(); + perfManager->PrintPerfLog(); +} + +TEST_F(ExclusiveConn, MainThreadSinglePutAndGetTest) +{ + FLAGS_v = 10; + // get lots of logs at client layer + int numPuts = 1000; + LOG(INFO) << "Testing " << numPuts << "puts"; + ConnectOptions connectOptions; + InitConnectOpt(0, connectOptions); + // connect to worker 0 + connectOptions.enableExclusiveConnection = true; + std::shared_ptr client0 = std::make_shared(connectOptions); + DS_ASSERT_OK(client0->Init()); + std::string objKey = "AmazingKey_1001"; + std::string val("a"); + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; + DS_ASSERT_OK(client0->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); + + std::vector> buffers; + DS_EXPECT_OK(client0->Get({ objKey }, 0, buffers)); + ASSERT_EQ(buffers.size(), size_t(1)); + ASSERT_TRUE(buffers[0]); + std::string getVal((char *)buffers[0]->ImmutableData(), buffers[0]->GetSize()); + EXPECT_EQ(getVal, val); + + LOG(INFO) << "Calling perf manager print from client side in case it didn't get written."; + PerfManager *perfManager = PerfManager::Instance(); + perfManager->PrintPerfLog(); +} + +TEST_F(ExclusiveConn, MultiThreadSinglePutAndGetTest) +{ + FLAGS_v = 10; // get lots of logs at client layer + int numThreads = 5; + LOG(INFO) << "Testing multi threads single put and get"; + ConnectOptions connectOptions; + InitConnectOpt(0, connectOptions); // connect to worker 0 + connectOptions.enableExclusiveConnection = true; + std::shared_ptr client0 = std::make_shared(connectOptions); + DS_ASSERT_OK(client0->Init()); + + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; + + LOG(INFO) << "now put the data"; + std::vector threads; + for (auto i = 0; i < numThreads; i++) { + std::string objKey = "AmazingKey_" + std::to_string(i); + std::string val("a"); + + threads.emplace_back([i, &client0, &objKey, &val, ¶m]() { + LOG(INFO) << "Thread " << std::to_string(i) << "starts."; + DS_ASSERT_OK(client0->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); + std::vector> buffers; + DS_EXPECT_OK(client0->Get({ objKey }, 0, buffers)); + ASSERT_EQ(buffers.size(), size_t(1)); + ASSERT_TRUE(buffers[0]); + std::string getVal((char *)buffers[0]->ImmutableData(), buffers[0]->GetSize()); + EXPECT_EQ(getVal, val); + LOG(INFO) << "Thread " << std::to_string(i) << "finishes."; + }); + } + + for (auto& t : threads) { + if (t.joinable()) { + t.join(); + } + } + + // maybe this not needed + LOG(INFO) << "Calling perf manager print from client side in case it didn't get written."; + PerfManager *perfManager = PerfManager::Instance(); + perfManager->PrintPerfLog(); +} + } // namespace st } // namespace datasystem diff --git a/tests/st/client/object_cache/client_skip_auth_test.cpp b/tests/st/client/object_cache/client_skip_auth_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d26684e6b535ec84bb5034efd098c88a324ebf3b --- /dev/null +++ b/tests/st/client/object_cache/client_skip_auth_test.cpp @@ -0,0 +1,82 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "common.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/util/file_util.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/timer.h" +#include "datasystem/common/util/wait_post.h" +#include "datasystem/object_client.h" +#include "datasystem/object/object_enum.h" +#include "datasystem/utils/status.h" +#include "oc_client_common.h" +#include "datasystem/common/metrics/res_metric_collector.h" + +namespace datasystem { +namespace st { +namespace { +} // namespace +class ClientSkipAuthTest : public OCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + auto workerNum = 2; + opts.numWorkers = workerNum; + opts.numEtcd = 1; + opts.workerGflagParams = + FormatString("-skip_authenticate=true -shared_memory_size_mb=50 -v=2 -log_monitor=true"); + } + +protected: +}; + +TEST_F(ClientSkipAuthTest, NoAuthPutGet) +{ + FLAGS_v = 1; + std::shared_ptr cliLocal, clientRemote; + InitTestClient(0, cliLocal); + InitTestClient(1, clientRemote); + std::string val = "is a test "; + std::string objKey = "key"; + CreateParam param; + DS_ASSERT_OK(cliLocal->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); + std::vector> buffers; + DS_EXPECT_OK(clientRemote->Get({ objKey }, 0, buffers)); + ASSERT_EQ(buffers.size(), size_t(1)); + ASSERT_TRUE(buffers[0]); + char buff[buffers[0]->GetSize() + 1]; + buff[buffers[0]->GetSize()] = '\0'; + memcpy_s(buff, buffers[0]->GetSize(), buffers[0]->ImmutableData(), buffers[0]->GetSize()); + std::string getVal(buff); + EXPECT_EQ(getVal, val); +} + +} // namespace st +} // namespace datasystem \ No newline at end of file diff --git a/tests/st/client/object_cache/object_client_scale_test.cpp b/tests/st/client/object_cache/object_client_scale_test.cpp index b2237125e4cf8a62b03c8ce4ded50068be700390..db3eda2e9332fd22d0c721edee94b13bf5bae522 100644 --- a/tests/st/client/object_cache/object_client_scale_test.cpp +++ b/tests/st/client/object_cache/object_client_scale_test.cpp @@ -1053,7 +1053,7 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1NoneL2CacheWithCopy) ASSERT_EQ(metaNum, 400); // obj is 400 } -TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1WriteBackWithCopy) +TEST_F(OCVoluntaryScaleDownTest, DISABLED_VoluntaryDownWorker1WriteBackWithCopy) { DS_ASSERT_OK(cluster_->StartOBS()); StartWorkerAndWaitReady({ 0, 1, 2 }); @@ -1961,4 +1961,4 @@ TEST_F(OCVScaleDownDiskTest, VoluntaryDownMigrateDataMultiType) } } } // namespace st -} // namespace datasystem \ No newline at end of file +} // namespace datasystem diff --git a/tests/st/client/object_cache/urma_object_client_test.cpp b/tests/st/client/object_cache/urma_object_client_test.cpp index 2dda532c9df03101b13e4d7bc6c55bd196541c02..6229ac612acc3aca2b987378e8849d83c6f39243 100644 --- a/tests/st/client/object_cache/urma_object_client_test.cpp +++ b/tests/st/client/object_cache/urma_object_client_test.cpp @@ -520,8 +520,8 @@ TEST_F(UrmaObjectClientTest, UrmaRemoteGetTwoSmallParallel) TEST_F(UrmaObjectClientTest, UrmaRemoteGetSizeChanged) { - DS_ASSERT_OK( - cluster_->SetInjectAction(WORKER, 0, "WorkerOcServiceGetImpl.PrepareUrmaInfo.changeSize", "1*call(1023)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "WorkerOcServiceGetImpl.PrepareGetRequestHelper.changeSize", + "1*call(1023)")); std::shared_ptr client1; std::shared_ptr client2; InitTestClient(0, client1); @@ -543,8 +543,8 @@ TEST_F(UrmaObjectClientTest, UrmaRemoteGetSizeChanged) TEST_F(UrmaObjectClientTest, UrmaRemoteBatchGetSizeChanged) { // Test that with batch get, a batch of failure due to size change can be retried automatically. - DS_ASSERT_OK( - cluster_->SetInjectAction(WORKER, 0, "WorkerOcServiceGetImpl.PrepareUrmaInfo.changeSize", "10*call(1023)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "WorkerOcServiceGetImpl.PrepareGetRequestHelper.changeSize", + "10*call(1023)")); std::shared_ptr client1; std::shared_ptr client2; InitTestKVClient(0, client1); @@ -580,8 +580,8 @@ TEST_F(UrmaObjectClientTest, UrmaRemoteBatchGetSizeChanged) TEST_F(UrmaObjectClientTest, UrmaRemoteGetSizeChangedInvalid) { - DS_ASSERT_OK( - cluster_->SetInjectAction(WORKER, 0, "WorkerOcServiceGetImpl.PrepareUrmaInfo.changeSize", "1*call(1023)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "WorkerOcServiceGetImpl.PrepareGetRequestHelper.changeSize", + "1*call(1023)")); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "WorkerWorkerOCServiceImpl.GetObjectRemoteImpl.changeDataSize", "1*call(0)")); std::shared_ptr client1; diff --git a/tests/st/master/object_cache/oc_giveup_primary_test.cpp b/tests/st/master/object_cache/oc_giveup_primary_test.cpp index 733c91eddb3c25e73e8ff13e7edde2063869cb71..d6349436a351fd9c134460fcd8baa7c868ad907d 100644 --- a/tests/st/master/object_cache/oc_giveup_primary_test.cpp +++ b/tests/st/master/object_cache/oc_giveup_primary_test.cpp @@ -183,7 +183,7 @@ public: metadata->set_data_size(dataSize); metadata->set_life_state(static_cast(ObjectLifeState::OBJECT_PUBLISHED)); metadata->set_ttl_second(0); - ObjectMetaPb::ConfigPb *configPb = metadata->mutable_config(); + ConfigPb *configPb = metadata->mutable_config(); configPb->set_write_mode(static_cast(mode)); configPb->set_data_format(static_cast(DataFormat::BINARY)); configPb->set_consistency_type(static_cast(ConsistencyType::CAUSAL)); diff --git a/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp b/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp index 4aa282a59e12113640af942700e3d8330f53ab14..8e0d56e11b1357e34aa4a2cb1d6bfc80de9f5761 100644 --- a/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp +++ b/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp @@ -150,7 +150,7 @@ public: metadata->set_data_size(dataSize_); metadata->set_life_state(static_cast(lifeState_)); metadata->set_ttl_second(ttlSecond_); - ObjectMetaPb::ConfigPb *configPb = metadata->mutable_config(); + ConfigPb *configPb = metadata->mutable_config(); configPb->set_write_mode(static_cast(writeMode_)); configPb->set_data_format(static_cast(dataFormat_)); configPb->set_consistency_type(static_cast(consistencyType_)); @@ -339,13 +339,13 @@ TEST_F(OCMigrateMetadataManagerTest, TestNestRefMigrateSuccess) failedObjectKeys)); for (const auto &id : nestedKeys_) { LOG(INFO) << nestedKeys_.size(); - ASSERT_TRUE(ocMetadataManager->CheckIsNoneNestedRefById()->CheckIsNoneNestedRefById(id)); + ASSERT_TRUE(ocMetadataManager->GetNestedRefManager()->CheckIsNoneNestedRefById(id)); } for (auto id : objectKeys_) { CheckMigrationMetadata(id, true); std::vector objKeys; - ocMetadataManager->CheckIsNoneNestedRefById()->GetNestedRelationship(id, objKeys); + ocMetadataManager->GetNestedRefManager()->GetNestedRelationship(id, objKeys); ASSERT_TRUE(objKeys.empty()); } } @@ -389,15 +389,15 @@ public: CreateMultiMetaReqPb request; CreateMultiMetaRspPb response; for (int i = 0; i < num; ++i) { - datasystem::ObjectMetaPb *metadata = request.add_metas(); + datasystem::ObjectBaseInfoPb *metadata = request.add_metas(); std::string objectKey = "MigrateMetadataTestId" + std::to_string(i); objectKeys_.emplace_back(objectKey); metadata->set_object_key(objectKey); metadata->set_data_size(dataSize_); - metadata->set_life_state(static_cast(lifeState_)); - metadata->set_ttl_second(ttlSecond_); - metadata->set_existence(ExistenceOptPb::NX); - ObjectMetaPb::ConfigPb *configPb = metadata->mutable_config(); + request.set_life_state(static_cast(lifeState_)); + request.set_ttl_second(ttlSecond_); + request.set_existence(ExistenceOptPb::NX); + ConfigPb *configPb = request.mutable_config(); configPb->set_write_mode(static_cast(writeMode_)); configPb->set_data_format(static_cast(dataFormat_)); configPb->set_consistency_type(static_cast(consistencyType_)); diff --git a/tests/st/worker/object_cache/worker_oc_eviction_test.cpp b/tests/st/worker/object_cache/worker_oc_eviction_test.cpp index 146c87013b19b272461e1ef96d33f6dace7761be..42720f3b0b6993382c5a680d6ffb0485b19895ba 100644 --- a/tests/st/worker/object_cache/worker_oc_eviction_test.cpp +++ b/tests/st/worker/object_cache/worker_oc_eviction_test.cpp @@ -290,7 +290,7 @@ public: if (clientApi->IsV2Client()) { RETURN_IF_NOT_OK(clientApi->ReceivePayload(dest, size)); } else { - shmUnit->id = GetStringUuid(); + shmUnit->id = ShmKey::Intern(GetStringUuid()); RETURN_IF_NOT_OK(clientApi->ReceivePayload(payloads)); size_t payloadLen = 0; @@ -780,7 +780,7 @@ TEST_F(EvictionManagerAndMasterTest, DISABLED_WriteBackDelayTest) auto globalRefTable = std::make_shared(); DS_EXPECT_OK(evictionManager->Init(globalRefTable, akSkManager_)); - std::shared_ptr api; + std::shared_ptr api = std::make_shared(); DS_ASSERT_OK(api->Init()); AsyncSendManager asyncMgr(api, evictionManager); // Stop async send thread. diff --git a/tests/ut/CMakeLists.txt b/tests/ut/CMakeLists.txt index e3e022dee98bc2784199c9149256a57880bccadb..a9e4f043bb47fcaa6039c53005811d8fe1f9d7db 100644 --- a/tests/ut/CMakeLists.txt +++ b/tests/ut/CMakeLists.txt @@ -22,6 +22,7 @@ set(DS_UT_DEPEND_LIBS common_persistence_api common_immutable_string string_ref + common_parallel httpclient master_object_cache worker_object_cache diff --git a/tests/ut/bench_helper.h b/tests/ut/bench_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..f6dfd65abd6161f0d5675c323d1fe6ccab72a5e5 --- /dev/null +++ b/tests/ut/bench_helper.h @@ -0,0 +1,144 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Datasystem unit test base class for benchmark. + */ +#ifndef DATASYSTEM_TEST_UT_BENCH_HELPER_H +#define DATASYSTEM_TEST_UT_BENCH_HELPER_H + +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "datasystem/common/util/timer.h" +#include "datasystem/common/util/wait_post.h" + +namespace datasystem { +namespace ut { +class BenchHelper { +public: + static std::string GetBenchCost(std::vector> &costsVec) + { + std::vector costs; + for (const auto &c : costsVec) { + for (const auto &cost : c) { + costs.emplace_back(cost); + } + } + std::stringstream ss; + std::sort(costs.begin(), costs.end()); + // remove last + double totalTimeCost = std::accumulate(costs.begin(), costs.end(), 0.0); // ms + double avg = totalTimeCost / static_cast(costs.size()); + auto count = costs.size(); + const size_t p90 = 90; + const size_t p99 = 90; + const size_t pmax = 100; + ss << costs.size(); + ss << "," << avg; // avg ms + ss << "," << costs[0]; // min ms + ss << "," << costs[count * p90 / pmax]; // p90 ms + ss << "," << costs[count * p99 / pmax]; // p99 ms + ss << "," << costs[count - 1]; // max ms + return ss.str(); + } + + static std::string GenUniqueString(size_t threadIdx, size_t iter) + { + const size_t appendSize = 64; + return "thread_" + std::to_string(threadIdx) + "_iter_" + std::to_string(iter) + "_" + + std::string(appendSize, 'a'); + } + + static std::string GenDupString(size_t threadIdx, size_t iter) + { + (void)threadIdx; + const size_t appendSize = 64; + const size_t count = 100; + return "thread_x_iter_" + std::to_string(iter % count) + "_" + std::string(appendSize, 'a'); + } + + template + void PerfTwoAction(size_t threadCnt, G &&gen, F1 &&fn1, F2 &&fn2) + { + const size_t countPerThread = 102'400; + const size_t batchCnt = 1024; + std::vector threads; + // generate string + std::vector> datas; + std::vector> costPerThread1; + std::vector> costPerThread2; + datas.resize(threadCnt); + costPerThread1.resize(threadCnt); + costPerThread2.resize(threadCnt); + Barrier barrier1(threadCnt); + Barrier barrier2(threadCnt); + for (size_t i = 0; i < threadCnt; i++) { + auto &data = datas[i]; + auto &costs1 = costPerThread1[i]; + auto &costs2 = costPerThread2[i]; + data.reserve(countPerThread); + costs1.reserve(countPerThread / batchCnt); + costs2.reserve(countPerThread / batchCnt); + + for (size_t n = 0; n < countPerThread; n++) { + data.emplace_back(std::move(gen(i, n))); + } + std::shuffle(data.begin(), data.end(), std::mt19937{ std::random_device{}() }); + threads.emplace_back([&] { + size_t count = 0; + barrier1.Wait(); + Timer timer1; + for (const auto &key : data) { + fn1(key); + if (count == batchCnt) { + count = 0; + costs1.emplace_back(timer1.ElapsedMilliSecondAndReset()); + } + count += 1; + } + barrier2.Wait(); + Timer timer2; + count = 0; + for (const auto &key : data) { + fn2(key); + if (count == batchCnt) { + count = 0; + costs2.emplace_back(timer2.ElapsedMilliSecondAndReset()); + } + count += 1; + } + }); + } + for (auto &t : threads) { + t.join(); + } + std::string caseName; + std::string name; + ut::GetCurTestName(caseName, name); + LOG(ERROR) << "BENCHMARK," << caseName << "," << name << ",Thread-" << threadCnt << ", Action1," + << GetBenchCost(costPerThread1); + LOG(ERROR) << "BENCHMARK," << caseName << "," << name << ",Thread-" << threadCnt << ", Action2," + << GetBenchCost(costPerThread2); + } +}; +} // namespace ut +} // namespace datasystem +#endif diff --git a/tests/ut/client/parallel_for_local_test.cpp b/tests/ut/client/parallel_for_local_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..618bb9ce1f9992e3f89d2608073aedbcc4454fd0 --- /dev/null +++ b/tests/ut/client/parallel_for_local_test.cpp @@ -0,0 +1,326 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "datasystem/common/parallel/parallel_for.h" + +using namespace datasystem; +using namespace datasystem::Parallel; + +std::vector g_nodepool; +std::map g_tid_id_map; +std::map g_id_tid_map; +bool g_id_tid_good = true; +std::mutex g_id_tid_check_mu; + +const uint32_t g_taskNum = 1000000; +const uint32_t g_threadsNum = 8; +const uint32_t g_chunksize = 100; + +class ParallelForLocalTest : public testing::Test { +public: + ParallelForLocalTest() {} + ~ParallelForLocalTest() {} + static void SetUpTestCase() + { + InitParallelThreadPool(g_threadsNum); + } + void SetUp() + { + g_nodepool.resize(g_taskNum); + g_id_tid_good = true; + g_tid_id_map.clear(); + g_id_tid_map.clear(); + } + void TearDown() + { + g_nodepool.clear(); + g_tid_id_map.clear(); + g_id_tid_map.clear(); + } +}; + +void BodyFun(uint32_t start, uint32_t end) +{ + for (uint32_t i = start; i < end; i++) { + g_nodepool[i] += i; + } +} + +auto BodyLambda = [](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + g_nodepool[i] += i; + } +}; + +class BodyOperator { +public: + void operator()(size_t start, size_t end) const + { + for (size_t i = start; i < end; i++) { + g_nodepool[i] += i; + } + } +}; + +class BodyClassFun { +public: + void Fun(uint32_t start, uint32_t end) + { + for (uint32_t i = start; i < end; i++) { + g_nodepool[i] += i; + } + } + + static void StaticFun(uint32_t start, uint32_t end) + { + for (uint32_t i = start; i < end; i++) { + g_nodepool[i] += i; + } + } +}; + +// 3 arguments handler +// ctx.id and thread_id need to be one-to-one corresponding +void idTidCheck(size_t ctxid) +{ + std::thread::id tid = std::this_thread::get_id(); + std::lock_guard lock(g_id_tid_check_mu); + auto it = g_tid_id_map.find(tid); + if (it != g_tid_id_map.end() && it->second != ctxid) { + g_id_tid_good = false; + } + auto it2 = g_id_tid_map.find(ctxid); + if (it2 != g_id_tid_map.end() && it2->second != tid) { + g_id_tid_good = false; + } + g_tid_id_map[tid] = ctxid; + g_id_tid_map[ctxid] = tid; + EXPECT_TRUE(g_id_tid_good); +} + +void BodyFunWithCtx(uint32_t start, uint32_t end, const Context &ctx) +{ + idTidCheck(ctx.id); + for (uint32_t i = start; i < end; i++) { + g_nodepool[i] += i; + } +} + +auto BodyLambdaWithCtx = [](size_t start, size_t end, const Context &ctx) { + idTidCheck(ctx.id); + for (size_t i = start; i < end; i++) { + g_nodepool[i] += i; + } +}; + +class BodyOperatorWithCtx { +public: + void operator()(size_t start, size_t end, const Context &ctx) const + { + idTidCheck(ctx.id); + for (size_t i = start; i < end; i++) { + g_nodepool[i] += i; + } + } +}; + +class BodyClassFunWithCtx { +public: + void Fun(uint32_t start, uint32_t end, const Context &ctx) + { + idTidCheck(ctx.id); + for (uint32_t i = start; i < end; i++) { + g_nodepool[i] += i; + } + } + + static void StaticFun(uint32_t start, uint32_t end, const Context &ctx) + { + idTidCheck(ctx.id); + for (uint32_t i = start; i < end; i++) { + g_nodepool[i] += i; + } + } +}; + +TEST_F(ParallelForLocalTest, CallBodyOperator) +{ + BodyOperator body; + ParallelFor(0, g_taskNum, body, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyOperatorWithCtx) +{ + BodyOperatorWithCtx body; + ParallelFor(0, g_taskNum, body, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyFun) +{ + ParallelFor(0, g_taskNum, &BodyFun, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyFunWithCtx) +{ + ParallelFor(0, g_taskNum, &BodyFunWithCtx, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyLambda) +{ + ParallelFor(0, g_taskNum, BodyLambda, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyLambdaWithCtx) +{ + ParallelFor(0, g_taskNum, BodyLambdaWithCtx, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyClassFun) +{ + BodyClassFun body; + ParallelFor( + 0, g_taskNum, std::bind(&BodyClassFun::Fun, &body, std::placeholders::_1, std::placeholders::_2), g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyClassFunWithCtx) +{ + BodyClassFunWithCtx body; + ParallelFor(0, g_taskNum, + std::bind(&BodyClassFunWithCtx::Fun, &body, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3), + g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyClassStaticFun) +{ + ParallelFor(0, g_taskNum, &BodyClassFun::StaticFun, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyClassStaticFunWithCtx) +{ + ParallelFor(0, g_taskNum, &BodyClassFunWithCtx::StaticFun, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +TEST_F(ParallelForLocalTest, CallBodyLambdaChunkSizeIsBigger) +{ + auto chunksize = g_taskNum + 100; + ParallelFor(0, g_taskNum, BodyLambda, chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepool[i], i); + } +} + +/* Muti threads tests */ +const uint32_t g_teamsNum = 4; +std::array, g_teamsNum> g_nodepools; + +void ThreadParallelForFunc(int teamid) +{ + auto fun = [&](uint32_t start, uint32_t end) { + for (uint32_t i = start; i < end; i++) { + g_nodepools[teamid][i] += i; + } + }; + // 1 master and parallelDegree-1 workers + ParallelFor(0, g_taskNum, fun, g_chunksize); + for (uint32_t i = 0; i < g_taskNum; i++) { + EXPECT_EQ(g_nodepools[teamid][i], i); + } +} + +TEST_F(ParallelForLocalTest, When_Worker_Thread_Size_Is_One_And_TaskNum_Is_Not_One_Should_Do_Ok) +{ + const int chunkSize = 1; + const int want = 100; + const int length = 2; + std::vector arr; + arr.resize(length); + auto useStart = [&arr](size_t start, size_t /* end */) { arr[start] = want; }; + ParallelFor(0, length, useStart, chunkSize, 1); + for (size_t i = 0; i < length; i++) { + EXPECT_EQ(arr[i], want); + } +} + +TEST_F(ParallelForLocalTest, When_Input_UINT32_MAX_Should_Do_Ok) +{ + std::atomic get; + get.store(0); + size_t want = 1; + auto addLambda = [&get](size_t /* start */, size_t /* end */) { get++; }; + ParallelFor(0, UINT32_MAX, addLambda, UINT32_MAX); + EXPECT_EQ(want, get); +} + +TEST_F(ParallelForLocalTest, NestedParallelFor) +{ + size_t n = 5; + std::mutex mu; + size_t cnt = 0; + ParallelFor(0, n, [n, &mu, &cnt](size_t i1, size_t j1) { + for (auto i = i1; i < j1; i++) { + ParallelFor(0, n, [i, n, &mu, &cnt](size_t i2, size_t j2) { + for (auto j = i2; j < j2; j++) { + ParallelFor(0, n, [i, j, &mu, &cnt](size_t i3, size_t j3) { + for (auto k = i3; k < j3; k++) { + std::lock_guard lk(mu); + std::cout << cnt++ << ": [" << syscall(SYS_gettid) << "] " + << i << " " << j << " " << k << std::endl; + } + }, 1); + } + }, 1); + } + }, 1); +} diff --git a/tests/ut/common/object_cache/obj_ref_table_test.cpp b/tests/ut/common/object_cache/obj_ref_table_test.cpp index 87007423e3ca9343bb7c6e0cdd00c6e90b0c9271..def8af5d8d74797f632180141736523151895712 100644 --- a/tests/ut/common/object_cache/obj_ref_table_test.cpp +++ b/tests/ut/common/object_cache/obj_ref_table_test.cpp @@ -21,6 +21,7 @@ #include "common.h" #include "datasystem/common/object_cache/object_ref_info.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/random_data.h" #include "datasystem/common/object_cache/safe_table.h" #include "datasystem/common/util/thread_pool.h" @@ -51,6 +52,7 @@ void ParallelFor(size_t numOfOps, F f, size_t numOfThreads) class ObjRefTableTest : public CommonTest { public: + using ObjectRefInfo = datasystem::object_cache::ObjectRefInfo; void SetUp() override { clientIds_.clear(); @@ -170,7 +172,7 @@ std::vector ObjRefTableTest::GetGlobalRefIds(const std::vector> objects; + std::unordered_map> objects; // Verify Object Table and Ref Cnt. for (const auto &objKeyClients : refClientSets_) { @@ -185,8 +187,8 @@ void ObjRefTableTest::VerifyMemRefTableMatches() ASSERT_EQ((*entry)->GetShmUnit()->GetRefCount(), static_cast(clientSet.size())); entry->RUnlock(); for (const auto &client : clientSet) { - ASSERT_TRUE(memRefTable_.Contains(client, objKey)); - objects[client].emplace(objKey); + ASSERT_TRUE(memRefTable_.Contains(client, ShmKey::Intern(objKey))); + objects[client].emplace(ShmKey::Intern(objKey)); } } @@ -194,7 +196,7 @@ void ObjRefTableTest::VerifyMemRefTableMatches() for (const auto &clientObjects : objects) { const auto &clientId = clientObjects.first; const auto &objectSet = clientObjects.second; - std::vector objKeys; + std::vector objKeys; memRefTable_.GetClientRefIds(clientId, objKeys); ASSERT_EQ(objKeys.size(), objectSet.size()); for (const auto &objKey : objKeys) { @@ -217,7 +219,7 @@ void ObjRefTableTest::TestMemRefTableUniqAdd() std::unique_ptr objPtr = nullptr; // Set ShmUnit. auto shmUnit = std::make_shared(); - shmUnit->id = objKeys_[objIndex]; + shmUnit->id = ShmKey::Intern(objKeys_[objIndex]); auto objShmUnit = std::make_unique(); objShmUnit->SetShmUnit(shmUnit); entry->SetRealObject(std::move(objShmUnit)); @@ -254,7 +256,7 @@ void ObjRefTableTest::TestMemRefTableUniqRemove() auto objShmUnit = std::make_unique(); entry->SetRealObject(std::move(objShmUnit)); } - memRefTable_.RemoveShmUnit(clientIds_[clientIndex], objKeys_[objIndex]); + memRefTable_.RemoveShmUnit(clientIds_[clientIndex], ShmKey::Intern(objKeys_[objIndex])); entry->WUnlock(); }, numOfThreads_); @@ -394,7 +396,6 @@ void ObjRefTableTest::TestGlobalRefTableRemove() TEST_F(ObjRefTableTest, ObjRefInfoUniqBranchTest) { - using ObjectRefInfo = datasystem::object_cache::ObjectRefInfo; auto clientInfo = std::make_shared(); size_t dataSz = 32; auto id = randomData_.GetRandomString(dataSz); @@ -436,7 +437,6 @@ TEST_F(ObjRefTableTest, ObjRefInfoUniqBranchTest) TEST_F(ObjRefTableTest, ObjRefInfoRefCntBranchTest) { - using ObjectRefInfo = datasystem::object_cache::ObjectRefInfo; auto clientInfo = std::make_shared(false); size_t dataSz = 32; auto id = randomData_.GetRandomString(dataSz); @@ -454,7 +454,6 @@ TEST_F(ObjRefTableTest, ObjRefInfoRefCntBranchTest) TEST_F(ObjRefTableTest, ObjRefInfoRefCntMultiIdMultiThread) { - using ObjectRefInfo = datasystem::object_cache::ObjectRefInfo; auto clientInfo = std::make_shared(false); int threadNum = 8; ThreadPool threadPool(threadNum); @@ -491,7 +490,6 @@ TEST_F(ObjRefTableTest, ObjRefInfoRefCntMultiIdMultiThread) TEST_F(ObjRefTableTest, ObjRefInfoRefCntOneIdMultiThread) { - using ObjectRefInfo = datasystem::object_cache::ObjectRefInfo; auto clientInfo = std::make_shared(false); int threadNum = 8; ThreadPool threadPool(threadNum); @@ -530,7 +528,6 @@ TEST_F(ObjRefTableTest, ObjRefInfoRefCntOneIdMultiThread) TEST_F(ObjRefTableTest, ObjRefInfoRefCntMultiIdMultiThread2) { - using ObjectRefInfo = datasystem::object_cache::ObjectRefInfo; auto clientInfo = std::make_shared(true); int threadNum = 8; ThreadPool threadPool(threadNum); @@ -566,7 +563,6 @@ TEST_F(ObjRefTableTest, ObjRefInfoRefCntMultiIdMultiThread2) TEST_F(ObjRefTableTest, ObjRefInfoRefCntOneIdMultiThread2) { - using ObjectRefInfo = datasystem::object_cache::ObjectRefInfo; auto clientInfo = std::make_shared(true); int threadNum = 8; ThreadPool threadPool(threadNum); @@ -640,7 +636,7 @@ TEST_F(ObjRefTableTest, GlobalRefTableAddRmTest) TEST_F(ObjRefTableTest, RemoveClientAndDecreaseShmUnit) { - std::vector shmIds; + std::vector shmIds; auto clientId = GetStringUuid(); for (int i = 0; i < 3000; i++) { // id num is 3000 std::shared_ptr entry; @@ -651,11 +647,11 @@ TEST_F(ObjRefTableTest, RemoveClientAndDecreaseShmUnit) std::unique_ptr objPtr = nullptr; // Set ShmUnit. auto shmUnit = std::make_shared(); - shmUnit->id = objId; + shmUnit->id = ShmKey::Intern(objId); auto objShmUnit = std::make_unique(); objShmUnit->SetShmUnit(shmUnit); entry->SetRealObject(std::move(objShmUnit)); - shmIds.emplace_back(objId); + shmIds.emplace_back(ShmKey::Intern(objId)); } auto shmUnit = (*entry)->GetShmUnit(); memRefTable_.AddShmUnit(clientId, shmUnit); diff --git a/tests/ut/common/shared_memory/allocator_test.cpp b/tests/ut/common/shared_memory/allocator_test.cpp index dae28140434780c3389a85a736bc843203f5b306..c1cab158e61e259509c1c6dbc17f25eac469c2fa 100644 --- a/tests/ut/common/shared_memory/allocator_test.cpp +++ b/tests/ut/common/shared_memory/allocator_test.cpp @@ -24,6 +24,7 @@ #include #include #include +#include "datasystem/common/string_intern/string_ref.h" #include "gtest/gtest.h" #define JEMALLOC_NO_DEMANGLE @@ -818,7 +819,7 @@ void AllocatorTest::TestShmUnits1() ShmView currView = shmUnit1.GetShmView(); std::string id("123"); // test a construction from view - ShmUnit shmUnit2(id, currView, nullptr); + ShmUnit shmUnit2(ShmKey::Intern(id), currView, nullptr); // test a construction using fd/mmapsize ShmUnit shmUnit3(1, 1); // Test a copy diff --git a/tests/ut/common/string_intern/string_ref_test.cpp b/tests/ut/common/string_intern/string_ref_test.cpp index 0de77d775479b186e264ff24578a0fdedd484d97..4120cb5e5f6416b3c7a7747cc70687aedad10f33 100644 --- a/tests/ut/common/string_intern/string_ref_test.cpp +++ b/tests/ut/common/string_intern/string_ref_test.cpp @@ -313,6 +313,10 @@ TEST_F(StringRefTest, BaseTest) std::unordered_map map; map.emplace(s1, 1); ASSERT_EQ(map[s2], 1); + + auto s4 = ObjectKey::Intern("abcde"); + s4 = s3; + ASSERT_EQ(s3, s4); } TEST_F(StringRefTest, TestMove) diff --git a/tests/ut/common/util/immutable_string_test.cpp b/tests/ut/common/util/immutable_string_test.cpp index 41dc2d408314f368dd17b1e178605464a361f85e..2b8067c6cee8771f4a9ed67b8b71e3912a379851 100644 --- a/tests/ut/common/util/immutable_string_test.cpp +++ b/tests/ut/common/util/immutable_string_test.cpp @@ -24,18 +24,18 @@ namespace datasystem { namespace ut { class ImmutableStringTest : public CommonTest { public: - static void CheckImmutableStringEqual(const ImmutableString &im1, const ImmutableString &im2) + static void CheckImmutableStringEqual(const ImmutableStringImpl &im1, const ImmutableStringImpl &im2) { ASSERT_EQ(im1, im2); ASSERT_EQ(im1.ToString(), im2.ToString()); ASSERT_EQ(&im1.ToString(), &im2.ToString()); } - static void CheckSetErase(tbb::concurrent_unordered_set> &set1, - tbb::concurrent_unordered_set> &set2) + static void CheckSetErase(tbb::concurrent_unordered_set> &set1, + tbb::concurrent_unordered_set> &set2) { set1.unsafe_erase("123"); - set2.unsafe_erase(ImmutableString("123")); + set2.unsafe_erase(ImmutableStringImpl("123")); EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1ul); set1.unsafe_erase("456"); @@ -47,7 +47,7 @@ public: static void CheckSetErase(T &set1, T &set2) { set1.erase("123"); - set2.erase(ImmutableString("123")); + set2.erase(ImmutableStringImpl("123")); EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1ul); set1.erase("456"); @@ -65,8 +65,8 @@ public: T map1; T map2; { - auto im1 = ImmutableString(key1); - // insert by ImmutableString + auto im1 = ImmutableStringImpl(key1); + // insert by ImmutableStringImpl map1[im1] = value1; ASSERT_EQ(map1[key1], value1); @@ -93,14 +93,14 @@ public: auto pair = set1.insert(test1); ASSERT_TRUE(pair.second); - pair = set1.insert(ImmutableString(test2)); + pair = set1.insert(ImmutableStringImpl(test2)); ASSERT_TRUE(pair.second); - pair = set1.insert(ImmutableString(test2)); + pair = set1.insert(ImmutableStringImpl(test2)); ASSERT_FALSE(pair.second); // After insert, 2 RefCountString in pool. EXPECT_EQ(ImmutableStringPool::Instance().Size(), 2ul); - // find by ImmutableString - auto iter = set1.find(ImmutableString(test1)); + // find by ImmutableStringImpl + auto iter = set1.find(ImmutableStringImpl(test1)); ASSERT_TRUE(iter != set1.end()); ASSERT_EQ(*iter, test1); // find by std::string @@ -117,9 +117,9 @@ public: iter = set1.find("789"); ASSERT_TRUE(iter == set1.end()); - pair = set2.insert(ImmutableString(test1)); + pair = set2.insert(ImmutableStringImpl(test1)); ASSERT_TRUE(pair.second); - pair = set2.insert(ImmutableString(test2)); + pair = set2.insert(ImmutableStringImpl(test2)); ASSERT_TRUE(pair.second); EXPECT_EQ(ImmutableStringPool::Instance().Size(), 2ul); @@ -137,12 +137,12 @@ TEST_F(ImmutableStringTest, TestConstructor) std::string test2 = "456"; char test3[] = "123"; { - ImmutableString im1 = ImmutableString(test1); - ImmutableString im2 = ImmutableString(test1); - ImmutableString im3 = ImmutableString(test2); - ImmutableString im4 = ImmutableString("123"); - ImmutableString im5 = ImmutableString("456"); - ImmutableString im6 = ImmutableString(test3); + ImmutableStringImpl im1 = ImmutableStringImpl(test1); + ImmutableStringImpl im2 = ImmutableStringImpl(test1); + ImmutableStringImpl im3 = ImmutableStringImpl(test2); + ImmutableStringImpl im4 = ImmutableStringImpl("123"); + ImmutableStringImpl im5 = ImmutableStringImpl("456"); + ImmutableStringImpl im6 = ImmutableStringImpl(test3); LOG(INFO) << "check im1, im2"; CheckImmutableStringEqual(im1, im2); @@ -164,7 +164,7 @@ TEST_F(ImmutableStringTest, TestBigString) size_t strSize = 1024ul * 1024 * 1024; std::string str = RandomData().GetPartRandomString(strSize, 100); size_t imSize = 2; - std::vector imVec; + std::vector imVec; imVec.reserve(imSize); for (size_t i = 0; i < imSize; i++) { LOG(INFO) << "loop: " << i; @@ -190,7 +190,7 @@ TEST_F(ImmutableStringTest, TestDestructorInParallel) for (size_t i = 0; i < threadNum; i++) { pool->Execute([&strVec, i, strNum]() { for (int j = 0; j < 10000; j++) { - ImmutableString im = ImmutableString(strVec[i % strNum]); + ImmutableStringImpl im = ImmutableStringImpl(strVec[i % strNum]); } }); } @@ -200,25 +200,25 @@ TEST_F(ImmutableStringTest, TestDestructorInParallel) } /** -1. ImmutableString\const char*\std::string 都能insert、find +1. ImmutableStringImpl\const char*\std::string 都能insert、find 2. 重复insert,内存不增加 3. 都erase后,内存能释放 -4. 并发场景下,表不加外部锁的情况下能安全的进行(insert\find\erase),ImmutableString不被破坏 +4. 并发场景下,表不加外部锁的情况下能安全的进行(insert\find\erase),ImmutableStringImpl不被破坏 5. 支持 tbb 和 stl 的所有map/set 类型。 */ TEST_F(ImmutableStringTest, TestImInTbbUnorderedSet) { - ImSetCheckMemoryReduce>>(); + ImSetCheckMemoryReduce>>(); } TEST_F(ImmutableStringTest, TestImInSTLUnorderedSet) { - ImSetCheckMemoryReduce>(); + ImSetCheckMemoryReduce>(); } TEST_F(ImmutableStringTest, TestImInSTLSet) { - ImSetCheckMemoryReduce>(); + ImSetCheckMemoryReduce>(); } TEST_F(ImmutableStringTest, ImInTbbHashMap) @@ -228,14 +228,14 @@ TEST_F(ImmutableStringTest, ImInTbbHashMap) auto value1 = RandomData().GetRandomUint32(); auto value2 = RandomData().GetRandomUint32(); - using MapType = tbb::concurrent_hash_map; + using MapType = tbb::concurrent_hash_map; MapType map1; MapType map2; - auto im1 = ImmutableString(key1); + auto im1 = ImmutableStringImpl(key1); MapType::accessor ac; - // insert by ImmutableString + // insert by ImmutableStringImpl map1.insert(ac, im1); ac->second = value1; ac.release(); @@ -257,7 +257,7 @@ TEST_F(ImmutableStringTest, ImInTbbHashMap) TEST_F(ImmutableStringTest, ImInUnorderedMapInParrel) { auto key1 = GetStringUuid(); - using MapType = std::unordered_map; + using MapType = std::unordered_map; MapType map1; auto pool = std::make_unique(10); std::shared_timed_mutex mutex; @@ -285,13 +285,13 @@ TEST_F(ImmutableStringTest, ImInUnorderedMapInParrel) TEST_F(ImmutableStringTest, ImInSTLHashMap) { - using MapType = std::map; + using MapType = std::map; ImMapCheckMemoryReduce(); } TEST_F(ImmutableStringTest, ImInSTLUnorderedMap) { - using MapType = std::unordered_map; + using MapType = std::unordered_map; ImMapCheckMemoryReduce(); } } // namespace ut diff --git a/tests/ut/common/util/status_test.cpp b/tests/ut/common/util/status_test.cpp index 51b2fd7521c120cac895151460bb272c3d137d3d..0a55c6a8cf19da05480d141041b23dfc8fcd9427 100644 --- a/tests/ut/common/util/status_test.cpp +++ b/tests/ut/common/util/status_test.cpp @@ -198,15 +198,15 @@ TEST_F(StatusTest, TestStreamOperator) TEST_F(StatusTest, TestStatusLogForMat) { - Status status(StatusCode::K_OK, "This is a msg."); + Status status(StatusCode::K_RUNTIME_ERROR, "This is a msg."); status.AppendMsg("This is appended msg."); ASSERT_EQ(status.GetMsg(), "This is a msg. This is appended msg."); - Status status1(StatusCode::K_OK, "This is a msg"); + Status status1(StatusCode::K_RUNTIME_ERROR, "This is a msg"); status1.AppendMsg("This is appended msg."); ASSERT_EQ(status1.GetMsg(), "This is a msg. This is appended msg."); - Status status3(StatusCode::K_OK, ""); + Status status3(StatusCode::K_RUNTIME_ERROR, ""); status3.AppendMsg("This is appended msg."); ASSERT_EQ(status3.GetMsg(), " This is appended msg."); diff --git a/tests/ut/worker/client_manager_test.cpp b/tests/ut/worker/client_manager_test.cpp index 8c1df7a62a135ddcd679bb61339f8265415e8b80..b0de2c71b15f178fa02367b597593f71d1243161 100644 --- a/tests/ut/worker/client_manager_test.cpp +++ b/tests/ut/worker/client_manager_test.cpp @@ -27,6 +27,7 @@ #include "common.h" #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/rpc/unix_sock_fd.h" +#include "datasystem/common/string_intern/string_ref.h" #include "datasystem/common/util/uuid_generator.h" #include "datasystem/worker/client_manager/client_manager.h" @@ -63,7 +64,7 @@ TEST_F(ClientManagerTest, TestAddShmUnitUniqueCount) DS_ASSERT_OK(clientMgr.GetClientSocketFd(clientId, getSocketFd)); ASSERT_TRUE(socketFd == getSocketFd); auto shmUnit = std::make_shared(); - shmUnit->id = GetBytesUuid(); + shmUnit->id = ShmKey::Intern(GetBytesUuid()); shmUnit->fd = socketFd; DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit)); ASSERT_EQ(shmUnit->refCount, 1); @@ -87,14 +88,14 @@ TEST_F(ClientManagerTest, TestAddShmUnit) ASSERT_TRUE(clientInfo != nullptr); auto shmUnit = std::make_shared(); - shmUnit->id = GetBytesUuid(); + shmUnit->id = ShmKey::Intern(GetBytesUuid()); shmUnit->fd = socketFd; DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit)); DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit)); DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit)); ASSERT_GE(shmUnit->refCount, 3); - std::unordered_map shmUnitIds; + std::unordered_map shmUnitIds; clientInfo->GetShmUnitIds(shmUnitIds); ASSERT_TRUE(shmUnitIds.find(shmUnit->id) != shmUnitIds.end()); ASSERT_EQ(shmUnitIds[shmUnit->id], static_cast(3)); @@ -171,7 +172,7 @@ TEST_F(ClientManagerTest, TestRemoveShmUnitOfAllClient) DS_ASSERT_OK(clientMgr.GetClientSocketFd(clientId, getSocketFd)); ASSERT_TRUE(socketFd == getSocketFd); auto shmUnit = std::make_shared(); - shmUnit->id = GetBytesUuid(); + shmUnit->id = ShmKey::Intern(GetBytesUuid()); shmUnit->fd = socketFd; DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit)); ASSERT_GE(shmUnit->refCount, 1); diff --git a/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp b/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp index 4ba0c0c6e6e31891c5366720607c6fa0695e7a54..931dd0c3e760373f1c145e19b4b4b3d4f2fe35b3 100644 --- a/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp +++ b/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp @@ -22,6 +22,7 @@ #include "securec.h" +#include "bench_helper.h" #include "common.h" #include "datasystem/common/constants.h" #include "datasystem/common/log/log.h" @@ -55,7 +56,6 @@ DS_DECLARE_string(etcd_address); namespace datasystem { namespace ut { - class EvictionManagerTest : public CommonTest, public EvictionManagerCommon { public: void SetUp() override @@ -244,8 +244,8 @@ public: allocator->Init(maxSize_, 0, false, true, 5000, ocPercent_, scPercent_); // decay is 5000 ms. std::shared_ptr &objectTable = GetObjectTable(); evictionManager_ = std::make_shared( - objectTable, HostPort("127.0.0.1", 32131), // worker port is 32131, - HostPort("127.0.0.1", 52319)); // master port is 52319; + objectTable, HostPort("127.0.0.1", 32131), // worker port is 32131, + HostPort("127.0.0.1", 52319)); // master port is 52319; auto globalRefTable = std::make_shared(); DS_ASSERT_OK(evictionManager_->Init(globalRefTable, akSkManager_)); scAllocateManager_ = std::make_shared(evictionManager_); @@ -337,5 +337,40 @@ TEST_F(ScEvictionObjectTest, TestEvictObject) evictionManager_->Add(prefix + std::to_string(i)); } } + +class EvictionManagerBenchTest : public CommonTest, public BenchHelper {}; + +TEST_F(EvictionManagerBenchTest, BenchThread1) +{ + const int logLevel = 2; + FLAGS_minloglevel = logLevel; + EvictionList list; + const int threadCnt = 1; + PerfTwoAction( + threadCnt, GenUniqueString, [&list](const std::string &key) { list.Add(key, Q1); }, + [&list](const std::string &key) { list.Erase(key); }); +} + +TEST_F(EvictionManagerBenchTest, BenchThread4) +{ + const int logLevel = 2; + FLAGS_minloglevel = logLevel; + EvictionList list; + const int threadCnt = 4; + PerfTwoAction( + threadCnt, GenUniqueString, [&list](const std::string &key) { list.Add(key, Q1); }, + [&list](const std::string &key) { list.Erase(key); }); +} + +TEST_F(EvictionManagerBenchTest, BenchThread8) +{ + const int logLevel = 2; + FLAGS_minloglevel = logLevel; + EvictionList list; + const int threadCnt = 8; + PerfTwoAction( + threadCnt, GenUniqueString, [&list](const std::string &key) { list.Add(key, Q1); }, + [&list](const std::string &key) { list.Erase(key); }); +} } // namespace ut } // namespace datasystem