From 97f88f5a0d04a546c9b49ce90d614604dc2ad4a6 Mon Sep 17 00:00:00 2001 From: yaohaolin Date: Mon, 17 Nov 2025 14:58:09 +0800 Subject: [PATCH] Hi ALL, stream is available now! --- .clang-tidy | 13 + LOG_README | 50 +- README.md | 2 +- README_CN.md | 2 +- build.sh | 62 +- cli/command.py | 29 +- {example => cli}/cpp_template/CMakeLists.txt | 0 {example => cli}/cpp_template/README.md | 0 .../cpp_template/kv_cache_example.cpp | 0 {example => cli}/cpp_template/run.sh | 0 cli/deploy/conf/worker_config.json | 150 +- cli/generate_helm_chart.py | 2 +- cli/start.py | 1 + cmake/external_libs/libcurl.cmake | 17 +- cmake/external_libs/sdk_c_obs.cmake | 4 +- cmake/external_libs/spdlog.cmake | 5 +- cmake/external_libs/ub.cmake | 33 +- cmake/external_libs/urma.cmake | 5 +- cmake/modules/FindURMA.cmake | 17 +- cmake/package.cmake | 1 + cmake/util.cmake | 118 +- docs/source_en/appendix/k8s_configuration.md | 18 +- docs/source_zh_cn/appendix/hugepage_guide.md | 80 + docs/source_zh_cn/appendix/log_guide.md | 12 +- docs/source_zh_cn/deployment/dscli.md | 22 +- .../deployment/k8s_configuration.md | 18 +- .../datasystem.DsTensorClient.dev_mget.rst | 4 +- .../datasystem.DsTensorClient.dev_recv.rst | 4 +- .../datasystem.DsTensorClient.dev_send.rst | 4 +- ...nsorClient.get_page_attn_layerwise_d2d.rst | 6 +- ...nsorClient.put_page_attn_layerwise_d2d.rst | 6 +- ...em.hetero_client.HeteroClient.dev_mget.rst | 2 + ...em.hetero_client.HeteroClient.dev_mset.rst | 4 +- ...ystem.hetero_client.HeteroClient.exist.rst | 16 + .../datasystem.hetero_client.HeteroClient.rst | 3 + .../datasystem.kv_client.KVClient.exist.rst | 2 +- .../development-guide/example/hetero.md | 4 +- docs/source_zh_cn/index.rst | 2 +- example/README.md | 27 +- example/{ => cpp}/CMakeLists.txt | 37 +- .../datasystem_example.cpp} | 25 +- example/cpp/hetero_client_example.cpp | 159 + .../kv_client_example.cpp} | 5 +- .../object_client_example.cpp} | 15 +- example/cpp/stream_client_example.cpp | 159 + example/python/ds_tensor_client_example.py | 121 + example/python/hetero_client_example.py | 75 + .../kv_cache => python}/kv_client_example.py | 30 +- example/python/object_client_example.py | 76 + example/run-example.sh | 63 +- .../device_object_example.cpp | 0 include/datasystem/datasystem.h | 1 + include/datasystem/hetero_client.h | 2 +- include/datasystem/kv_client.h | 1 - include/datasystem/object_client.h | 5 - include/datasystem/stream/consumer.h | 128 + include/datasystem/stream/element.h | 57 + include/datasystem/stream/producer.h | 111 + include/datasystem/stream/stream_config.h | 134 + include/datasystem/stream_client.h | 147 + include/datasystem/utils/status.h | 18 +- install_tools.sh | 3 + k8s/docker/dockerfile/datasystem.Dockerfile | 12 + .../templates/worker_daemonset.yaml | 38 +- k8s/helm_chart/datasystem/values.yaml | 81 +- python/__init__.py | 2 + python/ds_tensor_client.py | 21 +- python/object_client.py | 29 +- python/stream_client.py | 323 + scripts/modules/llt_util.sh | 3 +- scripts/stream_cache/parse_sc_metrics.py | 207 + setup.py | 2 +- src/datasystem/client/CMakeLists.txt | 19 +- .../client/client_worker_common_api.cpp | 4 +- .../client/client_worker_common_api.h | 1 + src/datasystem/client/context/context.cpp | 4 +- .../client/hetero_cache/device_buffer.h | 9 +- src/datasystem/client/listen_worker.cpp | 3 +- src/datasystem/client/mmap_manager.h | 4 +- .../client/object_cache/client_worker_api.cpp | 56 +- .../client/object_cache/client_worker_api.h | 29 +- .../device/client_device_object_manager.cpp | 52 +- .../device/client_device_object_manager.h | 58 +- .../device/device_memory_unit.cpp | 48 +- .../object_cache/device/device_memory_unit.h | 11 +- .../object_cache/device/hccl_comm_factory.cpp | 198 +- .../object_cache/device/hccl_comm_factory.h | 32 +- .../object_cache/device/p2p_subscribe.cpp | 271 +- .../object_cache/device/p2p_subscribe.h | 25 +- .../client/object_cache/object_client.cpp | 24 +- .../object_cache/object_client_impl.cpp | 399 +- .../client/object_cache/object_client_impl.h | 86 +- .../client/stream_cache/client_base_impl.cpp | 192 + .../client/stream_cache/client_base_impl.h | 188 + .../client/stream_cache/client_worker_api.cpp | 277 + .../client/stream_cache/client_worker_api.h | 159 + .../client/stream_cache/consumer.cpp | 90 + .../client/stream_cache/consumer_impl.cpp | 744 ++ .../client/stream_cache/consumer_impl.h | 311 + .../client/stream_cache/producer.cpp | 74 + .../producer_consumer_worker_api.cpp | 245 + .../producer_consumer_worker_api.h | 127 + .../client/stream_cache/producer_impl.cpp | 573 ++ .../client/stream_cache/producer_impl.h | 248 + .../client/stream_cache/receive_element.h | 36 + .../client/stream_cache/stream_client.cpp | 145 + .../stream_cache/stream_client_impl.cpp | 420 + .../client/stream_cache/stream_client_impl.h | 237 + src/datasystem/common/CMakeLists.txt | 4 +- src/datasystem/common/constants.h | 16 + .../device/ascend/acl_device_manager.cpp | 166 +- .../common/device/ascend/acl_device_manager.h | 48 +- .../device/ascend/acl_pipeline_p2p_task.cpp | 21 +- .../device/ascend/acl_pipeline_p2p_task.h | 10 +- .../device/ascend/acl_resource_manager.cpp | 2 +- .../common/device/ascend/callback_thread.cpp | 6 +- .../common/device/ascend/cann_types.h | 1 + .../device/ascend/comm_wrapper_base.cpp | 104 +- .../common/device/ascend/comm_wrapper_base.h | 76 +- .../common/device/ascend/ffts_dispatcher.cpp | 7 +- .../common/device/ascend/ffts_dispatcher.h | 1 - .../device/ascend/hccl_comm_wrapper.cpp | 40 +- .../common/device/ascend/hccl_comm_wrapper.h | 14 +- .../device/ascend/p2phccl_comm_wrapper.cpp | 42 +- .../device/ascend/p2phccl_comm_wrapper.h | 13 +- .../common/eventloop/timer_queue.cpp | 3 +- .../common/httpclient/http_request.cpp | 2 +- .../immutable_string/ref_count_string.cpp | 2 +- .../common/kvstore/etcd/etcd_store.cpp | 16 +- .../common/kvstore/rocksdb/replica.cpp | 55 +- .../common/kvstore/rocksdb/replica.h | 19 +- .../common/kvstore/rocksdb/rocks_store.cpp | 207 +- .../common/kvstore/rocksdb/rocks_store.h | 43 +- src/datasystem/common/log/access_point.def | 7 + src/datasystem/common/log/log.h | 12 +- src/datasystem/common/log/log_helper.h | 9 +- src/datasystem/common/log/logging.cpp | 4 +- .../common/log/spdlog/CMakeLists.txt | 2 +- .../common/log/spdlog/log_message_impl.cpp | 13 +- .../common/log/spdlog/log_message_impl.h | 6 +- src/datasystem/common/log/spdlog/log_param.h | 2 +- .../common/log/spdlog/log_severity.h | 4 +- .../common/log/spdlog/logger_context.cpp | 57 +- .../common/log/spdlog/logger_context.h | 2 +- .../hard_disk_exporter/hard_disk_exporter.cpp | 4 +- .../common/metrics/res_metric_collector.cpp | 6 + .../common/metrics/res_metric_collector.h | 6 + .../common/object_cache/buffer_composer.cpp | 60 +- .../common/object_cache/buffer_composer.h | 16 +- .../common/object_cache/device_buffer.cpp | 20 +- .../common/object_cache/safe_object.h | 21 + src/datasystem/common/perf/perf_point.def | 4 + src/datasystem/common/rdma/CMakeLists.txt | 5 +- src/datasystem/common/rdma/rdma_util.cpp | 13 + src/datasystem/common/rdma/rdma_util.h | 8 + src/datasystem/common/rdma/urma_info.cpp | 202 + src/datasystem/common/rdma/urma_info.h | 117 + src/datasystem/common/rdma/urma_manager.cpp | 322 +- src/datasystem/common/rdma/urma_manager.h | 72 +- .../common/rdma/urma_manager_wrapper.cpp | 30 +- .../common/rdma/urma_manager_wrapper.h | 2 +- src/datasystem/common/rdma/urma_stub.cpp | 2 +- src/datasystem/common/rpc/CMakeLists.txt | 2 + src/datasystem/common/rpc/rpc_channel.cpp | 60 - src/datasystem/common/rpc/rpc_channel.h | 32 - .../common/rpc/rpc_stub_cache_mgr.cpp | 50 +- .../common/rpc/rpc_stub_cache_mgr.h | 3 + .../common/rpc/zmq/zmq_stub_conn.cpp | 24 +- .../common/shared_memory/allocator.cpp | 61 +- .../common/shared_memory/allocator.h | 65 +- src/datasystem/common/shared_memory/arena.cpp | 11 +- src/datasystem/common/shared_memory/arena.h | 4 +- .../common/shared_memory/shm_unit.cpp | 39 +- .../common/shared_memory/shm_unit.h | 30 + .../common/stream_cache/CMakeLists.txt | 13 + .../common/stream_cache/consumer_meta.h | 142 + src/datasystem/common/stream_cache/cursor.cpp | 518 ++ src/datasystem/common/stream_cache/cursor.h | 383 + .../common/stream_cache/stream_data_page.cpp | 1153 +++ .../common/stream_cache/stream_data_page.h | 575 ++ .../common/stream_cache/stream_fields.h | 154 + .../common/stream_cache/stream_meta_shm.cpp | 81 + .../common/stream_cache/stream_meta_shm.h | 65 + src/datasystem/common/stream_cache/util.h | 53 + .../common/string_intern/CMakeLists.txt | 11 + .../common/string_intern/string_entity.cpp | 97 + .../common/string_intern/string_entity.h | 124 + .../common/string_intern/string_pool.h | 108 + .../common/string_intern/string_ptr.h | 93 + .../common/string_intern/string_ref.h | 262 + .../common/util/gflag/common_gflags.cpp | 22 +- src/datasystem/common/util/id_tool.cpp | 2 +- src/datasystem/common/util/status_code.def | 14 + src/datasystem/common/util/strings_util.h | 41 +- src/datasystem/common/util/thread_pool.h | 105 + src/datasystem/common/util/validator.h | 16 + src/datasystem/common/util/wait_post.cpp | 15 + src/datasystem/common/util/wait_post.h | 29 + src/datasystem/master/CMakeLists.txt | 2 + .../master/object_cache/CMakeLists.txt | 1 + .../device/master_dev_dead_lock_manager.cpp | 132 + .../device/master_dev_dead_lock_manager.h | 101 + .../device/master_dev_hccl_rootinfo.cpp | 36 +- .../device/master_dev_hccl_rootinfo.h | 30 +- .../device/master_dev_oc_manager.cpp | 92 +- .../device/master_dev_oc_manager.h | 5 +- .../object_cache/master_master_oc_api.cpp | 2 +- .../object_cache/master_oc_service_impl.h | 3 + .../object_cache/master_worker_oc_api.cpp | 40 +- .../object_cache/master_worker_oc_api.h | 6 +- .../oc_global_cache_delete_manager.cpp | 2 +- .../oc_global_cache_delete_manager.h | 1 + .../object_cache/oc_metadata_manager.cpp | 33 +- .../master/object_cache/oc_metadata_manager.h | 6 +- .../object_cache/oc_notify_worker_manager.cpp | 100 +- .../object_cache/oc_notify_worker_manager.h | 22 +- .../object_cache/store/meta_async_queue.h | 16 + .../object_cache/store/object_meta_store.cpp | 86 +- .../object_cache/store/object_meta_store.h | 24 +- src/datasystem/master/replica_manager.cpp | 95 +- src/datasystem/master/replica_manager.h | 19 +- .../master/stream_cache/CMakeLists.txt | 26 + .../stream_cache/master_sc_service_impl.cpp | 346 + .../stream_cache/master_sc_service_impl.h | 204 + .../stream_cache/master_worker_sc_api.cpp | 393 + .../stream_cache/master_worker_sc_api.h | 291 + .../stream_cache/rpc_session_manager.cpp | 33 + .../master/stream_cache/rpc_session_manager.h | 61 + .../stream_cache/sc_metadata_manager.cpp | 1075 +++ .../master/stream_cache/sc_metadata_manager.h | 432 + .../sc_migrate_metadata_manager.cpp | 340 + .../sc_migrate_metadata_manager.h | 229 + .../stream_cache/sc_notify_worker_manager.cpp | 777 ++ .../stream_cache/sc_notify_worker_manager.h | 394 + .../master/stream_cache/store/CMakeLists.txt | 19 + .../store/rocks_stream_meta_store.cpp | 337 + .../store/rocks_stream_meta_store.h | 233 + .../stream_cache/store/stream_transform.h | 69 + .../master/stream_cache/stream_metadata.cpp | 906 ++ .../master/stream_cache/stream_metadata.h | 546 ++ .../master/stream_cache/topology_manager.cpp | 478 ++ .../master/stream_cache/topology_manager.h | 361 + src/datasystem/protos/CMakeLists.txt | 32 +- src/datasystem/protos/README.md | 8 +- src/datasystem/protos/master_stream.proto | 228 + src/datasystem/protos/meta_zmq.proto | 17 +- src/datasystem/protos/object_posix.proto | 3 + src/datasystem/protos/p2p_subscribe.proto | 2 + src/datasystem/protos/share_memory.proto | 1 + src/datasystem/protos/stream_posix.proto | 419 + src/datasystem/protos/utils.proto | 23 +- src/datasystem/protos/worker_stream.proto | 181 + src/datasystem/pybind_api/CMakeLists.txt | 1 + src/datasystem/pybind_api/pybind_register.cpp | 2 +- .../pybind_api/pybind_register_object.cpp | 8 +- .../pybind_api/pybind_register_stream.cpp | 189 + src/datasystem/worker/CMakeLists.txt | 2 + .../worker/client_manager/client_info.cpp | 2 +- .../cluster_manager/etcd_cluster_manager.cpp | 192 +- .../cluster_manager/etcd_cluster_manager.h | 26 + src/datasystem/worker/hash_ring/hash_ring.cpp | 6 +- .../object_cache/obj_cache_shm_unit.cpp | 108 +- .../worker/object_cache/obj_cache_shm_unit.h | 29 +- .../worker_oc_service_batch_get_impl.cpp | 83 +- .../service/worker_oc_service_create_impl.cpp | 25 + .../service/worker_oc_service_delete_impl.h | 18 +- .../service/worker_oc_service_expire_impl.cpp | 8 +- .../service/worker_oc_service_get_impl.cpp | 178 +- .../service/worker_oc_service_get_impl.h | 54 +- .../worker_oc_service_migrate_impl.cpp | 7 +- .../worker_oc_eviction_manager.cpp | 40 +- .../object_cache/worker_oc_eviction_manager.h | 6 +- .../object_cache/worker_oc_service_impl.cpp | 67 +- .../object_cache/worker_oc_service_impl.h | 10 +- .../worker/stream_cache/CMakeLists.txt | 47 + .../worker/stream_cache/buffer_pool.cpp | 661 ++ .../worker/stream_cache/buffer_pool.h | 344 + .../client_worker_sc_service_impl.cpp | 2585 ++++++ .../client_worker_sc_service_impl.h | 1058 +++ .../worker/stream_cache/consumer.cpp | 132 + src/datasystem/worker/stream_cache/consumer.h | 168 + .../master_worker_sc_service_impl.cpp | 305 + .../master_worker_sc_service_impl.h | 160 + .../stream_cache/metrics/CMakeLists.txt | 16 + .../stream_cache/metrics/sc_metrics.cpp | 142 + .../stream_cache/metrics/sc_metrics.def | 39 + .../worker/stream_cache/metrics/sc_metrics.h | 124 + .../metrics/sc_metrics_monitor.cpp | 180 + .../stream_cache/metrics/sc_metrics_monitor.h | 115 + .../stream_cache/page_queue/CMakeLists.txt | 18 + .../page_queue/exclusive_page_queue.cpp | 603 ++ .../page_queue/exclusive_page_queue.h | 281 + .../page_queue/page_queue_base.cpp | 1239 +++ .../stream_cache/page_queue/page_queue_base.h | 304 + .../page_queue/page_queue_handler.cpp | 420 + .../page_queue/page_queue_handler.h | 191 + .../page_queue/shared_page_queue.cpp | 201 + .../page_queue/shared_page_queue.h | 72 + .../page_queue/shared_page_queue_group.cpp | 132 + .../page_queue/shared_page_queue_group.h | 60 + .../worker/stream_cache/producer.cpp | 55 + src/datasystem/worker/stream_cache/producer.h | 107 + .../stream_cache/remote_worker_manager.cpp | 1651 ++++ .../stream_cache/remote_worker_manager.h | 667 ++ .../worker/stream_cache/stream_data_pool.cpp | 380 + .../worker/stream_cache/stream_data_pool.h | 145 + .../worker/stream_cache/stream_manager.cpp | 1638 ++++ .../worker/stream_cache/stream_manager.h | 995 +++ .../worker/stream_cache/stream_producer.h | 42 + .../worker/stream_cache/subscription.cpp | 204 + .../worker/stream_cache/subscription.h | 182 + .../worker/stream_cache/usage_monitor.cpp | 406 + .../worker/stream_cache/usage_monitor.h | 245 + .../stream_cache/worker_master_sc_api.cpp | 315 + .../stream_cache/worker_master_sc_api.h | 240 + .../worker_sc_allocate_memory.cpp | 95 + .../stream_cache/worker_sc_allocate_memory.h | 70 + .../worker_worker_sc_service_impl.cpp | 405 + .../worker_worker_sc_service_impl.h | 176 + src/datasystem/worker/worker_cli.h | 11 + .../worker/worker_liveness_check.cpp | 15 + .../worker/worker_master_api_manager_base.h | 4 +- src/datasystem/worker/worker_oc_server.cpp | 205 +- src/datasystem/worker/worker_oc_server.h | 43 + src/datasystem/worker/worker_service_impl.cpp | 4 +- .../benchmark}/hetero_h2d_d2h_benchmark.py | 0 ...lement-yuanrong-datasystem-connector.patch | 1598 ++++ tests/python/prefetch_tests/start_worker.sh | 3 +- tests/python/test_ds_tensor_client.py | 92 +- tests/python/test_oc_client.py | 10 +- tests/python/test_sc_client.py | 506 ++ tests/st/CMakeLists.txt | 17 +- .../kv_cache/kv_cache_client_expire_test.cpp | 4 +- .../kv_cache/kv_cache_client_storage_test.cpp | 441 + .../client/kv_cache/kv_cache_client_test.cpp | 344 +- tests/st/client/kv_cache/kv_client_common.h | 57 + .../kv_cache/kv_client_cross_az_test.cpp | 31 +- .../kv_cache/kv_client_etcd_dfx_test.cpp | 8 +- .../client/kv_cache/kv_client_mset_test.cpp | 107 +- .../kv_client_offset_read_one_host_test.cpp | 18 +- .../kv_cache/kv_client_replica_test.cpp | 4 +- .../client/kv_cache/kv_client_scale_common.h | 29 +- .../client/kv_cache/kv_client_scale_test.cpp | 4 +- .../kv_client_voluntary_scale_down_test.cpp | 29 + .../client/object_cache/client_dfx_test.cpp | 12 +- .../client/object_cache/client_get_test.cpp | 6 +- .../object_cache/client_update_test.cpp | 6 +- .../object_cache/hetero_client_mock_test.cpp | 29 +- .../object_client_replica_test.cpp | 51 - .../object_cache/object_client_scale_test.cpp | 141 +- .../object_client_tenant_test.cpp | 6 +- .../object_cache/object_client_test.cpp | 32 +- .../object_client_with_token_test.cpp | 2 +- .../oc_client_dist_master_test.cpp | 2 +- .../object_cache/oc_client_publish_test.cpp | 13 +- .../object_cache/oc_client_ref_test.cpp | 12 +- .../object_cache/oc_service_disable_test.cpp | 26 + .../object_cache/shm_threshold_test.cpp | 3 +- .../object_cache/urma_object_client_test.cpp | 147 +- .../client/stream_cache/client_crash_test.cpp | 1815 ++++ .../client_worker_heartbeat_test.cpp | 477 ++ .../stream_cache/consumer_large_page_test.cpp | 134 + .../st/client/stream_cache/consumer_test.cpp | 1977 +++++ .../delete_stream_concurrent_test.cpp | 388 + .../stream_cache/delete_stream_test.cpp | 764 ++ .../mem_ctrl_boundary_case_test.cpp | 189 + .../st/client/stream_cache/mem_ctrl_test.cpp | 607 ++ .../multi_producer_multi_consumer_test.cpp | 695 ++ .../st/client/stream_cache/producer_test.cpp | 4311 ++++++++++ .../stream_cache/pub_sub_complex_test.cpp | 313 + .../stream_cache/pub_sub_multinode_test.cpp | 977 +++ tests/st/client/stream_cache/pub_sub_test.cpp | 299 + tests/st/client/stream_cache/pub_sub_utils.h | 85 + .../stream_cache/query_stream_topo_test.cpp | 543 ++ .../client/stream_cache/remote_push_test.cpp | 339 + .../stream_cache/remote_send_recv_test.cpp | 1546 ++++ .../client/stream_cache/reset_stream_test.cpp | 477 ++ .../client/stream_cache/retain_data_test.cpp | 1199 +++ .../stream_cache/sc_client_aksk_auth_test.cpp | 118 + .../st/client/stream_cache/sc_client_common.h | 59 + .../sc_client_evict_object_test.cpp | 151 + .../sc_client_token_auth_test.cpp | 300 + .../client/stream_cache/sc_metrics_test.cpp | 1100 +++ .../shared_page_send_recv_test.cpp | 914 +++ .../single_consuemr_topo_test.cpp | 360 + .../stream_client_replica_test.cpp | 1177 +++ .../stream_cache/stream_client_scale_test.cpp | 1874 +++++ .../stream_data_encryption_test.cpp | 349 + .../stream_dfx_send_recv_test.cpp | 376 + .../client/stream_cache/stream_dfx_test.cpp | 2670 ++++++ .../stream_cache/stream_meta_shm_test.cpp | 351 + .../stream_cache/stream_multi_tenant.cpp | 299 + .../stream_observability_test.cpp | 432 + .../client/stream_cache/stream_size_test.cpp | 241 + .../stream_cache/stream_cache_test.cpp | 290 + tests/st/cluster/base_cluster.h | 13 +- tests/st/cluster/external_cluster.cpp | 13 +- tests/st/cluster/external_cluster.h | 4 +- tests/st/common/kvstore/etcd_store_test.cpp | 4 +- tests/st/common/kvstore/rocks_store_test.cpp | 4 +- .../common/stream_cache/element_generator.cpp | 268 + .../common/stream_cache/element_generator.h | 115 + .../common/stream_cache/mock_evictmanager.h | 44 + tests/st/common/stream_cache/stream_common.h | 46 + tests/st/device/dev_object_client_test.cpp | 0 tests/st/device/dev_object_hetero_test.cpp | 272 +- tests/st/device/hetero_d2h_test.cpp | 166 + tests/st/device/hetero_get_meta_info_test.cpp | 3 +- .../mock/ascend_device_manager_mock.cpp | 7 +- .../object_cache/oc_giveup_primary_test.cpp | 1 + .../oc_migrate_metadata_manager_test.cpp | 1 + tests/st/master/replica_manager_test.cpp | 7 +- .../pub_sub_topo_concurrent_test.cpp | 363 + .../master/stream_cache/pub_sub_topo_test.cpp | 406 + .../st/worker/object_cache/evict_mem_test.cpp | 7 +- .../object_cache/worker_oc_eviction_test.cpp | 32 +- .../master_worker_sc_api_test.cpp | 65 + .../worker_master_sc_api_test.cpp | 64 + tests/ut/CMakeLists.txt | 9 + tests/ut/common.cpp | 3 +- .../ut/common/kvstore/rocks_replica_test.cpp | 18 + tests/ut/common/log/logging_test.cpp | 21 +- .../ut/common/log/spdlog/log_message_test.cpp | 1 + .../common/shared_memory/allocator_test.cpp | 53 +- .../shared_mem_view_lock_test.cpp | 65 + .../stream_cache/stream_meta_shm_test.cpp | 59 + .../string_intern/string_ref_bench_test.cpp | 228 + .../common/string_intern/string_ref_test.cpp | 389 + .../ut/common/util/immutable_string_test.cpp | 2 +- tests/ut/common/util/shm_lock_test.cpp | 119 +- tests/ut/common/util/validator_test.cpp | 2 + .../master_dev_dead_lock_manager_test.cpp | 70 + .../master_dev_oc_manager_test.cpp | 2 +- .../object_cache/object_meta_store_test.cpp | 1 + .../rocks_streammeta_store_test.cpp | 356 + .../sc_migrate_metadata_manager_test.cpp | 106 + .../object_cache/worker_oc_eviction_test.cpp | 49 +- .../ut/worker/stream_cache/lock_map_test.cpp | 344 + .../shared_page_queue_group_test.cpp | 109 + .../stream_cache/shared_page_queue_test.cpp | 273 + .../stream_cache/stream_bufferpool_test.cpp | 168 + .../stream_cache/stream_cursor_test.cpp | 124 + .../stream_cache/stream_data_page_test.cpp | 269 + .../stream_cache/stream_usagemonitor_test.cpp | 291 + .../communicator/P2PCommunicatorManager.h | 1 - .../P2P-Transfer/include/external/ra.h | 1 - .../P2P-Transfer/include/external/tsd.h | 1 - .../include/tools/host-interface.h | 1 - .../source/communication/TcpClient.cpp | 1 - .../source/communication/TcpServer.cpp | 1 - .../source/communicator/P2PCommunicator.cpp | 1 - .../communicator/hccs-ipc/HccsReceiver.cpp | 1 - .../communicator/hccs-ipc/HccsSender.cpp | 1 - .../source/communicator/roce/RoceReceiver.cpp | 1 - .../source/communicator/roce/RoceSender.cpp | 1 - .../P2P-Transfer/source/communicator/test.cpp | 1 - third_party/P2P-Transfer/source/npu/Hccp.cpp | 1 - .../P2P-Transfer/source/npu/LocalNotify.cpp | 1 - .../P2P-Transfer/source/npu/P2PMem.cpp | 1 - .../P2P-Transfer/source/npu/P2PNotify.cpp | 1 - .../P2P-Transfer/source/npu/P2PStream.cpp | 1 - .../P2P-Transfer/source/npu/PeerManager.cpp | 1 - .../P2P-Transfer/source/npu/RaWrapper.cpp | 1 - .../source/npu/RdmaErrCollector.cpp | 1 - .../P2P-Transfer/source/npu/RdmaNotify.cpp | 1 - third_party/P2P-Transfer/source/p2p.cpp | 1 - third_party/P2P-Transfer/source/tools/env.cpp | 1 - .../source/tools/hccl-convert.cpp | 1 - .../source/tools/host-interface.cpp | 1 - .../test/source/p2p-transfer_test.cpp | 1 - .../test/source/p2p-transfer_test_batch.cpp | 1 - .../source/p2p-transfer_test_batch_recv.cpp | 1 - .../source/p2p-transfer_test_batch_send.cpp | 1 - .../source/p2p-transfer_test_batch_thread.cpp | 1 - .../test/source/p2p-transfer_test_init.cpp | 1 - .../test/source/test-tools/barrier.h | 1 - .../test/source/test-tools/fifo.h | 1 - .../test/source/test-tools/measure.h | 1 - .../source/test-tools/measurementSeries.h | 1 - .../test/source/test-tools/tools.h | 1 - .../curl/8.8.0/support_old_cmake.patch | 12 + .../obs/3.24.3/obs-sdk-change-spdlog.patch | 119 + .../patches/spdlog/change-namespace.patch | 7274 +++++++++++++++++ 483 files changed, 84903 insertions(+), 2473 deletions(-) create mode 100644 .clang-tidy rename {example => cli}/cpp_template/CMakeLists.txt (100%) rename {example => cli}/cpp_template/README.md (100%) rename {example => cli}/cpp_template/kv_cache_example.cpp (100%) rename {example => cli}/cpp_template/run.sh (100%) create mode 100644 docs/source_zh_cn/appendix/hugepage_guide.md create mode 100644 docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.exist.rst rename example/{ => cpp}/CMakeLists.txt (43%) rename example/{src/ds_example.cpp => cpp/datasystem_example.cpp} (89%) create mode 100644 example/cpp/hetero_client_example.cpp rename example/{src/kv_cache/kv_example.cpp => cpp/kv_client_example.cpp} (98%) rename example/{src/object_cache/object_example.cpp => cpp/object_client_example.cpp} (94%) create mode 100644 example/cpp/stream_client_example.cpp create mode 100644 example/python/ds_tensor_client_example.py create mode 100644 example/python/hetero_client_example.py rename example/{src/python/kv_cache => python}/kv_client_example.py (79%) create mode 100644 example/python/object_client_example.py delete mode 100644 example/src/device_object_cache/device_object_example.cpp create mode 100644 include/datasystem/stream/consumer.h create mode 100644 include/datasystem/stream/element.h create mode 100644 include/datasystem/stream/producer.h create mode 100644 include/datasystem/stream/stream_config.h create mode 100644 include/datasystem/stream_client.h create mode 100644 python/stream_client.py create mode 100644 scripts/stream_cache/parse_sc_metrics.py create mode 100644 src/datasystem/client/stream_cache/client_base_impl.cpp create mode 100644 src/datasystem/client/stream_cache/client_base_impl.h create mode 100644 src/datasystem/client/stream_cache/client_worker_api.cpp create mode 100644 src/datasystem/client/stream_cache/client_worker_api.h create mode 100644 src/datasystem/client/stream_cache/consumer.cpp create mode 100644 src/datasystem/client/stream_cache/consumer_impl.cpp create mode 100644 src/datasystem/client/stream_cache/consumer_impl.h create mode 100644 src/datasystem/client/stream_cache/producer.cpp create mode 100644 src/datasystem/client/stream_cache/producer_consumer_worker_api.cpp create mode 100644 src/datasystem/client/stream_cache/producer_consumer_worker_api.h create mode 100644 src/datasystem/client/stream_cache/producer_impl.cpp create mode 100644 src/datasystem/client/stream_cache/producer_impl.h create mode 100644 src/datasystem/client/stream_cache/receive_element.h create mode 100644 src/datasystem/client/stream_cache/stream_client.cpp create mode 100644 src/datasystem/client/stream_cache/stream_client_impl.cpp create mode 100644 src/datasystem/client/stream_cache/stream_client_impl.h create mode 100644 src/datasystem/common/rdma/urma_info.cpp create mode 100644 src/datasystem/common/rdma/urma_info.h create mode 100644 src/datasystem/common/stream_cache/CMakeLists.txt create mode 100644 src/datasystem/common/stream_cache/consumer_meta.h create mode 100644 src/datasystem/common/stream_cache/cursor.cpp create mode 100644 src/datasystem/common/stream_cache/cursor.h create mode 100644 src/datasystem/common/stream_cache/stream_data_page.cpp create mode 100644 src/datasystem/common/stream_cache/stream_data_page.h create mode 100644 src/datasystem/common/stream_cache/stream_fields.h create mode 100644 src/datasystem/common/stream_cache/stream_meta_shm.cpp create mode 100644 src/datasystem/common/stream_cache/stream_meta_shm.h create mode 100644 src/datasystem/common/stream_cache/util.h create mode 100644 src/datasystem/common/string_intern/CMakeLists.txt create mode 100644 src/datasystem/common/string_intern/string_entity.cpp create mode 100644 src/datasystem/common/string_intern/string_entity.h create mode 100644 src/datasystem/common/string_intern/string_pool.h create mode 100644 src/datasystem/common/string_intern/string_ptr.h create mode 100644 src/datasystem/common/string_intern/string_ref.h create mode 100644 src/datasystem/master/object_cache/device/master_dev_dead_lock_manager.cpp create mode 100644 src/datasystem/master/object_cache/device/master_dev_dead_lock_manager.h create mode 100644 src/datasystem/master/stream_cache/CMakeLists.txt create mode 100644 src/datasystem/master/stream_cache/master_sc_service_impl.cpp create mode 100644 src/datasystem/master/stream_cache/master_sc_service_impl.h create mode 100644 src/datasystem/master/stream_cache/master_worker_sc_api.cpp create mode 100644 src/datasystem/master/stream_cache/master_worker_sc_api.h create mode 100644 src/datasystem/master/stream_cache/rpc_session_manager.cpp create mode 100644 src/datasystem/master/stream_cache/rpc_session_manager.h create mode 100644 src/datasystem/master/stream_cache/sc_metadata_manager.cpp create mode 100644 src/datasystem/master/stream_cache/sc_metadata_manager.h create mode 100644 src/datasystem/master/stream_cache/sc_migrate_metadata_manager.cpp create mode 100644 src/datasystem/master/stream_cache/sc_migrate_metadata_manager.h create mode 100644 src/datasystem/master/stream_cache/sc_notify_worker_manager.cpp create mode 100644 src/datasystem/master/stream_cache/sc_notify_worker_manager.h create mode 100644 src/datasystem/master/stream_cache/store/CMakeLists.txt create mode 100644 src/datasystem/master/stream_cache/store/rocks_stream_meta_store.cpp create mode 100644 src/datasystem/master/stream_cache/store/rocks_stream_meta_store.h create mode 100644 src/datasystem/master/stream_cache/store/stream_transform.h create mode 100644 src/datasystem/master/stream_cache/stream_metadata.cpp create mode 100644 src/datasystem/master/stream_cache/stream_metadata.h create mode 100644 src/datasystem/master/stream_cache/topology_manager.cpp create mode 100644 src/datasystem/master/stream_cache/topology_manager.h create mode 100644 src/datasystem/protos/master_stream.proto create mode 100644 src/datasystem/protos/stream_posix.proto create mode 100644 src/datasystem/protos/worker_stream.proto create mode 100644 src/datasystem/pybind_api/pybind_register_stream.cpp create mode 100644 src/datasystem/worker/stream_cache/CMakeLists.txt create mode 100644 src/datasystem/worker/stream_cache/buffer_pool.cpp create mode 100644 src/datasystem/worker/stream_cache/buffer_pool.h create mode 100644 src/datasystem/worker/stream_cache/client_worker_sc_service_impl.cpp create mode 100644 src/datasystem/worker/stream_cache/client_worker_sc_service_impl.h create mode 100644 src/datasystem/worker/stream_cache/consumer.cpp create mode 100644 src/datasystem/worker/stream_cache/consumer.h create mode 100644 src/datasystem/worker/stream_cache/master_worker_sc_service_impl.cpp create mode 100644 src/datasystem/worker/stream_cache/master_worker_sc_service_impl.h create mode 100644 src/datasystem/worker/stream_cache/metrics/CMakeLists.txt create mode 100644 src/datasystem/worker/stream_cache/metrics/sc_metrics.cpp create mode 100644 src/datasystem/worker/stream_cache/metrics/sc_metrics.def create mode 100644 src/datasystem/worker/stream_cache/metrics/sc_metrics.h create mode 100644 src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.cpp create mode 100644 src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h create mode 100644 src/datasystem/worker/stream_cache/page_queue/CMakeLists.txt create mode 100644 src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp create mode 100644 src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.h create mode 100644 src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp create mode 100644 src/datasystem/worker/stream_cache/page_queue/page_queue_base.h create mode 100644 src/datasystem/worker/stream_cache/page_queue/page_queue_handler.cpp create mode 100644 src/datasystem/worker/stream_cache/page_queue/page_queue_handler.h create mode 100644 src/datasystem/worker/stream_cache/page_queue/shared_page_queue.cpp create mode 100644 src/datasystem/worker/stream_cache/page_queue/shared_page_queue.h create mode 100644 src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.cpp create mode 100644 src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.h create mode 100644 src/datasystem/worker/stream_cache/producer.cpp create mode 100644 src/datasystem/worker/stream_cache/producer.h create mode 100644 src/datasystem/worker/stream_cache/remote_worker_manager.cpp create mode 100644 src/datasystem/worker/stream_cache/remote_worker_manager.h create mode 100644 src/datasystem/worker/stream_cache/stream_data_pool.cpp create mode 100644 src/datasystem/worker/stream_cache/stream_data_pool.h create mode 100644 src/datasystem/worker/stream_cache/stream_manager.cpp create mode 100644 src/datasystem/worker/stream_cache/stream_manager.h create mode 100644 src/datasystem/worker/stream_cache/stream_producer.h create mode 100644 src/datasystem/worker/stream_cache/subscription.cpp create mode 100644 src/datasystem/worker/stream_cache/subscription.h create mode 100644 src/datasystem/worker/stream_cache/usage_monitor.cpp create mode 100644 src/datasystem/worker/stream_cache/usage_monitor.h create mode 100644 src/datasystem/worker/stream_cache/worker_master_sc_api.cpp create mode 100644 src/datasystem/worker/stream_cache/worker_master_sc_api.h create mode 100644 src/datasystem/worker/stream_cache/worker_sc_allocate_memory.cpp create mode 100644 src/datasystem/worker/stream_cache/worker_sc_allocate_memory.h create mode 100644 src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.cpp create mode 100644 src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.h rename {example/src/python/hetero_cache => tests/benchmark}/hetero_h2d_d2h_benchmark.py (100%) create mode 100644 tests/kvconnector/patch/0001-implement-yuanrong-datasystem-connector.patch create mode 100644 tests/python/test_sc_client.py create mode 100644 tests/st/client/kv_cache/kv_cache_client_storage_test.cpp create mode 100644 tests/st/client/kv_cache/kv_client_common.h create mode 100644 tests/st/client/stream_cache/client_crash_test.cpp create mode 100644 tests/st/client/stream_cache/client_worker_heartbeat_test.cpp create mode 100644 tests/st/client/stream_cache/consumer_large_page_test.cpp create mode 100644 tests/st/client/stream_cache/consumer_test.cpp create mode 100644 tests/st/client/stream_cache/delete_stream_concurrent_test.cpp create mode 100644 tests/st/client/stream_cache/delete_stream_test.cpp create mode 100644 tests/st/client/stream_cache/mem_ctrl_boundary_case_test.cpp create mode 100644 tests/st/client/stream_cache/mem_ctrl_test.cpp create mode 100644 tests/st/client/stream_cache/multi_producer_multi_consumer_test.cpp create mode 100644 tests/st/client/stream_cache/producer_test.cpp create mode 100644 tests/st/client/stream_cache/pub_sub_complex_test.cpp create mode 100644 tests/st/client/stream_cache/pub_sub_multinode_test.cpp create mode 100644 tests/st/client/stream_cache/pub_sub_test.cpp create mode 100644 tests/st/client/stream_cache/pub_sub_utils.h create mode 100644 tests/st/client/stream_cache/query_stream_topo_test.cpp create mode 100644 tests/st/client/stream_cache/remote_push_test.cpp create mode 100644 tests/st/client/stream_cache/remote_send_recv_test.cpp create mode 100644 tests/st/client/stream_cache/reset_stream_test.cpp create mode 100644 tests/st/client/stream_cache/retain_data_test.cpp create mode 100644 tests/st/client/stream_cache/sc_client_aksk_auth_test.cpp create mode 100644 tests/st/client/stream_cache/sc_client_common.h create mode 100644 tests/st/client/stream_cache/sc_client_evict_object_test.cpp create mode 100644 tests/st/client/stream_cache/sc_client_token_auth_test.cpp create mode 100644 tests/st/client/stream_cache/sc_metrics_test.cpp create mode 100644 tests/st/client/stream_cache/shared_page_send_recv_test.cpp create mode 100644 tests/st/client/stream_cache/single_consuemr_topo_test.cpp create mode 100644 tests/st/client/stream_cache/stream_client_replica_test.cpp create mode 100644 tests/st/client/stream_cache/stream_client_scale_test.cpp create mode 100644 tests/st/client/stream_cache/stream_data_encryption_test.cpp create mode 100644 tests/st/client/stream_cache/stream_dfx_send_recv_test.cpp create mode 100644 tests/st/client/stream_cache/stream_dfx_test.cpp create mode 100644 tests/st/client/stream_cache/stream_meta_shm_test.cpp create mode 100644 tests/st/client/stream_cache/stream_multi_tenant.cpp create mode 100644 tests/st/client/stream_cache/stream_observability_test.cpp create mode 100644 tests/st/client/stream_cache/stream_size_test.cpp create mode 100644 tests/st/client_c_api/stream_cache/stream_cache_test.cpp create mode 100644 tests/st/common/stream_cache/element_generator.cpp create mode 100644 tests/st/common/stream_cache/element_generator.h create mode 100644 tests/st/common/stream_cache/mock_evictmanager.h create mode 100644 tests/st/common/stream_cache/stream_common.h delete mode 100644 tests/st/device/dev_object_client_test.cpp create mode 100644 tests/st/device/hetero_d2h_test.cpp create mode 100644 tests/st/master/stream_cache/pub_sub_topo_concurrent_test.cpp create mode 100644 tests/st/master/stream_cache/pub_sub_topo_test.cpp create mode 100644 tests/st/worker/stream_cache/master_worker_sc_api_test.cpp create mode 100644 tests/st/worker/stream_cache/worker_master_sc_api_test.cpp create mode 100644 tests/ut/common/stream_cache/shared_mem_view_lock_test.cpp create mode 100644 tests/ut/common/stream_cache/stream_meta_shm_test.cpp create mode 100644 tests/ut/common/string_intern/string_ref_bench_test.cpp create mode 100644 tests/ut/common/string_intern/string_ref_test.cpp create mode 100644 tests/ut/master/object_cache/master_dev_dead_lock_manager_test.cpp create mode 100644 tests/ut/master/stream_cache/rocks_streammeta_store_test.cpp create mode 100644 tests/ut/master/stream_cache/sc_migrate_metadata_manager_test.cpp create mode 100644 tests/ut/worker/stream_cache/lock_map_test.cpp create mode 100644 tests/ut/worker/stream_cache/shared_page_queue_group_test.cpp create mode 100644 tests/ut/worker/stream_cache/shared_page_queue_test.cpp create mode 100644 tests/ut/worker/stream_cache/stream_bufferpool_test.cpp create mode 100644 tests/ut/worker/stream_cache/stream_cursor_test.cpp create mode 100644 tests/ut/worker/stream_cache/stream_data_page_test.cpp create mode 100644 tests/ut/worker/stream_cache/stream_usagemonitor_test.cpp create mode 100644 third_party/patches/curl/8.8.0/support_old_cmake.patch create mode 100644 third_party/patches/obs/3.24.3/obs-sdk-change-spdlog.patch create mode 100644 third_party/patches/spdlog/change-namespace.patch diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000..29c2c94 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,13 @@ +Checks: > + bugprone-*, + performance-*, + readability-*, + concurrency-mt-unsafe, + clang-analyzer-*, + -readability-identifier-length, + -readability-qualified-auto, + -readability-function-cognitive-complexity, + -readability-simplify-boolean-expr, + -bugprone-easily-swappable-parameters, + -readability-implicit-bool-conversion, + -readability-convert-member-functions-to-static, \ No newline at end of file diff --git a/LOG_README b/LOG_README index 60af057..a549b66 100644 --- a/LOG_README +++ b/LOG_README @@ -13,10 +13,11 @@ 2 /path/worker/access.log 访问worker POSIX接口的日志 3 /path/worker/resource.log Worker资源使用日志 默认关闭,通过log_monitor开关控制是否开启。 4 /path/worker/requestout.log 访问l2cache、ETCD、AGC IAM接口日志 -5 /path/worker/container.log 容器运行日志,管理和监控worker进程的生命周期 +5 /path/worker/sc_metrics.log 流缓存运行数据 默认关闭,通过log_monitor开关控制是否开启。 +6 /path/worker/container.log 容器运行日志,管理和监控worker进程的生命周期 ----------------------------------------------------------------------------------------------------------------------------- -6 Client /path/client/ds_client_.log SDK运行日志 -7 /path/client/ds_client_access_.log SDK接口访问日志 +7 Client /path/client/ds_client_.log SDK运行日志 +8 /path/client/ds_client_access_.log SDK接口访问日志 ----------------------------------------------------------------------------------------------------------------------------- 1.1.2 日志格式 @@ -24,11 +25,12 @@ ----------------------------------------------------------------------------------------------------------------------------- 序号 日志 日志格式 ----------------------------------------------------------------------------------------------------------------------------- -1 运行日志 Time | level | filename| pod_name | pid:tid | trace_id | az_name | message -2 访问日志 Time | level | filename |pod_name|pid:tid | trace_id | az_name | status_code | action | cost | data size| request param| response param -3 访问第三方日志 Time | level | filename| pod_name | pid:tid | trace_id | az_name | status_code | action | cost | data size| request param| response param -4 资源日志 Time | level | filename | pod_name | pid:tid | trace_id | az_name | 共享内存信息 | spill磁盘信息 | 客户端数 | Object总数 | Object数据总大小 | WorkerOcService线程池 | WorkerWorkerOcService线程池 | MasterWokrerOcService线程池 | MasterOcService线程池 |写l2cache队列 |写ETCD队列 | ETCD请求成功率 | OBS请求成功率 | Master异步任务线程池 | 流总数 | ClientWorkerSCService线程池 | WorkerWorkerSCService线程池 | MasterWorkerSCService线程池 | MasterSCService线程池 | 流远端推送成功率 | 共享磁盘信息 | scLocalCache信息 -5 容器运行日志 Time | level | filename| pod_name | pid:tid | trace_id | az_name | message +1 运行日志 Time | level | filename| pod_name | pid:tid | trace_id | cluster_name | message +2 访问日志 Time | level | filename |pod_name|pid:tid | trace_id | cluster_name | status_code | action | cost | data size| request param| response param +3 访问第三方日志 Time | level | filename| pod_name | pid:tid | trace_id | cluster_name | status_code | action | cost | data size| request param| response param +4 资源日志 Time | level | filename | pod_name | pid:tid | trace_id | cluster_name | 共享内存信息 | spill磁盘信息 | 客户端数 | Object总数 | Object数据总大小 | WorkerOcService线程池 | WorkerWorkerOcService线程池 | MasterWokrerOcService线程池 | MasterOcService线程池 |写l2cache队列 |写ETCD队列 | ETCD请求成功率 | OBS请求成功率 | Master异步任务线程池 | 流总数 | ClientWorkerSCService线程池 | WorkerWorkerSCService线程池 | MasterWorkerSCService线程池 | MasterSCService线程池 | 流远端推送成功率 | 共享磁盘信息 | scLocalCache信息 +5 流缓存数据日志 Time | level | filename| pod_name | pid:tid | trace_id | cluster_name | sc_metric +6 容器运行日志 Time | level | filename| pod_name | pid:tid | trace_id | cluster_name | message ----------------------------------------------------------------------------------------------------------------------------- @@ -42,7 +44,7 @@ filename 128 输出该条日志的函数所在文件及 pod_name 128 输出当前worker所属的POD名称,超出长度则截断。示例:ds-worker-hs5qm pid:tid 11 该日志所属的进程ID和线程ID。进程号最大值为32757,该字段最大长度11.示例:9:177 traceid 36 请求的traceid。 -az_name 128 输出AZ名称,最大长度为128,超出长度则截断。示例:AZ1。 +cluster_name 128 输出AZ名称,最大长度为128,超出长度则截断。示例:AZ1。 Message 1024 自定义消息内容 status_code 5 该请求的状态,不同消息类型状态值不一样。SDK/Worker 访问日志,0表示成功,其他表示失败。l2cache/AGC 访问日志:http请求,200表示成功,其他表示失败。 action 64 表示该请求所访问的接口名称。约定前缀:SDK接口:DS_STATE_CLINET、DS_OBJECT_CLIENT,Worker接口:DS_OBJECT_POSIX、DS_STREAM_POSIX,ETCD:DS_ETCD,HTTP请求:POST {url path},示例:POST /v1/agc/token @@ -55,6 +57,8 @@ response param 1024 记录该请求的响应信息。最大长 2) physicalMemoryUsage 已分配的物理内存大小。 3) totalLimit 共享内存总大小。 4) Rate 共享内存使用率,memoryUsage/totalLimit, 保留3位有效数字,单位: %. + 5) scMemoryUsage 流缓存使用共享内存大小 + 6) scMemoryLimit 流缓存共享内存总大小 spill磁盘信息 47 记录Spill磁盘使用信息。单位为Byte,按照1T限制大小,每个长度 13 Byte,格式为:spaceUsage/physicalSpaceUsage/totalLimit/rate 1) spaceUsage 已使用的磁盘大小,是已Spill的对象大小总和。 2) physicalSpaceUsage 已使用的物理磁盘大小。 @@ -92,6 +96,34 @@ MasterSCService线程池 21 线程池使用信息,格式为:idleN 3) totalLimit 共享磁盘总大小。 4) rate 共享磁盘使用率,usage/totalLimit, 保留3位有效数字,单位: %. scLocalCache信息 47 记录scLocalCache使用信息,单位为Byte,按照1T限制大小,每个长度 13 Byte,格式为:usedSize/reservedSize/totalLimit/usedRate + +sc_metric 1024 流缓存运行数据(sc_stream_metric)。worker上一个stream的的流缓存数据,格式:streamName ["exit"]/numLocalProd/numRemoteProd/numLocalCon/numRemoteCon/sharedMemUsed/localMemUsed/numEleSent/numEleRecv/numEleAck/numSendReq/numRecvReq/numPagesCreated/numPagesReleased/numPagesInUse/numPagesCached/numBigPagesCreated/numBigPagesReleased/numLocalProdBlocked/numRemoteProdBlocked/numRemoteConBlocking/retainData/streamState/numProdMaster/numConMaster + 1) streamName ["exit"] stream名字,带有" exit"表示stream正要关闭 + 2) numLocalProd 本地producer数量 + 3) numRemoteProd - Number of remote workers with atleast one producer (value will be 0, if no local consumers) + 4) numLocalCon 本地consumer数量 + 5) numRemoteCon 远端consumer数量 + 6) sharedMemUsed stream使用共享内存大小,单位: Byte + 7) localMemUsed stream使用本地内存大小,单位: Byte + 8) numEleSent - Total number of elements produced by all local producers + 9) numEleRecv - Total number of elements received by all local consumers (value will be 0, if no local consumers) + 10) numEleAck element acked数量 + 11) numSendReq client调用producer.send()次数 + 12) numRecvReq client调用consumer.receive()次数 + 13) numPagesCreated page创建次数 + 14) numPagesReleased page释放次数 + 15) numPagesInUse page in use数量 + 16) numPagesCached page cached数量 + 17) numBigPagesCreated big element page创建次数 + 18) numBigPagesReleased big element page释放次数 + 19) numLocalProdBlocked 本地producer blocked数量 + 20) numRemoteProdBlocked 远端producer blocked数量 + 21) numRemoteConBlocking 远端consumer blocking数量 + 22) retainData retain data state + 23) streamState stream state + 24) numProdMaster master上producer数量 + 25) numConMaster master上consumer数量 + - 如果worker不是stream的master,24-25会没有数据。如果worker只有master数据,2-23会没有数据。 ----------------------------------------------------------------------------------------------------------------------------------------- 1.1.2.2.1 SDK与worker访问日志关键请求参数 diff --git a/README.md b/README.md index 71b0447..3016bef 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ openYuanrong datasystem 的主要特性包括: - **NPU 间高效数据传输**:将 NPU 的 HBM 抽象为异构对象,自动协调 NPU 间 HCCL 收发顺序,实现简单易用的卡间数据异步并发传输。并支持P2P传输负载均衡策略,充分利用卡间链路带宽。 - **灵活的生命周期管理**:支持设置 TTL、LRU 缓存淘汰以及 delete 接口等多种生命周期管理策略,数据生命周期既可由数据系统管理,也可交由上层应用管理,提供更高的灵活性。 - **热点数据多副本**:数据跨节点读取时自动在本地保存副本,支撑热点数据高效访问。本地副本使用 LRU 策略自动淘汰。 -- **多种数据可靠性策略**:支持 write_through、wirte_back 及 none 多种持久化策略,满足不同场景的数据可靠性需求。 +- **多种数据可靠性策略**:支持 write_through、write_back 及 none 多种持久化策略,满足不同场景的数据可靠性需求。 - **数据一致性**:支持 Causal 及 PRAM 两种数据一致性模型,用户可按需选择,实现性能和数据一致性的平衡。 - **数据发布订阅**:支持数据订阅发布,解耦数据的生产者(发布者)和消费者(订阅者),实现数据的异步传输与共享。 - **高可靠高可用**:支持分布式元数据管理,实现系统水平线性扩展。支持元数据可靠性,支持动态资源伸缩自动迁移数据,实现系统高可用。 diff --git a/README_CN.md b/README_CN.md index 4bb130c..c5d1eea 100644 --- a/README_CN.md +++ b/README_CN.md @@ -22,7 +22,7 @@ openYuanrong datasystem 的主要特性包括: - **NPU 间高效数据传输**:将 NPU 的 HBM 抽象为异构对象,自动协调 NPU 间 HCCL 收发顺序,实现简单易用的卡间数据异步并发传输。并支持P2P传输负载均衡策略,充分利用卡间链路带宽。 - **灵活的生命周期管理**:支持设置 TTL、LRU 缓存淘汰以及 delete 接口等多种生命周期管理策略,数据生命周期既可由数据系统管理,也可交由上层应用管理,提供更高的灵活性。 - **热点数据多副本**:数据跨节点读取时自动在本地保存副本,支撑热点数据高效访问。本地副本使用 LRU 策略自动淘汰。 -- **多种数据可靠性策略**:支持 write_through、wirte_back 及 none 多种持久化策略,满足不同场景的数据可靠性需求。 +- **多种数据可靠性策略**:支持 write_through、write_back 及 none 多种持久化策略,满足不同场景的数据可靠性需求。 - **数据一致性**:支持 Causal 及 PRAM 两种数据一致性模型,用户可按需选择,实现性能和数据一致性的平衡。 - **数据发布订阅**:支持数据订阅发布,解耦数据的生产者(发布者)和消费者(订阅者),实现数据的异步传输与共享。 - **高可靠高可用**:支持分布式元数据管理,实现系统水平线性扩展。支持元数据可靠性,支持动态资源伸缩自动迁移数据,实现系统高可用。 diff --git a/build.sh b/build.sh index 7c3c0f3..3b00302 100755 --- a/build.sh +++ b/build.sh @@ -19,9 +19,9 @@ source /etc/profile.d/*.sh readonly USAGE=" Usage: bash build.sh [-h] [-r] [-d] [-c off/on/html] [-t off|build|run] [-s on|off] [-j ] [-p on|off] [-S address|thread|undefined|off] [-o ] [-u ] - [-B ] [-P on/off] [-X on/off] [-T ] - [-R on/off] [-D \"on \"/off] [-l ] [-i on/off] [-n on/off] - [-x on/off] + [-B ] [-J on|off] [-P on/off] [-G on/off] [-X on/off] [-e on/off] [-T ] + [-R on/off] [-O on/off] [-I ] [-M \"on|off \"/off] [-D \"on \"/off] + [-C on/off] [-l ] [-i on/off] [-n on/off][-x on/off] Options: -h Output this help and exit. @@ -45,7 +45,8 @@ Options: For communication layer -M Build with URMA framework in addition to ZMQ, choose from on/off, default: off. - -D Download UB package that is needed for URMA, choose from on/off. When on, can also provide UB download options, default: on. + When on, can also provide the URMA mode, choose from IB/UB, default IB (URMA over IB). + -D Download UB package that is needed for URMA, choose from on/off. When on, can also provide UB download options, default: off. Notes to compile and run with URMA: 1. The default packages are for EulerOS-V2R10 environment 2. The downloaded rpm packages and the kernel modules need to be installed before run @@ -124,6 +125,7 @@ function init_default_opts() { # For communication layer export BUILD_WITH_URMA="off" + export URMA_OVER_UB="off" export DOWNLOAD_UB="off" export UB_URL="" export UB_SHA256="" @@ -232,6 +234,27 @@ function parse_thirdparty_versions() { done } +function parse_urma_options() { + local args + check_on_off "$1" M + BUILD_WITH_URMA="$1" + if [[ $# -gt 1 ]]; then + IFS=',' read -ra parts <<< "$2" + for part in "${parts[@]}"; do + case $part in + "IB") + ;; + "UB") + URMA_OVER_UB="on" + ;; + *) + echo "Invalid URMA mode option: $part" + ;; + esac + done + fi +} + function parse_ub_download_options() { local args check_on_off "$1" D @@ -286,17 +309,12 @@ function gen_html_coverage_report() { function build_example() { echo -e "---- building example..." - local example_build_dir="${DATASYSTEM_DIR}/example/build" + local example_build_dir="${DATASYSTEM_DIR}/example/cpp/build" # clean and create build dir. [[ "${BUILD_INCREMENT}" == "off" && -d "${example_build_dir}" ]] && rm -rf "${example_build_dir}" mkdir -p "${example_build_dir}" && cd "${example_build_dir}" - local prefix_path - prefix_path=${INSTALL_DIR}/sdk/cpp - - cmake "${DATASYSTEM_DIR}/example" \ - -DCMAKE_PREFIX_PATH="${prefix_path}" \ - -DBUILD_HETERO="${BUILD_HETERO}" || go_die "---- build example CMake project failed!" + cmake "${DATASYSTEM_DIR}/example/cpp" || go_die "---- build example CMake project failed!" make || go_die "---- example make failed!" echo -e "---- build example success!" @@ -319,9 +337,7 @@ function run_example() { export LD_LIBRARY_PATH=$new_ld_path echo -e "---- Sanitize LD_LIBRARY_PATH from ${old_ld_path} to ${new_ld_path}" - python3 -m pip install ${INSTALL_DIR}/openyuanrong_datasystem-*.whl --force-reinstall - bash "${DATASYSTEM_DIR}/example/run-example.sh" "${BUILD_HETERO}" "${ENABLE_PERF}" || - (remove_running_pids && go_die "---- Smoke Testing failed!") + bash "${DATASYSTEM_DIR}/example/run-example.sh" || (remove_running_pids && go_die "---- Smoke Testing failed!") echo -e "---- Smoke Testing success!" echo -e "---- [TIMER] Run example: $(($(date +%s)-$baseTime_s)) seconds" @@ -460,6 +476,15 @@ function version_lt() [ "$1" = "$(echo -e "$1\n$2" | sort -V | head -n1)" ] && [ "$1" != "$2" ] } +function generate_config() +{ + local config_file=${BASE_DIR}/config.cmake + [[ -f "${config_file}" ]] && rm -f "${config_file}" + echo "set(INSTALL_DIR \"${INSTALL_DIR}\")" >> "${config_file}" + echo "set(BUILD_HETERO \"${BUILD_HETERO}\")" >> "${config_file}" + echo "set(PACKAGE_PYTHON \"${PACKAGE_PYTHON}\")" >> "${config_file}" +} + function build_datasystem() { # clean and create build dir. @@ -475,6 +500,8 @@ function build_datasystem() BUILD_TESTCASE="on" fi + generate_config + local cmake_options=( "${DATASYSTEM_DIR}" "-DCMAKE_BUILD_TYPE:STRING=${BUILD_TYPE}" @@ -489,6 +516,7 @@ function build_datasystem() "-DBUILD_HETERO:BOOL=${BUILD_HETERO}" "-DSUPPORT_JEPROF:BOOL=${SUPPORT_JEPROF}" "-DBUILD_WITH_URMA:BOOL=${BUILD_WITH_URMA}" + "-DURMA_OVER_UB:BOOL=${URMA_OVER_UB}" ) if [[ "${BUILD_WITH_URMA}" == "on" ]]; then @@ -496,9 +524,10 @@ function build_datasystem() "-DDOWNLOAD_UB:BOOL=${DOWNLOAD_UB}" "-DUB_URL:STRING=${UB_URL}" "-DUB_SHA256:STRING=${UB_SHA256}" + "-DURMA_OVER_UB:BOOL=${URMA_OVER_UB}" ) fi - + if is_on "${PACKAGE_PYTHON}" && [ -n "${PYTHON_ROOT_DIR}" ]; then echo -e "-- Specify python root path: ${PYTHON_ROOT_DIR}" cmake_options=("${cmake_options[@]}" "-DPython3_ROOT_DIR:PATH=${PYTHON_ROOT_DIR}") @@ -662,8 +691,7 @@ function main() { parse_thirdparty_versions "${OPTARG}" ;; M) - check_on_off "${OPTARG}" M - BUILD_WITH_URMA="${OPTARG}" + parse_urma_options ${OPTARG} ;; D) diff --git a/cli/command.py b/cli/command.py index 061ecc1..06e2521 100644 --- a/cli/command.py +++ b/cli/command.py @@ -38,23 +38,24 @@ class BaseCommand: def __init__(self): """Initialize of command""" - - def valid_safe_path(path: str): - unsafe_dirs = ["/bin", "/sbin", "/lib", "/lib64", "/", "/boot", "/dev", "/etc", "/sys", "/proc"] - norm_path = os.path.normpath(os.path.abspath(path)) - if norm_path in unsafe_dirs: - raise ValueError(f"Path {path} is outside allowed directory") - for parent in unsafe_dirs: - if parent == "/": - continue - if norm_path.startswith(parent + os.sep): - raise ValueError(f"Path {path} is outside allowed directory") - return norm_path - if BaseCommand.logger is None: BaseCommand._configure_logging() self._base_dir = str(resources.files("datasystem")) - self._base_dir = valid_safe_path(self._base_dir) + self._base_dir = self.valid_safe_path(self._base_dir) + + @staticmethod + def valid_safe_path(path: str): + """Validate the legality of the input path.""" + unsafe_dirs = ["/bin", "/sbin", "/lib", "/lib64", "/", "/boot", "/dev", "/etc", "/sys", "/proc"] + norm_path = os.path.normpath(os.path.abspath(path)) + if norm_path in unsafe_dirs: + raise ValueError(f"Path {path} is outside allowed directory") + for parent in unsafe_dirs: + if parent == "/": + continue + if norm_path.startswith(parent + os.sep): + raise ValueError(f"Path {path} is outside allowed directory") + return norm_path @staticmethod def add_arguments(parser): diff --git a/example/cpp_template/CMakeLists.txt b/cli/cpp_template/CMakeLists.txt similarity index 100% rename from example/cpp_template/CMakeLists.txt rename to cli/cpp_template/CMakeLists.txt diff --git a/example/cpp_template/README.md b/cli/cpp_template/README.md similarity index 100% rename from example/cpp_template/README.md rename to cli/cpp_template/README.md diff --git a/example/cpp_template/kv_cache_example.cpp b/cli/cpp_template/kv_cache_example.cpp similarity index 100% rename from example/cpp_template/kv_cache_example.cpp rename to cli/cpp_template/kv_cache_example.cpp diff --git a/example/cpp_template/run.sh b/cli/cpp_template/run.sh similarity index 100% rename from example/cpp_template/run.sh rename to cli/cpp_template/run.sh diff --git a/cli/deploy/conf/worker_config.json b/cli/deploy/conf/worker_config.json index 3d41f98..2bee2d4 100644 --- a/cli/deploy/conf/worker_config.json +++ b/cli/deploy/conf/worker_config.json @@ -11,18 +11,22 @@ "value": "1024", "description": "Upper limit of the shared memory, the unit is mb, must be greater than 0." }, + "sc_local_cache_memory_size_mb": { + "value": "1024", + "description": "Upper limit of the local cache used by stream cache, the unit is mb, must be greater than 0." + }, "oc_shm_threshold_percentage": { "value": "100", "description": "Upper limit of the shared memory in percentage can be used by OC, must be within (0, 100]" + }, + "sc_shm_threshold_percentage": { + "value": "100", + "description": "Upper limit of the shared memory in percentage can be used by SC, must be within (0, 100]" }, "heartbeat_interval_ms": { "value": "1000", "description": "Time interval between worker and etcd heartbeats." }, - "authorization_enable": { - "value": "false", - "description": "Indicates whether to enable the tenant authentication, default is false." - }, "ipc_through_shared_memory": { "value": "true", "description": "Using shared memory to exchange data between client and worker. If this parameter is set to true, client and worker will pass control messages through Unix Domain Sockets (UDS); otherwise, they will pass control messages through TCP/IP and exchange data through TCP/IP." @@ -39,22 +43,70 @@ "value": "", "description": "Specify other az names using the same etcd. Split by ','" }, - "oc_io_from_l2cache_need_metadata": { - "value": "true", - "description": "Control whether data read and write from the L2 cache daemon depend on metadata. Note: If set to false, it indicates that the metadata is not stored in etcd." - }, "enable_distributed_master": { "value": "true", "description": "Whether to support distributed master, default is true." }, + "sc_regular_socket_num": { + "value": "32", + "description": "The number of regular backend socket for stream cache, must be great equal than 0." + }, + "sc_stream_socket_num": { + "value": "32", + "description": "The number of stream backend socket for stream cache, must be great equal than 0." + }, "oc_worker_worker_direct_port": { "value": "0", "description": "A direct tcp/ip port for worker to workers scenarios to improve latency. Acceptable values:0 or a positive integer. 0 indicates disabled." }, + "sc_worker_worker_direct_port": { + "value": "0", + "description": "A direct tcp/ip port for worker to workers scenarios to improve latency. Acceptable values:0, or some positive integer. 0 means disabled." + }, "oc_worker_worker_pool_size": { "value": "3", "description": "Number of parallel connections between worker/worker oc service. The flag `oc_worker_worker_direct_port` must be enabled for this setting to take effect" }, + "sc_worker_worker_pool_size": { + "value": "3", + "description": "Number of parallel connections between worker/worker sc service. Flag sc_worker_worker_direct_port must be enabled to take effect." + }, + "sc_gc_interval_ms": { + "value": "50", + "description": "Memory resource clean up interval." + }, + "sc_scan_interval_ms": { + "value": "10", + "description": "Scan interval for remote send" + }, + "sc_metrics_log_interval_s": { + "value": "60", + "description": "Interval between logging stream metrics" + }, + "sc_cache_pages": { + "value": "16", + "description": "Default number of cache pages" + }, + "sc_scan_thread_num": { + "value": "16", + "description": "The num of threads used to scan new elements in shared memory." + }, + "sc_scan_num_buckets": { + "value": "1024", + "description": "Number of partitions for scanning streams." + }, + "sc_shared_page_size_mb": { + "value": "4", + "description": "The shared page size, should be in range [1, 16]." + }, + "sc_shared_page_group_count": { + "value": "4", + "description": "The shared page group count for each remote worker, should be in range [1, 64]." + }, + "master_sc_thread_num": { + "value": "128", + "description": "Max number of threads for (non rpc) master stream cache service work." + }, "payload_nocopy_threshold": { "value": "104857600", "description": "minimum payload size to trigger no memory copy" @@ -75,6 +127,10 @@ "value": "5", "description": "Client reconnect wait seconds, default is 5." }, + "page_size": { + "value": "1048576", + "description": "Size of the page used for caching worker files. The valid range is 4096-1073741824." + }, "etcd_meta_pool_size": { "value": "8", "description": "ETCD metadata async pool size." @@ -155,18 +211,10 @@ "value": "3600", "description": "when we delete the object which is in uploading process, in this scenarios we need delay some time to retry." }, - "cache_rpc_session": { - "value": "true", - "description": "Deprecated: This flag is deprecated and will be removed in future releases." - }, "rocksdb_store_dir": { "value": "./datasystem/rocksdb", "description": "Config MASTER back store directory and must specify in rocksdb scenario. The rocksdb database is used to persistently store the metadata stored in the master so that the metadata before the restart can be re-obtained when the master restarts." }, - "rocksdb_sync_write": { - "value": "false", - "description": "Controls whether rocksdb sets sync to true when writing data." - }, "rocksdb_max_open_file": { "value": "128", "description": "Number of open files that can be used by the rocksdb, default value is 128." @@ -175,6 +223,10 @@ "value": "16", "description": "Number of background threads rocksdb can use for flushing and compacting, default value is 16(Should be greater than 0)." }, + "rocksdb_write_mode": { + "value": "async", + "description": "Config the rocksdb support none, sync or async, async by default. Optional value: 'none', 'sync', 'async'. This represents the method of writing metadata to rocksdb." + }, "node_timeout_s": { "value": "60", "description": "Maximum time interval before a node is considered lost." @@ -236,11 +288,11 @@ "description": "Control whether get meta data from other AZ's worker, if false then get meta data from local AZ." }, "enable_worker_worker_batch_get": { - "value" : "false", - "description" : "Enable worker->worker OC batch get, default false." + "value": "false", + "description": "Enable worker->worker OC batch get, default false." }, "batch_get_threshold_mb": { - "value" : "100", + "value": "100", "description": "The payload threshold to batch get objects, the unit is mb, must be greater than 0. Setting to 0 will indicate no split." }, "enable_reconciliation": { @@ -355,10 +407,6 @@ "value": "", "description": "The directory to find ZMQ curve key files. This path must be specified when zmq authentication is enabled." }, - "encrypt_kit": { - "value": "plaintext", - "description": "choose the type of encrypt. Support plaintext, default is plaintext." - }, "enable_etcd_auth": { "value": "false", "description": "Whether to enable ETCD auth, default is false. ETCD certificate will be obtained." @@ -402,5 +450,61 @@ "enable_p2p_transfer": { "value": "false", "description": "Heterogeneous object transfer protocol Enables p2ptransfer." + }, + "enable_cloud_service_token_rotation": { + "value": "false", + "description": "Enable the OBS client to access OBS using a temporary token. After the token expires, obtain a new token and connect to OBS again." + }, + "enable_meta_replica": { + "value": "false", + "description": "Controls whether to enable multiple meta replica" + }, + "enable_urma": { + "value": "false", + "description": "Option to turn on urma for OC worker to worker data transfer, default false." + }, + "urma_connection_size": { + "value": "16", + "description": "Number of jfs and jfr pair" + }, + "urma_event_mode": { + "value": "false", + "description": "Uses interrupt mode to poll completion events." + }, + "urma_poll_size": { + "value": "8", + "description": "Number of complete record to poll at a time, 16 is the max this device can poll" + }, + "urma_register_whole_arena": { + "value": "true", + "description": "Register the whole arena as segment during init, otherwise, register each object as a segment." + }, + "logfile_mode": { + "value": "416", + "description": "Log file mode/permissions." + }, + "oc_shm_transfer_threshold_kb": { + "value": "500", + "description": "The data threshold to transfer obj data between client and worker via shm, unit is KB" + }, + "remote_send_thread_num": { + "value": "8", + "description": "The num of threads used to send elements to remote worker." + }, + "shared_disk_arena_per_tenant": { + "value": "8", + "description": "The number of disk cache Arena for each tenant. Multiple arenas can improve the performance of shared disk allocation for the first time, but each arena will use one more fd. The valid range is 0 to 32." + }, + "shared_disk_directory": { + "value": "", + "description": "Disk cache data placement directory, default value is empty, indicating that disk cache is not enabled." + }, + "shared_disk_size_mb": { + "value": "0", + "description": "Upper limit of the shared disk, the unit is mb." + }, + "stream_idle_time_s": { + "value": "300", + "description": "stream idle time. default 300s (5 minutes)" } } \ No newline at end of file diff --git a/cli/generate_helm_chart.py b/cli/generate_helm_chart.py index d563633..fa75719 100644 --- a/cli/generate_helm_chart.py +++ b/cli/generate_helm_chart.py @@ -43,7 +43,7 @@ class Command(BaseCommand): metavar='OUTPUT_PATH', default=os.getcwd(), help='path to save the generated Helm chart, default path is the current directory. \ - Example: dscli generate_helm_chart --output-path /home/user/helmCharts' + Example: dscli generate_helm_chart --output_path /home/user/helmCharts' ) def run(self, args): diff --git a/cli/start.py b/cli/start.py index c102043..a76cf78 100644 --- a/cli/start.py +++ b/cli/start.py @@ -225,6 +225,7 @@ class Command(BaseCommand): if not ready_check_path: raise RuntimeError("ready_check_path is empty") ready_check_path = os.path.abspath(ready_check_path) + ready_check_path = self.valid_safe_path(ready_check_path) if os.path.exists(ready_check_path) and os.path.isfile(ready_check_path): os.remove(ready_check_path) process = subprocess.Popen( diff --git a/cmake/external_libs/libcurl.cmake b/cmake/external_libs/libcurl.cmake index 3c1d1a3..ed739a6 100644 --- a/cmake/external_libs/libcurl.cmake +++ b/cmake/external_libs/libcurl.cmake @@ -14,14 +14,15 @@ set(curl_C_FLAGS ${THIRDPARTY_SAFE_FLAGS}) if (curl_VERSION STREQUAL "8.8.0") set(curl_PATCHES - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-6197-fix-CVE-2024-6197-for-curl-8.8.0-c.patch - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-6874-fix-CVE-2024-6874-for-curl-8.8.0-c.patch - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-7264-fix-CVE-2024-7264-for-curl-8.8.0-c.patch - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-8096-fix-CVE-2024-8096-for-curl-8.8.0-c.patch - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-9681-fix-CVE-2024-9681-for-curl-8.8.0-c.patch - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-11053-fix-CVE-2024-11053-for-curl-8.8.0-c.patch - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2025-0167-fix-CVE-2025-0167-for-curl-8.8.0-c.patch - ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2025-0725-fix-CVE-2025-0725-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-6197-fix-CVE-2024-6197-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-6874-fix-CVE-2024-6874-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-7264-fix-CVE-2024-7264-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-8096-fix-CVE-2024-8096-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-9681-fix-CVE-2024-9681-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2024-11053-fix-CVE-2024-11053-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2025-0167-fix-CVE-2025-0167-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/Backport-CVE-2025-0725-fix-CVE-2025-0725-for-curl-8.8.0-c.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/curl/8.8.0/support_old_cmake.patch ) endif() diff --git a/cmake/external_libs/sdk_c_obs.cmake b/cmake/external_libs/sdk_c_obs.cmake index f0127ea..604c92a 100644 --- a/cmake/external_libs/sdk_c_obs.cmake +++ b/cmake/external_libs/sdk_c_obs.cmake @@ -22,7 +22,9 @@ set(obs_CXX_FLAGS ${THIRDPARTY_SAFE_FLAGS}) set(obs_C_FLAGS ${THIRDPARTY_SAFE_FLAGS}) set(obs_LINK_FLAGS ${THIRDPARTY_SAFE_FLAGS}) -set(obs_PATCHES ${CMAKE_SOURCE_DIR}/third_party/patches/obs/3.24.3/obs-sdk-cmake-install.patch) +set(obs_PATCHES + ${CMAKE_SOURCE_DIR}/third_party/patches/obs/3.24.3/obs-sdk-cmake-install.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/obs/3.24.3/obs-sdk-change-spdlog.patch) add_thirdparty_lib(OBS URL ${obs_URL} diff --git a/cmake/external_libs/spdlog.cmake b/cmake/external_libs/spdlog.cmake index 7fc9fe5..22dc110 100644 --- a/cmake/external_libs/spdlog.cmake +++ b/cmake/external_libs/spdlog.cmake @@ -14,7 +14,8 @@ set(spdlog_CMAKE_OPTIONS -DSPDLOG_BUILD_SHARED:BOOL=ON) set(spdlog_PATCHES - ${CMAKE_SOURCE_DIR}/third_party/patches/spdlog/change-filename.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/spdlog/change-filename.patch + ${CMAKE_SOURCE_DIR}/third_party/patches/spdlog/change-namespace.patch ${CMAKE_SOURCE_DIR}/third_party/patches/spdlog/change-rotating-file-sink.patch) add_thirdparty_lib(SPDLOG @@ -36,5 +37,5 @@ find_library(SPDLOG set(CMAKE_PREFIX_PATH ${SPDLOG_ROOT}) find_package(spdlog ${spdlog_VERSION} REQUIRED) -get_property(spdlog_INCLUDE_DIR TARGET spdlog::spdlog PROPERTY INTERFACE_INCLUDE_DIRECTORIES) +get_property(spdlog_INCLUDE_DIR TARGET ds_spdlog::spdlog PROPERTY INTERFACE_INCLUDE_DIRECTORIES) include_directories(SYSTEM ${spdlog_INCLUDE_DIR}) diff --git a/cmake/external_libs/ub.cmake b/cmake/external_libs/ub.cmake index 64ce772..1e625ba 100644 --- a/cmake/external_libs/ub.cmake +++ b/cmake/external_libs/ub.cmake @@ -1,22 +1,31 @@ if (DOWNLOAD_UB) # set the url and archive name to be downloaded - set(UB_ARCHIVE_DATETIME 20240124) # format YYYYMMDDhhmmss - if ("${UB_URL}" STREQUAL "" AND "${UB_SHA256}" STREQUAL "") - if (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") - # wait for modify - set(UB_URL "xxx") - set(UB_SHA256 "xxx") - elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") - # wait for modify - set(UB_URL "xxx") - set(UB_SHA256 "xxx") - else() + if (URMA_OVER_UB) + set(UB_ARCHIVE_DATETIME 20251015) # format YYYYMMDDhhmmss + if ("${UB_URL}" STREQUAL "" OR "${UB_SHA256}" STREQUAL "") + message(FATAL_ERROR "UMDK package download paths need to be specified for URMA over UB scenario") + endif() + if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + # Note that only openeuler2403 is supported message(FATAL_ERROR "Unsupported system processor: ${CMAKE_SYSTEM_PROCESSOR}") endif() + else() + set(UB_ARCHIVE_DATETIME 20240124) # format YYYYMMDDhhmmss + if ("${UB_URL}" STREQUAL "" AND "${UB_SHA256}" STREQUAL "") + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + set(UB_URL "xxx") + set(UB_SHA256 "xxx") + elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + set(UB_URL "xxx") + set(UB_SHA256 "xxx") + else() + message(FATAL_ERROR "Unsupported system processor: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + endif() endif() ADD_THIRDPARTY_LIB(UB URL ${UB_URL} SHA256 ${UB_SHA256} VERSION ${UB_ARCHIVE_DATETIME}) -endif() \ No newline at end of file +endif() diff --git a/cmake/external_libs/urma.cmake b/cmake/external_libs/urma.cmake index d854c98..6989c23 100644 --- a/cmake/external_libs/urma.cmake +++ b/cmake/external_libs/urma.cmake @@ -23,4 +23,7 @@ include_directories(${URMA_INCLUDE_DIR}) include_directories(${URMA_INCLUDE_DIR}/common) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-field-initializers -Wno-unused-parameter") -add_definitions(-DUSE_URMA) \ No newline at end of file +add_definitions(-DUSE_URMA) +if (URMA_OVER_UB) + add_definitions(-DURMA_OVER_UB) +endif() \ No newline at end of file diff --git a/cmake/modules/FindURMA.cmake b/cmake/modules/FindURMA.cmake index b74b683..6567963 100644 --- a/cmake/modules/FindURMA.cmake +++ b/cmake/modules/FindURMA.cmake @@ -23,23 +23,8 @@ find_library(URMA_LIBRARY urma NO_CMAKE_SYSTEM_PATH NO_SYSTEM_ENVIRONMENT_PATH) -find_library(URMA_IP_LIBRARY urma_ip - PATHS ${URMA_IP_IB_LIB_LOCATION} - DOC "URMA IP library" - NO_CMAKE_SYSTEM_PATH - NO_SYSTEM_ENVIRONMENT_PATH) - -find_library(URMA_IB_LIBRARY urma_ib - PATHS ${URMA_IP_IB_LIB_LOCATION} - DOC "URMA IB library" - NO_CMAKE_SYSTEM_PATH - NO_SYSTEM_ENVIRONMENT_PATH) - include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(URMA REQUIRED_VARS - URMA_LIBRARY URMA_IP_LIBRARY URMA_IB_LIBRARY URMA_INCLUDE_DIR) +find_package_handle_standard_args(URMA REQUIRED_VARS URMA_LIBRARY URMA_INCLUDE_DIR) message(STATUS "URMA_LIBRARY=${URMA_LIBRARY}") -message(STATUS "URMA_IP_LIBRARY=${URMA_IP_LIBRARY}") -message(STATUS "URMA_IB_LIBRARY=${URMA_IB_LIBRARY}") message(STATUS "URMA_INCLUDE_DIR=${URMA_INCLUDE_DIR}") diff --git a/cmake/package.cmake b/cmake/package.cmake index 72fadc1..1842932 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -135,6 +135,7 @@ if (BUILD_GO_API) install(FILES ${CMAKE_SOURCE_DIR}/src/datasystem/c_api/status_definition.h ${CMAKE_SOURCE_DIR}/src/datasystem/c_api/state_cache_c_wrapper.h + ${CMAKE_SOURCE_DIR}/src/datasystem/c_api/stream_cache_c_wrapper.h ${CMAKE_SOURCE_DIR}/src/datasystem/c_api/object_cache_c_wrapper.h ${CMAKE_SOURCE_DIR}/src/datasystem/c_api/utilC.h ${CMAKE_SOURCE_DIR}/src/datasystem/c_api/cipher.h diff --git a/cmake/util.cmake b/cmake/util.cmake index ce97206..aea57de 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -54,7 +54,7 @@ function(__EXEC_COMMAND) endfunction() function(DOWNLOAD_LIB_PKG LIB_NAME URL SHA256) - # OpenEuler tiny package url end with "rpm" suffix, we need + # OpenEuler tiny package url end with "rpm" suffix, we need # to uncompress it and get the real source code package. if (URL MATCHES ".*\.src\.rpm$") FetchContent_Declare( @@ -172,7 +172,7 @@ function(GEN_THIRDPARTY_PKG NAME URL SHA256 FAKE_SHA256 VERSION) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) file(SHA256 "${_DEST_PATH}" _SHA256) endif() - + # Set output variables. set(${URL} "${_DEST_PATH}" PARENT_SCOPE) set(${SHA256} "${_SHA256}" PARENT_SCOPE) @@ -374,6 +374,7 @@ function(ADD_THIRDPARTY_LIB LIB_NAME) # extract files from rpm file(GLOB RPM_FILES "${${_LIB_NAME_LOWER}_SOURCE_DIR}/umdk-*.rpm") foreach(file ${RPM_FILES}) + message("process ${file}") execute_process(COMMAND rpm2cpio ${file} COMMAND cpio -idmv WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR} @@ -384,37 +385,58 @@ function(ADD_THIRDPARTY_LIB LIB_NAME) endif() endforeach() # Copy headers - file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/flowbuf/cpp/flowbuffer.h - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/common/ub_util.h - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urma_opcode.h - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urma_types.h - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urma_api.h - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urpc/cpp - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urpc/urpc.h - DESTINATION ${${LIB_NAME}_ROOT}/include) + if (URMA_OVER_UB) + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/ub/umdk/urma/urma_opcode.h + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/ub/umdk/urma/urma_types.h + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/ub/umdk/urma/urma_api.h + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/ub/umdk/urma/urma_ubagg.h + DESTINATION ${${LIB_NAME}_ROOT}/include) + else() + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urma_opcode.h + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urma_types.h + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urma_api.h + DESTINATION ${${LIB_NAME}_ROOT}/include) + endif() # Copy libs - file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/libflowbuffer.so - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburma.so + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburma.so ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburma.so.0 ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburma.so.0.0.1 ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburma_common.so ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburma_common.so.0 ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburma_common.so.0.0.1 - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburpc.so - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburpc_cpp.so DESTINATION ${${LIB_NAME}_ROOT}/lib64) - # copy only ib libs to /urma - file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ib.so - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ib.so.0 - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ib.so.0.0.1 - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ip.so - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ip.so.0 - ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ip.so.0.0.1 - DESTINATION ${${LIB_NAME}_ROOT}/lib64/urma) - - # Copy bins - file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/bin/flowc - DESTINATION ${${LIB_NAME}_ROOT}/bin) + # copy the ib/ip/ub libs to /urma if applicable + if (EXISTS ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ib.so) + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ib.so + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ib.so.0 + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ib.so.0.0.1 + DESTINATION ${${LIB_NAME}_ROOT}/lib64/urma) + endif() + if (EXISTS ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ip.so) + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ip.so + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ip.so.0 + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ip.so.0.0.1 + DESTINATION ${${LIB_NAME}_ROOT}/lib64/urma) + endif() + if (EXISTS ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ubagg.so) + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ubagg.so + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ubagg.so.0 + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/urma/liburma_ubagg.so.0.0.1 + DESTINATION ${${LIB_NAME}_ROOT}/lib64/urma) + endif() + # copy libs and bins for URPC + if (BUILD_WITH_URPC) + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/flowbuf/cpp/flowbuffer.h + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urpc/cpp + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/include/umdk/urpc/urpc.h + DESTINATION ${${LIB_NAME}_ROOT}/include) + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/libflowbuffer.so + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburpc.so + ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/lib64/liburpc_cpp.so + DESTINATION ${${LIB_NAME}_ROOT}/lib64) + file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/usr/bin/flowc + DESTINATION ${${LIB_NAME}_ROOT}/bin) + endif() elseif(${_TOOLCHAIN_LOWER} STREQUAL "cmake") if (ARG_CXX_FLAGS) list(APPEND ARG_CONF_OPTIONS "-DCMAKE_CXX_FLAGS=${ARG_CXX_FLAGS}") @@ -509,7 +531,7 @@ endfunction() # # ${LIB_NAME}_URL # third-party library download url. -# +# # ${LIB_NAME}_SHA256 # third-party library SHA256 for verify. function(ADJUICE_THIRDPARTY_VERSION LIB_NAME) @@ -596,7 +618,7 @@ endmacro() # # WORKING_DIR # Specify the working directory when running test executable. -# +# # TEST_ENVIRONMENTS ... # Specify the test environments variables when running test executable. function(ADD_DATASYSTEM_TEST TARGET) @@ -604,7 +626,7 @@ function(ADD_DATASYSTEM_TEST TARGET) set(one_value_args WORKING_DIR) set(multi_value_args TEST_ENVIRONMENTS) cmake_parse_arguments(ARG "${options}" "${one_value_args}" "${multi_value_args}" ${ARGN}) - + if (NOT ARG_WORKING_DIR) set(ARG_WORKING_DIR ${CMAKE_CURRENT_BINARY_DIR}) endif() @@ -644,7 +666,7 @@ function(ADD_DATASYSTEM_TEST TARGET) endfunction() # Clean target build rpath. Some targets like python shared library -# would package as jar/wheel format file first then install. But there +# would package as jar/wheel format file first then install. But there # build rpath need to erase first. # # TARGET is CMake target, shared library or executable is available. @@ -665,8 +687,8 @@ function(CLEAN_BUILD_RPATH TARGET) endfunction() # Run StripAndGenHash.cmake to strip libraries in install stage. -# -# LIB_LIST +# +# LIB_LIST # Specify the list of library path waiting to strip. # DST_DIR # Specify the destination directory path of StripAndGenHash run. @@ -691,7 +713,7 @@ endfunction() # PACKAGE_NAME is the package name of python library # # Additional optional arguments: -# +# # CMAKE_INSTALL_PATH # Specify the directory path where python whl file save. # @@ -711,7 +733,7 @@ function(PACKAGE_DATASYSTEM_WHEEL PACKAGE_NAME) set(DATASYSTEM_WHEEL_PATH ${CMAKE_BINARY_DIR}/dist/datasystem) set(DATASYSTEM_SETUP_PATH ${CMAKE_BINARY_DIR}/dist) set(DATASYSTEM_PACKAGE_LIBPATH ${CMAKE_SOURCE_DIR}) - + # Store helm chart set(HELM_CHART_PATH ${CMAKE_SOURCE_DIR}/k8s/helm_chart) set(SERVER_LIB ${CMAKE_INSTALL_PREFIX}/service/lib) @@ -725,7 +747,7 @@ function(PACKAGE_DATASYSTEM_WHEEL PACKAGE_NAME) foreach(_PATTERN ${ARG_THIRDPATRY_LIBS_PATTERN}) file(GLOB_RECURSE THIRDPARTY_LIB_LIST ${_PATTERN}) - foreach(_THIRDPARTY_LIB ${THIRDPARTY_LIB_LIST}) + foreach(_THIRDPARTY_LIB ${THIRDPARTY_LIB_LIST}) install(FILES ${_THIRDPARTY_LIB} DESTINATION ${PYTHON_LIBPATH}) if (NOT IS_SYMLINK ${_THIRDPARTY_LIB}) get_filename_component(_LIB_NAME ${_THIRDPARTY_LIB} NAME) @@ -744,45 +766,45 @@ function(PACKAGE_DATASYSTEM_WHEEL PACKAGE_NAME) # Strip libraries in PYTON_LIBPATH list(TRANSFORM ARG_NEED_STRIP_LIBS PREPEND "${PYTHON_LIBPATH}/") strip_libs_in_install_stage(NEED_STRIP_LIBS PYTHON_LIBPATH) - + # Copy chart files to package lib path install(DIRECTORY ${HELM_CHART_PATH}/ DESTINATION ${DATASYSTEM_WHEEL_PATH}/helm_chart) - + # Copy cpp include files to package lib path install(DIRECTORY ${CMAKE_SOURCE_DIR}/include DESTINATION ${DATASYSTEM_WHEEL_PATH}/) - + # Copy service lib to package lib path install(DIRECTORY ${SERVER_LIB}/ DESTINATION ${DATASYSTEM_WHEEL_PATH}/lib/) - + # Copy sdk lib install(DIRECTORY ${SDK_LIB}/ DESTINATION ${DATASYSTEM_WHEEL_PATH}/lib/ PATTERN "cmake" EXCLUDE) - + # Copy python sdk install(DIRECTORY ${PYTHON_SDK}/ DESTINATION ${DATASYSTEM_WHEEL_PATH}/) - + # Copy ds cli source files to package lib path install(DIRECTORY ${CMAKE_SOURCE_DIR}/cli DESTINATION ${DATASYSTEM_WHEEL_PATH}) - + #Copy setup.py - install(FILES ${CMAKE_SOURCE_DIR}/setup.py DESTINATION ${DATASYSTEM_SETUP_PATH}) - + install(FILES ${CMAKE_SOURCE_DIR}/setup.py DESTINATION ${DATASYSTEM_SETUP_PATH}/) + # Copy VERSION and LICENSE to package lib path install(FILES ${CMAKE_SOURCE_DIR}/VERSION ${CMAKE_SOURCE_DIR}/LICENSE ${CMAKE_SOURCE_DIR}/README.md DESTINATION ${DATASYSTEM_WHEEL_PATH}) # Copy cpp template to package lib path - install(DIRECTORY ${CMAKE_SOURCE_DIR}/example/cpp_template + install(DIRECTORY ${CMAKE_SOURCE_DIR}/cli/cpp_template DESTINATION ${DATASYSTEM_WHEEL_PATH}) # Copy worker and worker_config to package lib path install(FILES ${CMAKE_INSTALL_PREFIX}/service/datasystem_worker ${CMAKE_SOURCE_DIR}/cli/deploy/conf/worker_config.json ${CMAKE_SOURCE_DIR}/cli/deploy/conf/cluster_config.json DESTINATION ${DATASYSTEM_WHEEL_PATH}) - + find_package(Python3 COMPONENTS Interpreter Development) set(CONFIG_PACKAGE_SCRIPT ${CMAKE_BINARY_DIR}/PackageDatasystem.cmake) # Generate PackagePythonSDK.cmake to run setup.py @@ -916,10 +938,10 @@ endfunction() # # DEST_DIR # Specify the destination directory path of installing file. -# +# # PATH_PATTERN ... # Specify the library path pattern, like ${zlib_LIBRARY}/libz.so* . -# +# # PERMISSIONS ... # Specify permissions of copied file. function(INSTALL_FILE_PATTERN) diff --git a/docs/source_en/appendix/k8s_configuration.md b/docs/source_en/appendix/k8s_configuration.md index 977d40b..b5f2827 100644 --- a/docs/source_en/appendix/k8s_configuration.md +++ b/docs/source_en/appendix/k8s_configuration.md @@ -150,6 +150,8 @@ global: | global.rpc.zmqClientIoContext | int | `5` | Optimize the performance of the client stub. Default value is 5. The higher the throughput, the higher the value, but should be in range [1, 32] | | global.rpc.zmqChunkSz | int | `1048576` | Parallel payload split chunk size. Default to 1048756 bytes | | global.rpc.maxRpcSessionNum | int | `2048` | Maximum number of sessions that can be cached, must be within [512, 10'000] | +| global.rpc.streamIdleTimes | int | `300` | stream idle time. default 300s (5 minutes) | +| global.rpc.remoteSendThreadNum | int | `8` | The num of threads used to send elements to remote worker | **Example**: @@ -354,6 +356,7 @@ global: | global.metadata.rocksdbStoreDir | string | `"/home/sn/datasystem/rocksdb"` | Config MASTER back store directory and must specify in rocksdb scenario. The rocksdb database is used to persistently store the metadata stored in the master so that the metadata before the restart can be re-obtained when the master restarts | | global.metadata.rocksdbBackgroundThreads | int | `16` | Number of background threads rocksdb can use for flushing and compacting | | global.metadata.rocksdbMaxOpenFile | int | `128` | Number of open files that can be used by the rocksdb | +| global.metadata.rocksdbWriteMode | string | `async` | Config the rocksdb support none, sync or async, async by default. Optional value: 'none', 'sync', 'async'. This represents the method of writing metadata to rocksdb | ### Reliability Configurations @@ -369,6 +372,7 @@ global: | global.reliability.livenessProbeTimeoutS | int | `150` | Timeout interval of kubernetes liveness probe | | global.reliability.addNodeWaitTimeS | int | `60` | Time to wait for the first node that wants to join a working hash ring | | global.reliability.autoDelDeadNode | bool | `true` | Indicate dead nodes marked in the hash ring can be removed or not | +| global.reliability.enableDistributedMaster | bool | `true` | Whether to support distributed master, default is true | ### Graceful Shutdown Configurations @@ -391,6 +395,17 @@ global: | global.performance.arenaPerTenant | int | `16` | The arena count for each tenant. Multiple arenas can improve the performance of share memory allocation for the first time, but each arena will use one more fd, value range: [1, 32] | | global.performance.memoryReclamationTimeSecond | int | `600` | The memory reclamation time after free | | global.performance.asyncDelete | bool | `false` | Set whether to delete object asynchronously. If set to true, master will notify workers to delete objects asynchronously. Client doesn't need to wait for all workers to delete objects. | +| global.performance.enableP2pTransfer | bool | `false` | Heterogeneous object transfer protocol Enables p2p transfer | +| global.performance.enableWorkerWorkerBatchGet | bool | `false` | Enable worker->worker OC batch get, default false | +| global.performance.ocShmTransferThresholdKB | int | `500` | The data threshold to transfer obj data between client and worker via shm, unit is KB | +| global.performance.enableUrma | bool | `false` | Option to turn on urma for OC worker to worker data transfer, default false | +| global.performance.urmaPollSize | int | `8` | Number of complete record to poll at a time, 16 is the max this device can poll | +| global.performance.urmaRegisterWholeArena | bool | `true` | Register the whole arena as segment during init, otherwise, register each object as a segment | +| global.performance.urmaConnectionSize | int | `16` | Number of jfs and jfr pair | +| global.performance.urmaEventMode | bool | `false` | Uses interrupt mode to poll completion events | +| global.performance.sharedDiskDirectory | string | `""` | Disk cache data placement directory, default value is empty, indicating that disk cache is not enabled | +| global.performance.sharedDiskSize | int | `0` | Upper limit of the shared disk, the unit is mb | +| global.performance.sharedDiskArenaPerTenant | int | `8` | The number of disk cache Arena for each tenant. Multiple arenas can improve the performance of shared disk allocation for the first time, but each arena will use one more fd. The valid range is 0 to 32 | ### AK/SK Configurations @@ -485,5 +500,6 @@ global: | Configuration | Type | Default | Description | |-----|------|---------|-------------| | global.annotations | object | `{}` | Kubernetes meta annotation | -| global.enableNonPreemptive | bool | `false` | Configure priorityClass. If the value is false, the default priorityClass is system-cluster-critical. If the value is true, a priorityClass with preemptionPolicy Never is created. | +| global.enableNonPreemptive | bool | `false` | Configure priorityClass. If the value is false, the default priorityClass is system-cluster-critical. If the value is true, a priorityClass with preemptionPolicy Never is created | | global.fsGid | string | `"1002"` | fsGroup configuratio. All processes of the container are also part of the supplementary group ID | +| global.rollingUpdateTimeoutS | int | `1800` | Maximum duration of the rolling upgrade, default value is 1800 seconds | \ No newline at end of file diff --git a/docs/source_zh_cn/appendix/hugepage_guide.md b/docs/source_zh_cn/appendix/hugepage_guide.md new file mode 100644 index 0000000..7f178ed --- /dev/null +++ b/docs/source_zh_cn/appendix/hugepage_guide.md @@ -0,0 +1,80 @@ +# 大页内存配置指南 +适用系统:openEuler 22.03 / Ubuntu 22.04 及以上,内核 ≥ 5.10 + +重要性与风险提示 +为何需要配置大页内存:大页内存能显著减少TLB未命中,提升内存访问效率,对数据密集型应用至关重要。 +风险:错误配置可能导致系统内存不足、内核OOM或系统不稳定。请遵循以下步骤做前置检查,以2MB单个大页*1024个大页数量为例。 + +## 数据系统配置大页内存步骤 +### 第一步:内存资源评估 +1. 检查当前可用内存 + ```bash + grep MemAvailable /proc/meminfo + ``` +要求: MemAvailable 的值 必须大于 3 GB。这是为了给大页分配和系统正常运行预留足够的安全边界。 + +如果不足:关闭其他占用内存的应用程序,或考虑在系统空闲时进行操作。 + +2. 检查内存碎片程度 + ```bash + cat /proc/buddyinfo + ``` +找到您架构对应的行(如 Node 0, DMA32 或 Normal)。 +从左到右,数字代表连续页块的数量。需要找到 order 为 9 的连续页块(因为 2^9 * 4KB = 2MB)。 +简易判断:查看第10列(从0开始计数)的数字。如果这个数字 大于 100,则认为碎片化程度较低,分配成功率高。如果很小(如 < 10),则分配可能失败或非常缓慢。 + +若节点内存分布不均匀可定向扩池,或先整理内存再扩池。 + +### 第二步:执行动态分配 +方法A:适用于单路(UMA)或简单环境 +分配大页(需要root权限) + +```bash +# 分配1280个大页(为系统预留256个) +sudo su -c "echo 1280 > /proc/sys/vm/nr_hugepages" +``` + +方法B:适用于多路服务器(NUMA架构) +1. 查看NUMA节点信息 + + ```bash + lscpu | grep NUMA + ``` + +2. 在指定节点上分配 +在node0分配768个,node1分配512个 + ```bash + sudo su -c "echo 768 > /sys/devices/system/node/node0/hugepages/hugepages-2048kB/nr_hugepages" + sudo su -c "echo 512 > /sys/devices/system/node/node1/hugepages/hugepages-2048kB/nr_hugepages" + ``` + +### 第三步:验证分配结果 +分配命令是异步的,内核需要时间整理内存。请等待并验证。 +1. 等待并检查: + + ```bash + # 等待30秒内分配完成 + timeout 30 bash -c 'while [[ $(grep HugePages_Total /proc/meminfo | awk "{print \$2}") -lt 1280 ]]; do sleep 1; done' + ``` + +2. 查看详细信息: + ```bash + grep -i huge /proc/meminfo + ``` + +关注 HugePages_Total、HugePages_Free 和 HugePages_Rsvd。 +成功标志:HugePages_Total 达到或接近目标值(1280)。 + +### 第四步:部署数据系统worker开启大页内存 +启动数据系统时将enable_huge_tlb配置选项设置为true。 +示例: + ```bash + dscli start -w --worker_address "${host}:${worker_port}" --etcd_address "${host}:${etcd_port}" --shared_memory_size_mb {shared_memory} --enable_huge_tlb true + ``` + +### 第五步:确认应用在使用大页 +数据系统部署后再次运行: + ```bash + grep -i huge /proc/meminfo + ``` +成功标志: HugePages_Free 的数量明显减少,HugePages_Rsvd 可能增加。这表明您的大页已经被应用成功申请和使用。 \ No newline at end of file diff --git a/docs/source_zh_cn/appendix/log_guide.md b/docs/source_zh_cn/appendix/log_guide.md index 60c6dbb..76c9612 100644 --- a/docs/source_zh_cn/appendix/log_guide.md +++ b/docs/source_zh_cn/appendix/log_guide.md @@ -27,11 +27,11 @@ openYuanrong datasystem 不同模块日志分类如下表所示: | 序号 | 日志 | 日志格式 | |-----|-----|------------------------| -| 1 | 运行日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| az_name \| message | -| 2 | 访问日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| az_name \| status_code \| action \| cost \| data size \| request param\| response param -| 3 | 访问第三方日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| az_name \| status_code \| action \| cost \| data size \| request param\| response param -| 4 | 资源日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| az_name \| shm_info \| spill_disk_info \| client nums \| object nums \| object total datasize \| WorkerOcService threadpool \| WorkerWorkerOcService threadpool \| MasterWokrerOcService threadpool \| MasterOcService threadpool \| write ETCD queue \| ETCDrequest success rate \| OBSrequest success rate \| Master AsyncTask threadpool | -| 5 | 容器运行日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| az_name \| message | +| 1 | 运行日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| cluster_name \| message | +| 2 | 访问日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| cluster_name \| status_code \| action \| cost \| data size \| request param\| response param +| 3 | 访问第三方日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| cluster_name \| status_code \| action \| cost \| data size \| request param\| response param +| 4 | 资源日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| cluster_name \| shm_info \| spill_disk_info \| client nums \| object nums \| object total datasize \| WorkerOcService threadpool \| WorkerWorkerOcService threadpool \| MasterWokrerOcService threadpool \| MasterOcService threadpool \| write ETCD queue \| ETCDrequest success rate \| OBSrequest success rate \| Master AsyncTask threadpool | +| 5 | 容器运行日志 | Time \| level \| filename \| pod_name \| pid:tid \| trace_id \| cluster_name \| message | ### 日志字段 @@ -43,7 +43,7 @@ openYuanrong datasystem 不同模块日志分类如下表所示: | pod_name | 128 | 输出当前worker所属的POD名称,超出长度则截断。示例:ds-worker-hs5qm | | pid:tid | 11 | 该日志所属的进程ID和线程ID。进程号最大值为32757,该字段最大长度11.示例:9:177 | | traceid | 36 | 请求的traceid。 | -| az_name | 128 | 输出日志的组件名,最大长度为128,超出长度则截断。示例:ds-worker。 | +| cluster_name | 128 | 输出日志的组件名,最大长度为128,超出长度则截断。示例:ds-worker。 | | Message | 1024 | 自定义消息内容 | | status_code | 5 | 该请求的状态,不同消息类型状态值不一样。SDK/datasystem_worker 访问日志,0表示成功,其他表示失败。 | | action | 64 | 表示该请求所访问的接口名称。约定前缀:SDK接口:DS_STATE_CLINET、DS_OBJECT_CLIENT,Worker接口:DS_OBJECT_POSIX,ETCD:DS_ETCD,HTTP请求:POST {url path} | diff --git a/docs/source_zh_cn/deployment/dscli.md b/docs/source_zh_cn/deployment/dscli.md index 820ff44..82653e6 100644 --- a/docs/source_zh_cn/deployment/dscli.md +++ b/docs/source_zh_cn/deployment/dscli.md @@ -399,7 +399,7 @@ openYuanrong datasystem单机卸载依赖 [dscli stop](#dscli-stop) 命令: > **注意事项**: > > - 在不同节点执行 dscli start 命令时,需要保证连接的是同一个ETCD,即 `--etcd_address` 的值需要保持一致。 - > - 如果涉及到需要指定 az name 的情况,即需要传 `--az_name` 参数,那么 `--az_name` 的值也需要保持一致。 + > - 如果涉及到需要指定 cluster name 的情况,即需要传 `--cluster_name` 参数,那么 `--cluster_name` 的值也需要保持一致。 #### 多机卸载 @@ -658,6 +658,8 @@ dscli collect_log --cluster_config_path ./cluster_config.json | zmq_client_io_context | int | `5` | ZMQ客户端性能优化参数,其数值与系统吞吐量正相关,取值范围:[1, 32] | | zmq_chunk_sz | int | `1048576` | 并行负载分块大小配置(以字节为单位) | | max_rpc_session_num | int | `2048` | 单个datasystem-worker最大可缓存会话数,取值范围:[512, 10,000] | +| remote_send_thread_num | int | `8` | 配置服务端用于将元素发送到远程工作线程的线程数量 | +| stream_idle_time_s | int | `300` | 配置流的空闲时间。默认值为300秒(5分钟) | #### ETCD相关配置 @@ -703,6 +705,7 @@ dscli collect_log --cluster_config_path ./cluster_config.json | log_monitor_exporter | string | `"harddisk"` | 指定观测日志导出类型,当前仅支持按 `harddisk` 类型导出观测数据,即将观测数据保存到 `logDir` 路径下 | | log_monitor_interval_ms | int | `10000` | 观测日志收集导出的间隔时间(以毫秒为单位) | | minloglevel | int | `0` | 设置记录冗余日志的最低级别,低于这个级别的日志不会被记录 | +| logfile_mode | int | `416` | 设置日志文件模式/权限,值为八进制数 | #### 二级缓存相关配置 @@ -716,12 +719,13 @@ dscli collect_log --cluster_config_path ./cluster_config.json | obs_bucket | string | `""` | 对象存储服务(OBS) 桶的名称 | | obs_https_enabled | bool | `false` | 是否启用HTTPS连接对象存储服务(OBS),默认为HTTP | | sfs_path | string | `""` | 挂载的SFS路径 | +| enable_cloud_service_token_rotation | bool | `false` | 启用OBS客户端使用临时令牌访问OBS,令牌过期后,获取新的令牌并重新连接OBS | #### AZ相关配置 | 配置项 | 类型 | 默认值 | 描述 | |-----|------|---------|-------------| -| other_az_names | string | `""` | 指定其他可用区的名称,如果需要指定多个可用区通过','进行分隔 | +| other_cluster_names | string | `""` | 指定其他可用区的名称,如果需要指定多个可用区通过','进行分隔 | | cross_az_get_data_from_worker | bool | `true` | 是否优先尝试从其他可用区的datasystem-worker获取数据。如果为 `false`,则将直接从二级缓存中检索数据 | | cross_az_get_meta_from_worker | bool | `false` | 是否从其他可用区的datasystem-worker获取元数据,如果为 `false`,则从本地可用区获取元数据 | @@ -732,6 +736,8 @@ dscli collect_log --cluster_config_path ./cluster_config.json | rocksdb_store_dir | string | `"./yr_datasystem/rocksdb"` | 配置元数据持久化目录,元数据通过RocksDB持久化在磁盘中 | | rocksdb_background_threads | int | `16` | RocksDB的后台线程数,用于元数据的刷盘和压缩 | | rocksdb_max_open_file | int | `128` | RocksDB可使用的最大打开文件个数 | +| rocksdb_write_mode | string | `async` | 配置元数据写入RocksDB的方式,支持不写、同步和异步写入,默认值为`async`。可选值包括:'none'(不写)、'sync'(同步)、'async'(异步) | +| enable_meta_replica | bool | `false` | 控制是否启用多个元数据副本 | #### 可靠性相关配置 @@ -746,6 +752,7 @@ dscli collect_log --cluster_config_path ./cluster_config.json | enable_hash_ring_self_healing | bool | `false` | 是否启用哈希环自愈功能,如果该值为 `true`,当哈希环状态异常时会启用自愈修复哈希环 | | add_node_wait_time_s | int | `60` | 新节点加入哈希环的等待超时时间 | | auto_del_dead_node | bool | `true` | 是否启用死亡节点自动清理功能,当该值为 `true` 时,会将死亡节点剔除出集群管理,并触发被动缩容 | +| enable_distributed_master | bool | `true` | 是否启用分布式主节点,默认值为true | #### 优雅退出相关配置 @@ -766,6 +773,17 @@ dscli collect_log --cluster_config_path ./cluster_config.json | arena_per_tenant | int | `16` | 每个租户的共享内存分配器数量。多分配器可以提高第一次分配共享内存的性能,但每个分配器会多使用一个fd,导致fd资源使用量上升。取值范围:[1, 32] | | memory_reclamation_time_second | int | `600` | 释放后的内存回收时间,未回收的内存可以提供给下次分配复用,提升分配效率 | | async_delete | bool | `false` | 是否异步删除对象,如果设置为 `true` 时,删除对象数据是个异步的过程,客户端不需要等待所有数据副本删除完成即可返回 | +| enable_p2p_transfer | bool | `false` | 是否开启异构对象传输协议支持点对点传输 | +| enable_worker_worker_batch_get | bool | `false` | 是否开启worker到worker的对象数据批量获取,默认值为false | +| enable_urma | bool | `false` | 是否开启Urma以实现对象worker之间的数据传输 | +| urma_connection_size | int | `16` | jfs和jfr对的数量 | +| urma_event_mode | bool | `false` | 是否使用中断模式轮询完成事件 | +| urma_poll_size | int | `8` | 一次可轮询的完整记录数量,该设备最多可轮询16条记录 | +| urma_register_whole_arena | bool | `true` | 是否在初始化时将整个arena注册为一个段,如果设置为`false`,将每个对象分别注册为一个段 | +| oc_shm_transfer_threshold_kb | int | `500` | 在客户端和worker之间通过共享内存传输对象数据的阈值,单位为KB | +| shared_disk_arena_per_tenant | int | `8` | 每个租户的磁盘缓存区域数量,多个区域可以提高首次共享磁盘分配的性能,但每个区域会多占用一个文件描述符(fd)。取值范围:[0, 32] | +| shared_disk_directory | sting | `""` | 磁盘缓存数据存放目录,默认为空,表示未启用磁盘缓存 | +| shared_disk_size_mb | int | `0` | 共享磁盘的大小上限,单位为MB,默认为0,表示未启用磁盘缓存 | #### AK/SK相关配置 diff --git a/docs/source_zh_cn/deployment/k8s_configuration.md b/docs/source_zh_cn/deployment/k8s_configuration.md index 69d7932..deb14d1 100644 --- a/docs/source_zh_cn/deployment/k8s_configuration.md +++ b/docs/source_zh_cn/deployment/k8s_configuration.md @@ -153,6 +153,8 @@ global: | global.rpc.zmqClientIoContext | int | `5` | ZMQ客户端性能优化参数,其数值与系统吞吐量正相关,取值范围:[1, 32] | | global.rpc.zmqChunkSz | int | `1048576` | 并行负载分块大小配置(以字节为单位) | | global.rpc.maxRpcSessionNum | int | `2048` | 单个datasystem-worker最大可缓存会话数,取值范围:[512, 10,000] | +| global.rpc.streamIdleTimes | int | `300` | 配置流的空闲时间。默认值为300秒(5分钟) | +| global.rpc.remoteSendThreadNum | int | `8` | 配置服务端用于将元素发送到远程工作线程的线程数量 | **样例**: 配置一个Unix Domain Socket路径为 "/home/uds",并使用31501作为openYuanrong datasystem DaemonSet的监听端口号 @@ -356,6 +358,7 @@ global: | global.metadata.rocksdbStoreDir | string | `"/home/sn/datasystem/rocksdb"` | 配置元数据持久化目录,元数据通过RocksDB持久化在磁盘中 | | global.metadata.rocksdbBackgroundThreads | int | `16` | RocksDB的后台线程数,用于元数据的刷盘和压缩 | | global.metadata.rocksdbMaxOpenFile | int | `128` | RocksDB可使用的最大打开文件个数 | +| global.metadata.rocksdbWriteMode | string | `async` | 配置元数据写入RocksDB的方式,支持不写、同步和异步写入,默认值为`async`。可选值包括:'none'(不写)、'sync'(同步)、'async'(异步) | ### 可靠性相关配置 @@ -372,6 +375,7 @@ global: | global.reliability.livenessProbeTimeoutS | int | `150` | Kubernetes 存活探针超时时间配置 | | global.reliability.addNodeWaitTimeS | int | `60` | 新节点加入哈希环的等待超时时间 | | global.reliability.autoDelDeadNode | bool | `true` | 是否启用死亡节点自动清理功能,当该值为 `true` 时,会将死亡节点剔除出集群管理,并触发被动缩容 | +| global.reliability.enableDistributedMaster | bool | `true` | 是否启用分布式主节点,默认值为true | ### 优雅退出相关配置 @@ -394,6 +398,17 @@ global: | global.performance.arenaPerTenant | int | `16` | 每个租户的共享内存分配器数量。多分配器可以提高第一次分配共享内存的性能,但每个分配器会多使用一个fd,导致fd资源使用量上升。取值范围:[1, 32] | | global.performance.memoryReclamationTimeSecond | int | `600` | 释放后的内存回收时间,未回收的内存可以提供给下次分配复用,提升分配效率 | | global.performance.asyncDelete | bool | `false` | 是否异步删除对象,如果设置为 `true` 时,删除对象数据是个异步的过程,客户端不需要等待所有数据副本删除完成即可返回 | +| global.performance.enableP2pTransfer | bool | `false` | 是否开启异构对象传输协议支持点对点传输 | +| global.performance.enableWorkerWorkerBatchGet | bool | `false` | 是否开启worker到worker的对象数据批量获取,默认值为false | +| global.performance.ocShmTransferThresholdKB | int | `500` | 在客户端和worker之间通过共享内存传输对象数据的阈值,单位为KB | +| global.performance.enableUrma | bool | `false` | 是否开启Urma以实现对象worker之间的数据传输 | +| global.performance.urmaPollSize | int | `8` | 一次可轮询的完整记录数量,该设备最多可轮询16条记录 | +| global.performance.urmaRegisterWholeArena | bool | `true` | 是否在初始化时将整个arena注册为一个段,如果设置为`false`,将每个对象分别注册为一个段 | +| global.performance.urmaConnectionSize | int | `16` | jfs和jfr对的数量 | +| global.performance.urmaEventMode | bool | `false` | 是否使用中断模式轮询完成事件 | +| global.performance.sharedDiskDirectory | string | `""` | 磁盘缓存数据存放目录,默认为空,表示未启用磁盘缓存 | +| global.performance.sharedDiskSize | int | `0` | 共享磁盘的大小上限,单位为MB,默认为0,表示未启用磁盘缓存 | +| global.performance.sharedDiskArenaPerTenant | int | `8` | 每个租户的磁盘缓存区域数量,多个区域可以提高首次共享磁盘分配的性能,但每个区域会多占用一个文件描述符(fd)。取值范围:[0, 32] | ### AK/SK相关配置 @@ -491,4 +506,5 @@ global: |-----|------|---------|-------------| | global.annotations | object | `{}` | Kubernetes 元注解 | | global.enableNonPreemptive | bool | `false` | 配置priorityClass。如果该值为false,则默认priorityClass为system-cluster-key。如果为true,则会创建一个preemptionPolicy Never的priorityClass | -| global.fsGid | string | `"1002"` | fsGroup配置。容器的所有进程也是附加组ID的一部分 | \ No newline at end of file +| global.fsGid | string | `"1002"` | fsGroup配置。容器的所有进程也是附加组ID的一部分 | +| global.rollingUpdateTimeoutS | int | `1800` | 滚动升级的最大持续时间,默认值为1800秒 | \ No newline at end of file diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_mget.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_mget.rst index 0d8efd6..a0955fe 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_mget.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_mget.rst @@ -3,7 +3,9 @@ datasystem.DsTensorClient.dev_mget .. py:method:: datasystem.DsTensorClient.dev_mget(keys, tensors, sub_timeout_ms) - 获取 device 中的数据,并写入到 tensors 的 Tensor 中。数据通过 device to device 通道直接传输。dev_mset 和 dev_mget 需配套使用。 + 获取 device 中的数据,并写入到 tensors 的 Tensor 中。数据通过 device to device 通道直接传输。 + + dev_mset 和 dev_mget 需配套使用。dev_mset 和 dev_mget 传入的 Device 内存地址不能归属于同一张 NPU 卡。 dev_mget 后不会自动删除异构对象,如对象不再使用,可调用 dev_local_delete 或 dev_delete 删除。 diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_recv.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_recv.rst index e638e94..044c397 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_recv.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_recv.rst @@ -3,7 +3,9 @@ datasystem.DsTensorClient.dev_recv .. py:method:: datasystem.DsTensorClient.dev_recv(keys, tensors) - 订阅发布到数据系统的异构对象,并接收数据写入 tensors。数据通过 device to device 通道直接传输。dev_send 和 dev_recv 需配套使用。 + 订阅发布到数据系统的异构对象,并接收数据写入 tensors。数据通过 device to device 通道直接传输。 + + dev_send 和 dev_recv 需配套使用。dev_send 和 dev_recv 传入的 Device 内存地址不能归属于同一张 NPU 卡。 通过 dev_recv 获取数据成功后,数据系统会自动删除此异构对象,不再管理此对象对应的 device 内存。 diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_send.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_send.rst index 08ecf2e..7a0e884 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_send.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.dev_send.rst @@ -3,7 +3,9 @@ datasystem.DsTensorClient.dev_send .. py:method:: datasystem.DsTensorClient.dev_send(keys, tensors) - 将 device 上的内存发布为数据系统的异构对象。发布后的异构对象可通过 dev_recv 获取。dev_send 和 dev_recv 需配套使用。 + 将 device 上的内存发布为数据系统的异构对象。发布后的异构对象可通过 dev_recv 获取。 + + dev_send 和 dev_recv 需配套使用。dev_send 和 dev_recv 传入的 Device 内存地址不能归属于同一张 NPU 卡。 通过 dev_recv 获取数据成功后,数据系统会自动删除此异构对象,不再管理此对象对应的 device 内存。 diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.get_page_attn_layerwise_d2d.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.get_page_attn_layerwise_d2d.rst index a7dd179..d72fa5f 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.get_page_attn_layerwise_d2d.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.get_page_attn_layerwise_d2d.rst @@ -3,7 +3,11 @@ datasystem.DsTensorClient.get_page_attn_layerwise_d2d .. py:method:: datasystem.DsTensorClient.get_page_attn_layerwise_d2d(keys, layer_tensors, block_ids) - 将 PagedAttention 的层级 Tensor 作为数据系统的异构对象放在设备上。put_page_attn_layerwise_d2d 和 get_page_attn_layerwise_d2d 需配套使用。 + 将 PagedAttention 的层级 Tensor 作为数据系统的异构对象放在设备上。 + + put_page_attn_layerwise_d2d 和 get_page_attn_layerwise_d2d 需配套使用。 + + put_page_attn_layerwise_d2d 和 get_page_attn_layerwise_d2d 传入的 Device 内存地址不能归属于同一张 NPU 卡。 通过 get_page_attn_layerwise_d2d 获取数据成功后,数据系统会自动删除此异构对象,不再管理此对象对应的 device 内存。 diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.put_page_attn_layerwise_d2d.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.put_page_attn_layerwise_d2d.rst index 968f4dd..13be537 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.put_page_attn_layerwise_d2d.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.DsTensorClient.put_page_attn_layerwise_d2d.rst @@ -3,7 +3,11 @@ datasystem.DsTensorClient.put_page_attn_layerwise_d2d .. py:method:: datasystem.DsTensorClient.put_page_attn_layerwise_d2d(keys, layer_tensors, block_ids) - 将 PagedAttention 的层级 Tensor 发布为数据系统的异构对象。发布后的异构对象可通过 get_page_attn_layerwise_d2d 获取。put_page_attn_layerwise_d2d 和 get_page_attn_layerwise_d2d 需配套使用。 + 将 PagedAttention 的层级 Tensor 发布为数据系统的异构对象。发布后的异构对象可通过 get_page_attn_layerwise_d2d 获取。 + + put_page_attn_layerwise_d2d 和 get_page_attn_layerwise_d2d 需配套使用。 + + put_page_attn_layerwise_d2d 和 get_page_attn_layerwise_d2d 传入的 Device 内存地址不能归属于同一张 NPU 卡。 通过 get_page_attn_layerwise_d2d 获取数据成功后,数据系统会自动删除此异构对象,不再管理此对象对应的 device 内存。 diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mget.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mget.rst index 93bf45e..b090252 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mget.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mget.rst @@ -5,6 +5,8 @@ datasystem.hetero_client.HeteroClient.dev_mget 获取 device 中的数据,并写入到 data_blob_list 中。数据通过 device to device 通道直接传输。 + dev_mset 和 dev_mget 需配套使用。dev_mset 和 dev_mget 传入的 Device 内存地址不能归属于同一张 NPU 卡。 + dev_mget 后不会自动删除异构对象,如对象不再使用,可调用 dev_local_delete 或 dev_delete 删除。 在执行 dev_mget 过程中,执行了 dev_mset 的进程不能退出,否则 dev_mget 会失败。 diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mset.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mset.rst index c989db8..80a1fc1 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mset.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.dev_mset.rst @@ -3,7 +3,9 @@ datasystem.hetero_client.HeteroClient.dev_mset .. py:method:: datasystem.hetero_client.HeteroClient.dev_mset(keys, data_blob_list) - 通过数据系统缓存 Device 上的数据,将 data_blob_list 对应的 key 的元数据写入数据系统,可供其他 client 访问。dev_mset 和 dev_mget 需配套使用。 + 通过数据系统缓存 Device 上的数据,将 data_blob_list 对应的 key 的元数据写入数据系统,可供其他 client 访问。 + + dev_mset 和 dev_mget 需配套使用。dev_mset 和 dev_mget 传入的 Device 内存地址不能归属于同一张 NPU 卡。 dev_mget 后不会自动删除异构对象,如对象不再使用,可调用 dev_local_delete 或 dev_delete 删除。 diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.exist.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.exist.rst new file mode 100644 index 0000000..c899d22 --- /dev/null +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.exist.rst @@ -0,0 +1,16 @@ +datasystem.hetero_client.HeteroClient.exist +============================================== + +.. py:method:: datasystem.hetero_client.HeteroClient.exist(self, keys) + + 检查给定的键在数据系统中是否存在。 + + 参数: + - **keys** (list) - 待查询的键列表。约束:传入的key的数量不能超过1万。 + + 返回: + - **exists** (list) - 对应key的存在性。 + + 异常: + - **TypeError** - 如果输入参数无效,将抛出类型错误。 + - **RuntimeError** - 如果查询键是否存在失败,将抛出运行时错误。 \ No newline at end of file diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.rst index d4610c3..faaab27 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.hetero_client.HeteroClient.rst @@ -56,6 +56,8 @@ datasystem.hetero_client.HeteroClient - 生成一个带数据系统 Worker UUID 的 key。 * - :doc:`get_meta_info ` - 获取keys 对应的元数据信息。 + * - :doc:`exist ` + - 检查给定的键在数据系统中是否存在。 .. toctree:: :maxdepth: 1 @@ -76,3 +78,4 @@ datasystem.hetero_client.HeteroClient datasystem.hetero_client.HeteroClient.async_dev_delete datasystem.hetero_client.HeteroClient.generate_key datasystem.hetero_client.HeteroClient.get_meta_info + datasystem.hetero_client.HeteroClient.exist diff --git a/docs/source_zh_cn/development-guide/api/python/datasystem.kv_client.KVClient.exist.rst b/docs/source_zh_cn/development-guide/api/python/datasystem.kv_client.KVClient.exist.rst index 5b0f6e7..6217d04 100644 --- a/docs/source_zh_cn/development-guide/api/python/datasystem.kv_client.KVClient.exist.rst +++ b/docs/source_zh_cn/development-guide/api/python/datasystem.kv_client.KVClient.exist.rst @@ -3,7 +3,7 @@ datasystem.kv_client.KVClient.exist .. py:method:: datasystem.kv_client.KVClient.exist(self, keys) - 批量查询一组键是否存在,并返回每个键的存在性状态。 + 批量查询一组键是否存在。 参数: - **keys** (str) - 待查询的键列表,最大支持10000个键。 diff --git a/docs/source_zh_cn/development-guide/example/hetero.md b/docs/source_zh_cn/development-guide/example/hetero.md index 9c3237b..a1540d7 100644 --- a/docs/source_zh_cn/development-guide/example/hetero.md +++ b/docs/source_zh_cn/development-guide/example/hetero.md @@ -16,6 +16,7 @@ openYuanrong datasystem (下文中称为数据系统)的 Hetero 语义中,基 DevPublish / DevSubscribe 为异步接口,提供了返回 Future 供用户获取执行结果,每个key返回一个 Future。当 Future::Get 获取到结果为 OK 时,表示数据已经被对端接收成功。 > **注意**: +> DevPublish / DevSubscribe 传入的 Device 内存地址不能归属于同一张 NPU 卡。 > 在执行 DevSubscribe 过程中,执行了 DevPublish 的进程不能退出,否则 DevSubscribe 会失败。 > 在key,devBlobList内存地址映射关系均一致的情况下,DevPublish在同进程支持重试。 > DevSubscribe单Key的订阅超时时间为20s,多key为60s。 @@ -189,7 +190,8 @@ void HeteroDevSubscribe() **DevMSet / DevMGet**:数据缓存语义,数据生成端执行 DevMSet 将 HBM 数据发布到数据系统,数据接收端申请 HBM 内存后,执行 DevMGet 接口读取数据。当数据被读取后,数据系统不会删除对象,该数据可被反复读取。数据使用完成后需要调用 DevLocalDelete/DevDelete 删除对象。 -> **注意**: +> **注意**: +> DevMSet / DevMGet 传入的 Device 内存地址不能归属于同一张 NPU 卡。 > 在执行 DevMGet 过程中,执行了 DevMSet 的进程不能退出,否则 DevMGet 会失败。 > 在key,devBlobList内存地址映射关系均一致的情况下,DevMGet在同进程支持重试。 diff --git a/docs/source_zh_cn/index.rst b/docs/source_zh_cn/index.rst index 9e0f328..d935954 100644 --- a/docs/source_zh_cn/index.rst +++ b/docs/source_zh_cn/index.rst @@ -9,7 +9,7 @@ openYuanrong datasystem 的主要特性包括: - **NPU 间高效数据传输**:将 NPU 的 HBM 抽象为异构对象,自动协调 NPU 间 HCCL 收发顺序,实现简单易用的卡间数据异步并发传输。并支持P2P传输负载均衡策略,充分利用卡间链路带宽。 - **灵活的生命周期管理**:支持设置 TTL、LRU 缓存淘汰以及 delete 接口等多种生命周期管理策略,数据生命周期既可由数据系统管理,也可交由上层应用管理,提供更高的灵活性。 - **热点数据多副本**:数据跨节点读取时自动在本地保存副本,支撑热点数据高效访问。本地副本使用 LRU 策略自动淘汰。 -- **多种数据可靠性策略**:支持 write_through、wirte_back 及 none 多种持久化策略,满足不同场景的数据可靠性需求。 +- **多种数据可靠性策略**:支持 write_through、write_back 及 none 多种持久化策略,满足不同场景的数据可靠性需求。 - **数据一致性**:支持 Causal 及 PRAM 两种数据一致性模型,用户可按需选择,实现性能和数据一致性的平衡。 - **数据发布订阅**:支持数据订阅发布,解耦数据的生产者(发布者)和消费者(订阅者),实现数据的异步传输与共享。 - **高可靠高可用**:支持分布式元数据管理,实现系统水平线性扩展。支持元数据可靠性,支持动态资源伸缩自动迁移数据,实现系统高可用。 diff --git a/example/README.md b/example/README.md index 69fad07..1cc33b6 100644 --- a/example/README.md +++ b/example/README.md @@ -1,24 +1,15 @@ This is an example for how to invoke the datasystem client api. -# Build the example -1. Modify the example/CMakeLists.txt line 10 +1. Build example. ```bash -set(DS_CLIENT_DIR ${DS_BUILD_OUTPUT_PATH}) +cd example/cpp +mkdir build +cd build +cmake .. +make ``` -Set the DS_BUILD_OUTPUT_PATH to the datasystem build output path. -2. Modify the example/CMakeLists.txt line 16 +2. Run example. ```bash -set(THIRD_DIR ${DEPENDENCE_DIR}) -``` -Set the DEPENDENCE_DIR to the log dependency. - -3. Build -Build and compile the example. -```bash -# cd example -# mkdir build -# cd build -# cmake .. -# make -``` +bash run_example.sh +``` \ No newline at end of file diff --git a/example/CMakeLists.txt b/example/cpp/CMakeLists.txt similarity index 43% rename from example/CMakeLists.txt rename to example/cpp/CMakeLists.txt index fdf95c9..5100987 100644 --- a/example/CMakeLists.txt +++ b/example/cpp/CMakeLists.txt @@ -6,7 +6,24 @@ project(ds_example LANGUAGES CXX C) set(CMAKE_CXX_STANDARD 17) -file(STRINGS "${CMAKE_CURRENT_SOURCE_DIR}/../VERSION" version) +set(BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../..") +file(STRINGS "${BASE_DIR}/VERSION" DS_VERSION) + +set(INSTALL_DIR ${BASE_DIR}/output) +set(BUILD_HETERO "on") + +if(EXISTS "${BASE_DIR}/config.cmake") + include(${BASE_DIR}/config.cmake) +endif() + +if(NOT EXISTS "${INSTALL_DIR}/sdk/cpp") + file(ARCHIVE_EXTRACT + INPUT ${INSTALL_DIR}/yr-datasystem-v${version}.tar.gz + DESTINATION ${INSTALL_DIR} + ) +endif() + +SET(CMAKE_PREFIX_PATH "${INSTALL_DIR}/sdk/cpp") find_package(Datasystem ${Datasystem_version} REQUIRED) @@ -24,13 +41,19 @@ if (BUILD_HETERO) set(Ascend_ROOT /usr/local/Ascend/ascend-toolkit/latest) endif() include_directories(${Ascend_ROOT}/include) + + add_executable(hetero_client_example hetero_client_example.cpp) + target_link_libraries(hetero_client_example datasystem ${Ascend_ROOT}/lib64/libascendcl.so) endif() -add_executable(ds_example src/ds_example.cpp) -target_link_libraries(ds_example datasystem) +add_executable(datasystem_example datasystem_example.cpp) +target_link_libraries(datasystem_example datasystem) + +add_executable(object_client_example object_client_example.cpp) +target_link_libraries(object_client_example datasystem) -add_executable(object_example src/object_cache/object_example.cpp) -target_link_libraries(object_example datasystem) +add_executable(kv_client_example kv_client_example.cpp) +target_link_libraries(kv_client_example datasystem) -add_executable(kv_example src/kv_cache/kv_example.cpp) -target_link_libraries(kv_example datasystem) +add_executable(stream_client_example stream_client_example.cpp) +target_link_libraries(stream_client_example datasystem) diff --git a/example/src/ds_example.cpp b/example/cpp/datasystem_example.cpp similarity index 89% rename from example/src/ds_example.cpp rename to example/cpp/datasystem_example.cpp index 57ca25a..8d35e69 100644 --- a/example/src/ds_example.cpp +++ b/example/cpp/datasystem_example.cpp @@ -17,13 +17,11 @@ /** * Description: The ds client example. */ +#include "datasystem/datasystem.h" #include #include -#include "datasystem/datasystem.h" -#include "datasystem/context/context.h" - using datasystem::ConnectOptions; using datasystem::Context; using datasystem::DsClient; @@ -43,7 +41,7 @@ static constexpr int FAILED = -1; static std::shared_ptr dsClient_; static std::shared_ptr buffer_; -static int Write(std::shared_ptr client, std::string writeId, std::string data, bool isSeal) +static int Write(const std::shared_ptr &client, const std::string &writeId, std::string data, bool isSeal) { std::cout << "Writing data to a buffer" << std::endl; std::cout << std::boolalpha << "Immutable data: " << isSeal << std::endl; @@ -76,13 +74,13 @@ static int Write(std::shared_ptr client, std::string writeId, std: return SUCCESS; } -static int Read(std::shared_ptr client, std::string verifyId, std::string verifyData) +static int Read(const std::shared_ptr &client, const std::string &verifyId, const std::string &verifyData) { std::cout << "Reading data from a buffer" << std::endl; std::vector objKeys = { verifyId }; - int64_t timeout = 60; + const int64_t timeoutMs = 60'000; std::vector> buffers; - Status status = client->Get(objKeys, timeout, buffers); + Status status = client->Get(objKeys, timeoutMs, buffers); if (!status.IsOk()) { std::cerr << "Read Get Fail: " << status.ToString() << std::endl; return FAILED; @@ -101,20 +99,21 @@ static int Read(std::shared_ptr client, std::string verifyId, std: return SUCCESS; } -static int Modify(std::shared_ptr client, std::string writeId, std::string updateValue) +static int Modify(const std::shared_ptr &client, const std::string &writeId, + const std::string &updateValue) { std::cout << "Modifying data in the buffer" << std::endl; std::vector objKeys = { writeId }; - int64_t timeout = 60; + const int64_t timeoutMs = 60'000; std::vector> buffers; - Status status = client->Get(objKeys, timeout, buffers); + Status status = client->Get(objKeys, timeoutMs, buffers); if (!status.IsOk()) { std::cerr << "Modify Get Fail: " << status.ToString() << std::endl; return FAILED; } std::cout << "Get: " << status.ToString() << std::endl; auto &buf = buffers[0]; - int newDataSize = updateValue.size(); + auto newDataSize = updateValue.size(); buf->WLatch(); // If the length of the modified buffer is less than the length of the original buffer, the following characters // will remain. @@ -125,7 +124,7 @@ static int Modify(std::shared_ptr client, std::string writeId, std return SUCCESS; } -int RunObjectTest(std::shared_ptr client) +int RunObjectTest(const std::shared_ptr &client) { std::cout << "Run object client test." << std::endl; std::string writeObjectKey = "writeKey"; @@ -150,7 +149,7 @@ int RunObjectTest(std::shared_ptr client) return SUCCESS; } -int RunKVTest(std::shared_ptr kvClient) +int RunKVTest(const std::shared_ptr &kvClient) { std::cout << "Run kv client test." << std::endl; std::string key = "testKey"; diff --git a/example/cpp/hetero_client_example.cpp b/example/cpp/hetero_client_example.cpp new file mode 100644 index 0000000..f4a6d66 --- /dev/null +++ b/example/cpp/hetero_client_example.cpp @@ -0,0 +1,159 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The hetero client example. + */ +#include "datasystem/datasystem.h" + +#include + +#include + +using datasystem::HeteroClient; +using datasystem::Status; +using datasystem::Context; +using datasystem::DeviceBlobList; +using datasystem::Blob; +using datasystem::ConnectOptions; + +static std::shared_ptr client_; + +static std::string DEFAULT_IP = "127.0.0.1"; +static constexpr int DEFAULT_PORT = 9088; +static constexpr int PARAMETERS_NUM = 3; +static constexpr int SUCCESS = 0; +static constexpr int FAILED = -1; +static constexpr int DEVICE_IDX = 0; +static constexpr int SIZE = 10; + +static bool Write() +{ + (void)Context::SetTraceId("write"); + std::string key = "key1"; + std::string data(SIZE, 'x'); + Blob blob; + blob.size = SIZE; + auto aclRc = aclrtMalloc(&blob.pointer, blob.size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST); + if (aclRc != ACL_SUCCESS) { + return false; + } + aclRc = aclrtMemcpy(blob.pointer, blob.size, data.data(), blob.size, ACL_MEMCPY_HOST_TO_DEVICE); + if (aclRc != ACL_SUCCESS) { + return false; + } + DeviceBlobList devSetBlobList; + devSetBlobList.deviceIdx = DEVICE_IDX; + devSetBlobList.blobs = { blob }; + std::vector failedIdList; + auto setRc = client_->DevMSet({ key }, { devSetBlobList }, failedIdList); + if (setRc.IsError()) { + std::cerr << "DevMSet failed: " << setRc.ToString() << std::endl; + return false; + } + std::cout << "DevMSet succeeds." << std::endl; + return true; +} + +static int Read() +{ + (void)Context::SetTraceId("read"); + std::string key = "key1"; + Blob blob; + blob.size = SIZE; + auto aclRc = aclrtMalloc(&blob.pointer, blob.size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST); + if (aclRc != ACL_SUCCESS) { + return false; + } + DeviceBlobList devGetBlobList; + devGetBlobList.deviceIdx = DEVICE_IDX; + devGetBlobList.blobs = { blob }; + std::vector failedIdList; + int subTimeoutMs = 30'000; + std::vector devGetBlobLists = { devGetBlobList }; + auto getRc = client_->DevMGet({ key }, devGetBlobLists, failedIdList, subTimeoutMs); + if (getRc.IsError() || !failedIdList.empty()) { + std::cerr << "DevMGet failed: " << getRc.ToString() << std::endl; + return false; + } + std::cout << "DevMGet succeeds." << std::endl; + return true; +} + +static bool InitAcl() +{ + aclError ret = aclInit(nullptr); + if (ret != ACL_SUCCESS) { + return false; + } + ret = aclrtSetDevice(0); + if (ret != ACL_SUCCESS) { + return false; + } + return true; +} + +static bool Start() +{ + return InitAcl() && Write() && Read(); +} + +int main(int argc, char *argv[]) +{ + const int authParametersNum = 6; + std::string ip; + int port = 0; + int index = 0; + std::string clientPublicKey, clientPrivateKey, serverPublicKey; + + if (argc == 1) { + ip = DEFAULT_IP; + port = DEFAULT_PORT; + } else if (argc == PARAMETERS_NUM) { + ip = argv[++index]; + port = atoi(argv[++index]); + } else if (argc == authParametersNum) { + ip = argv[++index]; + port = atoi(argv[++index]); + clientPublicKey = argv[++index]; + clientPrivateKey = argv[++index]; + serverPublicKey = argv[++index]; + } else { + std::cerr << "Invalid input parameters."; + return FAILED; + } + + ConnectOptions connectOpts{ .host = ip, + .port = port, + .connectTimeoutMs = 3 * 1000, + .clientPublicKey = clientPublicKey, + .clientPrivateKey = clientPrivateKey, + .serverPublicKey = serverPublicKey }; + client_ = std::make_shared(connectOpts); + (void)Context::SetTraceId("init"); + Status status = client_->Init(); + if (status.IsError()) { + std::cerr << "Failed to init hetero client, detail: " << status.ToString() << std::endl; + return FAILED; + } + + if (!Start()) { + std::cerr << "The hetero client example run failed." << std::endl; + return FAILED; + } + (void)aclFinalize(); + return SUCCESS; +} diff --git a/example/src/kv_cache/kv_example.cpp b/example/cpp/kv_client_example.cpp similarity index 98% rename from example/src/kv_cache/kv_example.cpp rename to example/cpp/kv_client_example.cpp index 22e6deb..a85ebc1 100644 --- a/example/src/kv_cache/kv_example.cpp +++ b/example/cpp/kv_client_example.cpp @@ -15,13 +15,12 @@ */ /** - * Description: The kv cache example. + * Description: The kv client example. */ +#include "datasystem/datasystem.h" #include -#include "datasystem/kv_client.h" - using datasystem::ConnectOptions; using datasystem::Context; using datasystem::Optional; diff --git a/example/src/object_cache/object_example.cpp b/example/cpp/object_client_example.cpp similarity index 94% rename from example/src/object_cache/object_example.cpp rename to example/cpp/object_client_example.cpp index 191481a..ed2500a 100644 --- a/example/src/object_cache/object_example.cpp +++ b/example/cpp/object_client_example.cpp @@ -15,13 +15,12 @@ */ /** - * Description: The object cache example. + * Description: The object client example. */ +#include "datasystem/datasystem.h" #include -#include "datasystem/object_client.h" - using datasystem::Buffer; using datasystem::ConnectOptions; using datasystem::CreateParam; @@ -63,9 +62,9 @@ static void Read(int64_t size, bool isSeal) std::cout << "Reading data from a buffer" << std::endl; std::string objectKey = "123456789"; std::vector objKeys = { objectKey }; - int64_t timeout = 60; + const int64_t timeoutMs = 60'000; std::vector> buffers; - Status status = client_->Get(objKeys, timeout, buffers); + Status status = client_->Get(objKeys, timeoutMs, buffers); if (!status.IsOk()) { std::cerr << "Read Get Fail: " << status.ToString() << std::endl; return; @@ -84,9 +83,9 @@ static void Modify(int64_t size) std::cout << "Modifying data in the buffer" << std::endl; std::string objectKey = "123456789"; std::vector objKeys = { objectKey }; - int64_t timeout = 60; + const int64_t timeoutMs = 60'000; std::vector> buffers; - Status status = client_->Get(objKeys, timeout, buffers); + Status status = client_->Get(objKeys, timeoutMs, buffers); if (!status.IsOk()) { std::cerr << "Modify Get Fail: " << status.ToString() << std::endl; return; @@ -94,7 +93,7 @@ static void Modify(int64_t size) std::cout << "Get: " << status.ToString() << std::endl; auto &buf = buffers[0]; std::string newData = "test"; - int newDataSize = newData.size(); + auto newDataSize = newData.size(); buf->WLatch(); buf->MemoryCopy((void *)newData.data(), newDataSize); buf->Publish(); diff --git a/example/cpp/stream_client_example.cpp b/example/cpp/stream_client_example.cpp new file mode 100644 index 0000000..a0f7d76 --- /dev/null +++ b/example/cpp/stream_client_example.cpp @@ -0,0 +1,159 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The stream client example. + */ +#include "datasystem/datasystem.h" + +#include +#include + +using namespace datasystem; + +static std::string DEFAULT_IP = "127.0.0.1"; +static constexpr int DEFAULT_PORT = 9088; +static constexpr int PARAMETERS_NUM = 3; + +static int CreateProducerAndConsumer(std::shared_ptr &client, std::shared_ptr &producer, + std::shared_ptr &consumer) +{ + // Create the producer on the stream + std::string streamName("example1"); + const uint64_t testStreamSize = 64 * 1024 * 1024; // 64M + ProducerConf producerConf; + producerConf.maxStreamSize = testStreamSize; + Status status = client->CreateProducer(streamName, producer, producerConf); + if (status.IsError()) { + std::cerr << "Failed to create producer : " << status.ToString() << std::endl; + return -1; + } + std::cout << "Create producer successfully." << std::endl; + + // Create one subscription with one stream + std::string subName("sub1"); + SubscriptionConfig config(subName, SubscriptionType::STREAM); + status = client->Subscribe(streamName, config, consumer); + if (status.IsError()) { + std::cerr << "Failed to create subscription with one consumer : " << status.ToString() << std::endl; + return -1; + } + std::cout << "Create consumer successfully." << std::endl; + return 0; +} + +static int WriteAndFlush(Producer *producer, std::string &data) +{ + // Write and flush one element + Element element(reinterpret_cast(&data.front()), data.size(), ULONG_MAX); + Status status = producer->Send(element); + if (status.IsError()) { + std::cerr << "Failed to Send one element : " << status.ToString() << std::endl; + return -1; + } + std::cout << "Write one element successfully." << std::endl; + return 0; +} + +static int RecvAndVerify(Consumer *consumer, const std::string &data) +{ + // Read one element + std::vector outElements; + Status status = consumer->Receive(1, 0, outElements); + if (status.IsError()) { + std::cerr << "Failed to Receive one element : " << status.ToString() << std::endl; + return -1; + } + + if (outElements.size() != 1) { + std::cerr << "Should receive one element but receive " << outElements.size() << std::endl; + return -1; + } + if (outElements[0].id != 1) { + std::cerr << "The element id should be 1 but is " << outElements[0].id << std::endl; + return -1; + } + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + if (data != actualData) { + std::cerr << "The element verify failed expect: " << data << ", got:" << actualData << std::endl; + return -1; + } + status = consumer->Ack(outElements[0].id); + if (status.IsError()) { + std::cerr << "Failed to ack one element : " << status.ToString() << std::endl; + return -1; + } + std::cout << "Verify element successfully." << std::endl; + return 0; +} + +int RunExample(const std::string &ip, const int32_t port, const std::string &clientPublicKey, + const std::string &clientPrivateKey, const std::string &serverPublicKey) +{ + ConnectOptions connectOpts{ .host = ip, + .port = port, + .connectTimeoutMs = 60 * 1000, + .clientPublicKey = clientPublicKey, + .clientPrivateKey = clientPrivateKey, + .serverPublicKey = serverPublicKey }; + auto client = std::make_shared(connectOpts); + Status status = client->Init(); + if (status.IsError()) { + std::cerr << "Failed to init stream client : " << status.ToString() << std::endl; + return -1; + } + + std::shared_ptr producer; + std::shared_ptr consumer; + if (CreateProducerAndConsumer(client, producer, consumer)) { + return -1; + } + std::string data = "Hello World"; + if (WriteAndFlush(producer.get(), data)) { + return -1; + } + return RecvAndVerify(consumer.get(), data); +} + +int main(int argc, char *argv[]) +{ + const int authParametersNum = 6; + std::string ip = DEFAULT_IP; + int port = DEFAULT_PORT; + int index = 0; + std::string clientPublicKey; + std::string clientPrivateKey; + std::string serverPublicKey; + if (argc == 1) { + ip = DEFAULT_IP; + port = DEFAULT_PORT; + } else if (argc == PARAMETERS_NUM) { + ip = argv[++index]; + port = atoi(argv[++index]); + } else if (argc == authParametersNum) { + ip = argv[++index]; + port = atoi(argv[++index]); + clientPublicKey = argv[++index]; + clientPrivateKey = argv[++index]; + serverPublicKey = argv[++index]; + } else { + std::cerr << "Invalid input parameters."; + } + + // example call: + // ./stream_example 127.0.0.1 18482 + return RunExample(ip, port, clientPublicKey, clientPrivateKey, serverPublicKey); +} diff --git a/example/python/ds_tensor_client_example.py b/example/python/ds_tensor_client_example.py new file mode 100644 index 0000000..168f189 --- /dev/null +++ b/example/python/ds_tensor_client_example.py @@ -0,0 +1,121 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Hetero client python interface test. +""" + +from __future__ import absolute_import + +import argparse +import logging + +from datasystem import DsTensorClient + +is_torch_exist = True +is_mindspore_exist = True + +try: + import acl + import numpy +except ImportError: + is_torch_exist = False + is_mindspore_exist = False + +try: + import torch + import torch_npu +except ImportError: + is_torch_exist = False + +try: + import mindspore +except ImportError: + is_mindspore_exist = False + + +class DsTensorClientExample(): + """This class shows the SDK usage example of the HeteroClient.""" + + def __init__(self): + parser = argparse.ArgumentParser(description="DsTensorClient python interface Test") + parser.add_argument("--host", required=True, help="The IP of worker service") + parser.add_argument("--port", required=True, type=int, help="The port of worker service") + parser.add_argument("--device_id", type=int, default=0, help="The device id") + args = parser.parse_args() + self._host = args.host + self._port = args.port + self._device_id = args.device_id + + logging.basicConfig(level=logging.INFO) + + def torch_dev_mset_and_dev_mget_example(self): + """test pytorch tensor""" + logging.info("Start executing torch_dev_mset_and_dev_mget_example...") + acl.init() + acl.rt.set_device(self._device_id) + torch_npu.npu.set_device(f'npu:{self._device_id}') + + key = "key" + in_tensors = [torch.rand((2, 3), dtype=torch.float16, device=f'npu:{self._device_id}')] + client = DsTensorClient(self._host, self._port, self._device_id) + client.init() + failed_keys = client.dev_mset([key], in_tensors) + if failed_keys: + raise RuntimeError(f"dev_mset failed, failed keys: {failed_keys}") + + out_tensors = [torch.zeros((2, 3), dtype=torch.float16, device=f'npu:{self._device_id}')] + sub_timeout_ms = 30_000 + failed_keys = client.dev_mget([key], out_tensors, sub_timeout_ms) + if failed_keys: + raise RuntimeError(f"dev_mget failed, failed keys: {failed_keys}") + acl.finalize() + logging.info("Execute torch_dev_mset_and_dev_mget_example success.") + + def mindspore_dev_mset_and_dev_mget_example(self): + """test mindspore tensor""" + logging.info("Start executing mindspore_dev_mset_and_dev_mget_example...") + acl.init() + acl.rt.set_device(self._device_id) + mindspore.set_device(device_target="Ascend", device_id=self._device_id) + + key = "key" + data = numpy.random.rand(2, 3) + in_tensors = [mindspore.Tensor(data, dtype=mindspore.float32) + 0] + client = DsTensorClient(self._host, self._port, self._device_id) + client.init() + failed_keys = client.dev_mset([key], in_tensors) + if failed_keys: + raise RuntimeError(f"dev_mset failed, failed keys: {failed_keys}") + + out_tensors = [mindspore.Tensor(numpy.ones(shape=[2, 3]), dtype=mindspore.float32) + 0] + sub_timeout_ms = 30_000 + failed_keys = client.dev_mget([key], out_tensors, sub_timeout_ms) + if failed_keys: + raise RuntimeError(f"dev_mget failed, failed keys: {failed_keys}") + acl.finalize() + logging.info("Execute mindspore_dev_mset_and_dev_mget_example success.") + + +if __name__ == '__main__': + example = DsTensorClientExample() + excute = False + if is_torch_exist: + example.torch_dev_mset_and_dev_mget_example() + excute = True + if is_mindspore_exist: + example.mindspore_dev_mset_and_dev_mget_example() + excute = True + if not excute: + logging.warning("No examples were executed.") diff --git a/example/python/hetero_client_example.py b/example/python/hetero_client_example.py new file mode 100644 index 0000000..f0ff2b5 --- /dev/null +++ b/example/python/hetero_client_example.py @@ -0,0 +1,75 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Hetero client python interface test. +""" + +from __future__ import absolute_import + +import argparse + +import acl +from datasystem.hetero_client import ( + HeteroClient, + Blob, + DeviceBlobList, +) + + +class HeteroClientExample(): + """This class shows the SDK usage example of the HeteroClient.""" + + def __init__(self): + parser = argparse.ArgumentParser(description="Hetero client python interface Test") + parser.add_argument("--host", required=True, help="The IP of worker service") + parser.add_argument("--port", required=True, type=int, help="The port of worker service") + parser.add_argument("--device_id", type=int, default=0, help="The device id") + args = parser.parse_args() + self._host = args.host + self._port = args.port + self._device_id = args.device_id + + def dev_mset_and_dev_mget_example(self): + """test dev_mset and dev_mget""" + acl.init() + acl.rt.set_device(self._device_id) + client = HeteroClient(self._host, self._port) + client.init() + key = "key" + value = bytes("val", encoding='utf8') + size = len(value) + in_dev_ptr, _ = acl.rt.malloc(size, 0) + acl.rt.memcpy(in_dev_ptr, size, acl.util.bytes_to_ptr(value), size, 1) + in_blob = Blob(in_dev_ptr, size) + in_blob_list = [DeviceBlobList(self._device_id, [in_blob])] + failed_keys = client.dev_mset([key], in_blob_list) + if failed_keys: + raise RuntimeError(f"dev_mset failed, failed keys: {failed_keys}") + + out_dev_ptr, _ = acl.rt.malloc(size, 0) + out_blob = Blob(out_dev_ptr, size) + out_blob_list = [DeviceBlobList(self._device_id, [out_blob])] + sub_timeout_ms = 30_000 + failed_keys = client.dev_mget([key], out_blob_list, sub_timeout_ms) + if failed_keys: + raise RuntimeError(f"dev_mget failed, failed keys: {failed_keys}") + acl.rt.free(in_dev_ptr) + acl.rt.free(out_dev_ptr) + acl.finalize() + + +if __name__ == '__main__': + example = HeteroClientExample() + example.dev_mset_and_dev_mget_example() diff --git a/example/src/python/kv_cache/kv_client_example.py b/example/python/kv_client_example.py similarity index 79% rename from example/src/python/kv_cache/kv_client_example.py rename to example/python/kv_client_example.py index ea4bada..f45493c 100644 --- a/example/src/python/kv_cache/kv_client_example.py +++ b/example/python/kv_client_example.py @@ -13,27 +13,27 @@ # limitations under the License. """ -State cache client python interface Test. +KV client python interface test. """ from __future__ import absolute_import + +import argparse import time from datasystem.kv_client import KVClient -def assert_eq(value, expected_value): - """Compare two values for equality.""" - if value != expected_value: - raise RuntimeError(f"Assert failed, expect {expected_value}, but got {value}") - - -class KVClientExample: +class KVClientExample(): """This class shows the SDK usage example of the KVClient.""" - def __init__(self, host, port): - self._host = host - self._port = port + def __init__(self): + parser = argparse.ArgumentParser(description="KV client python interface Test") + parser.add_argument("--host", required=True, help="The IP of worker service") + parser.add_argument("--port", required=True, type=int, help="The port of worker service") + args = parser.parse_args() + self._host = args.host + self._port = args.port def set_data_example(self): """This function shows the basic usage of set/get/del.""" @@ -44,7 +44,8 @@ class KVClientExample: client.set(key, expected_val) val = client.get([key], True) - assert_eq(val[0], expected_val) + if val[0] != expected_val: + raise RuntimeError(f"Assert failed, expect {expected_val}, but got {val[0]}") client.delete([key]) @@ -56,7 +57,8 @@ class KVClientExample: key = client.set_value(expected_val) val = client.get([key], True) - assert_eq(val[0], expected_val) + if val[0] != expected_val: + raise RuntimeError(f"Assert failed, expect {expected_val}, but got {val[0]}") client.delete(["key"]) @@ -95,7 +97,7 @@ class KVClientExample: if __name__ == '__main__': - example = KVClientExample('127.0.0.1', 31501) + example = KVClientExample() example.set_data_example() example.set_value_with_generate_key_example() example.set_data_with_ttl_example() diff --git a/example/python/object_client_example.py b/example/python/object_client_example.py new file mode 100644 index 0000000..71da3b7 --- /dev/null +++ b/example/python/object_client_example.py @@ -0,0 +1,76 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Object client python interface test. +""" + +from __future__ import absolute_import + +import argparse + +from datasystem.object_client import ObjectClient + + +class ObjectClientExample: + """This class shows the SDK usage example of the ObjectClient.""" + + def __init__(self): + parser = argparse.ArgumentParser(description="Object client python interface Test") + parser.add_argument("--host", required=True, help="The IP of worker service") + parser.add_argument("--port", required=True, type=int, help="The port of worker service") + args = parser.parse_args() + self._host = args.host + self._port = args.port + + def set_data_example(self): + """This function shows the basic usage of g_increase_ref/create/get/g_decrease_ref.""" + client = ObjectClient(self._host, self._port) + # Init object client + client.init() + + # Increase the key's global reference + key = "key" + client.g_increase_ref([key]) + + # Create shared memory buffer for key. + value = bytes("val", encoding="utf8") + size = len(value) + buf = client.create(key, size) + + # Lock shared memory buffer. NOTE: Lock protection for shared memory data + # access is only required in scenarios involving concurrent access from + # multiple instances on a single node. + buf.wlatch() + + # Copy data to shared memory buffer. + buf.memory_copy(value) + + # Publish the key. + buf.publish() + + # Unlock shared memory buffer. + buf.unwlatch() + + # Get the key. + buffer_list = client.get([key], True) + if value != buffer_list[0].immutable_data(): + raise RuntimeError(f"Assert failed, expect {value}, but got {buffer_list[0].immutable_data()}") + + # Decrease the key's global reference, the lifecycle of this key will end afterwards. + client.g_decrease_ref([key]) + + +if __name__ == "__main__": + ObjectClientExample().set_data_example() diff --git a/example/run-example.sh b/example/run-example.sh index f32bdde..cb73d9c 100755 --- a/example/run-example.sh +++ b/example/run-example.sh @@ -15,30 +15,61 @@ set -e -readonly EXAMPLE_DIR=$(dirname "$(readlink -f "$0")") -example_build_dir="${EXAMPLE_DIR}/build" -ds_output_dir="${EXAMPLE_DIR}/../output" +readonly curr_dir=$(dirname "$(readlink -f "$0")") +example_cpp_dir="${curr_dir}/cpp/build" +example_python_dir="${curr_dir}/python" +config_file="${curr_dir}/../config.cmake" old_ld_path=${LD_LIBRARY_PATH} -export PATH=$PATH:${EXAMPLE_DIR}/../scripts/modules +function get_var_from_cmake() { + local var_name="$1" + local file="$2" + grep -E "^set\(${var_name} " "$file" | sed -E "s/^set\(${var_name} \"(.*)\"\)/\1/" +} +ds_output_dir=$(get_var_from_cmake "INSTALL_DIR" "$config_file") +run_hetero=$(get_var_from_cmake "BUILD_HETERO" "$config_file") +run_python=$(get_var_from_cmake "PACKAGE_PYTHON" "$config_file") + +[[ ! -d "${ds_output_dir}"/service ]] && tar -zxf "${ds_output_dir}"/yr-datasystem-v$(cat "${curr_dir}/../VERSION").tar.gz -C ${ds_output_dir} +python3 -m pip install ${ds_output_dir}/openyuanrong_datasystem-*.whl --force-reinstall + +export PATH=$PATH:${curr_dir}/../scripts/modules export PATH=$PATH:/usr/sbin . llt_util.sh -run_hetero="$1" -run_perf="$2" -start_all "${EXAMPLE_DIR}/build" "${ds_output_dir}" +start_all "${example_cpp_dir}" "${ds_output_dir}" +cleanup_once=false +cleanup() { + if [ "$cleanup_once" = true ]; then + return + fi + cleanup_once=true + trap '' INT TERM + echo "stop all service..." + stop_all "${ds_output_dir}" +} +trap cleanup EXIT INT TERM -echo -e "---- Running the example..." +# run cpp example +echo -e "---- Running cpp example..." export LD_LIBRARY_PATH="${ds_output_dir}/sdk/cpp/lib:${LD_LIBRARY_PATH}" echo "Set LD_LIBRARY_PATH=${LD_LIBRARY_PATH} before cpp example test." -${example_build_dir}/ds_example "127.0.0.1" "${worker_port}" -${example_build_dir}/object_example "127.0.0.1" "${worker_port}" "1000" "false" -${example_build_dir}/kv_example "127.0.0.1" "${worker_port}" - -if [ "x$run_perf" == "xon" ]; then - ${example_build_dir}/perf_example "127.0.0.1" "${worker_port}" +${example_cpp_dir}/stream_client_example "127.0.0.1" "${worker_port}" +${example_cpp_dir}/datasystem_example "127.0.0.1" "${worker_port}" +${example_cpp_dir}/object_client_example "127.0.0.1" "${worker_port}" "1000" "false" +${example_cpp_dir}/kv_client_example "127.0.0.1" "${worker_port}" +if [ "x$run_hetero" == "xon" ]; then + ${example_cpp_dir}/hetero_client_example "127.0.0.1" "${worker_port}" fi - export LD_LIBRARY_PATH="${old_ld_path}" -stop_all "${ds_output_dir}" +# run python example +if [ "x$run_python" == "xon" ]; then + echo -e "---- Running python example..." + python ${example_python_dir}/object_client_example.py --host "127.0.0.1" --port "${worker_port}" + python ${example_python_dir}/kv_client_example.py --host "127.0.0.1" --port "${worker_port}" + if [ "x$run_hetero" == "xon" ]; then + python ${example_python_dir}/hetero_client_example.py --host "127.0.0.1" --port "${worker_port}" + python ${example_python_dir}/ds_tensor_client_example.py --host "127.0.0.1" --port "${worker_port}" + fi +fi diff --git a/example/src/device_object_cache/device_object_example.cpp b/example/src/device_object_cache/device_object_example.cpp deleted file mode 100644 index e69de29..0000000 diff --git a/include/datasystem/datasystem.h b/include/datasystem/datasystem.h index 63a8b17..c56e1e7 100644 --- a/include/datasystem/datasystem.h +++ b/include/datasystem/datasystem.h @@ -26,6 +26,7 @@ #include "datasystem/hetero_client.h" #include "datasystem/object_client.h" #include "datasystem/kv_client.h" +#include "datasystem/stream_client.h" #include "datasystem/utils/status.h" namespace datasystem { diff --git a/include/datasystem/hetero_client.h b/include/datasystem/hetero_client.h index a169b18..bf53c6b 100644 --- a/include/datasystem/hetero_client.h +++ b/include/datasystem/hetero_client.h @@ -149,7 +149,7 @@ public: /// DevMSet and DevMGet must be used together. Heterogeneous objects are not automatically deleted after /// DevMGet is executed. If an object is no longer used, invoke DevLocalDelete to delete it. /// \param[in] keys Keys corresponding to blob2dList - /// \param[in] devBlobList List describing the structure of Device memory + /// \param[in,out] devBlobList List describing the structure of Device memory /// \param[out] failedKeys Returns failed keys if retrieval fails /// \param[in] subTimeoutMs Provides a timeout time, defaulting to 0 /// \return K_OK on when return sucesssfully; the error code otherwise. diff --git a/include/datasystem/kv_client.h b/include/datasystem/kv_client.h index b1fd14e..b6de807 100644 --- a/include/datasystem/kv_client.h +++ b/include/datasystem/kv_client.h @@ -21,7 +21,6 @@ #define DATASYSTEM_KV_CLIENT_H #include -#include #include #include "datasystem/context/context.h" diff --git a/include/datasystem/object_client.h b/include/datasystem/object_client.h index b15cff2..c3a4966 100644 --- a/include/datasystem/object_client.h +++ b/include/datasystem/object_client.h @@ -21,11 +21,7 @@ #ifndef DATASYSTEM_OBJECT_CLIENT_H #define DATASYSTEM_OBJECT_CLIENT_H -#include -#include #include -#include -#include #include #include @@ -45,7 +41,6 @@ class ObjectClientImpl; namespace datasystem { struct CreateParam { - WriteMode writeMode = WriteMode::NONE_L2_CACHE; ConsistencyType consistencyType = ConsistencyType::PRAM; CacheType cacheType = CacheType::MEMORY; }; diff --git a/include/datasystem/stream/consumer.h b/include/datasystem/stream/consumer.h new file mode 100644 index 0000000..641a335 --- /dev/null +++ b/include/datasystem/stream/consumer.h @@ -0,0 +1,128 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define api of stream cache consumer. + */ +#ifndef DATASYSTEM_STREAM_CACHE_CONSUMER_H +#define DATASYSTEM_STREAM_CACHE_CONSUMER_H + +#include +#include + +#include "datasystem/stream/element.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class StreamClientImpl; +class ConsumerImpl; +} // namespace stream_cache +} // namespace client +} // namespace datasystem + +namespace datasystem { +class __attribute((visibility("default"))) Consumer { +public: + ~Consumer(); + /** + * @brief Get expectNum elements form the subscription. + * @param[in] expectNum The number of elements to be read. + * @param[in] timeoutMs The timeout millisecond of elements to be Receive. + * @param[out] outElements The received elements to be read. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_RPC_UNAVAILABLE: didn't receive any response from server. + * K_DUPLICATED: the consumer already had pending receive. + * K_SC_PRODUCER_NOT_FOUND: one or more producer in the stream are dead. + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state. + * K_SC_ALREADY_CLOSED: consumer is already closed/inactive. + * K_SC_STREAM_IN_USE: another thread is calling API from the same consumer at the same time. + */ + Status Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements); + + /** + * @brief Get any number of elements already received from the subscription. + * @param[in] timeoutMs The timeout millisecond of elements to be Receive. + * @param[out] outElements The received elements to be read. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_RPC_UNAVAILABLE: didn't receive any response from server. + * K_DUPLICATED: the consumer already had pending receive. + * K_SC_PRODUCER_NOT_FOUND: one or more producer in the stream are dead. + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state. + * K_SC_ALREADY_CLOSED: consumer is already closed/inactive. + * K_SC_STREAM_IN_USE: another thread is calling API from the same consumer at the same time. + */ + Status Receive(uint32_t timeoutMs, std::vector &outElements); + + /** + * @brief Acknowledge elements that had been read by this consumer. + * @param[in] elementId The element id that to be acknowledged. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state. + * K_SC_ALREADY_CLOSED: consumer is already closed/inactive. + * K_SC_STREAM_IN_USE: another thread is calling API from the same consumer at the same time. + */ + Status Ack(uint64_t elementId); + + /** + * @brief Close the consumer, after close it will not allow Receive and Ack Elements. + * Calling Close() on an already closed consumer will return K_OK. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_RUNTIME_ERROR: delete sub node in global scope fail on master process. + * K_SC_STREAM_IN_USE: another thread is calling API from the same consumer at the same time. + */ + Status Close(); + + /** + * @brief Get the amount of received elements since this consumer construct, and the amount of elements + * not processed. + * @param[out] totalElements the amount of received elements since this consumer construct. + * @param[out] notProcessedElements the amount of elements not processed. + */ + void GetStatisticsMessage(uint64_t &totalElements, uint64_t ¬ProcessedElements); + +private: + explicit Consumer(std::unique_ptr impl); + + /** + * @cond Friend does not show up in the documentation. + */ + friend class client::stream_cache::StreamClientImpl; + // @endcond + + // for make_shared to access private/protected constructor. + friend std::shared_ptr std::make_shared(); + // for make_unique to access private/protected constructor. + friend std::unique_ptr std::make_unique(); + + std::unique_ptr impl_; +}; +} // namespace datasystem +#endif // DATASYSTEM_STREAM_CACHE_CONSUMER_H diff --git a/include/datasystem/stream/element.h b/include/datasystem/stream/element.h new file mode 100644 index 0000000..512849e --- /dev/null +++ b/include/datasystem/stream/element.h @@ -0,0 +1,57 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define api of stream cache. + */ +#ifndef DATASYSTEM_STREAM_CACHE_ELEMENT_H +#define DATASYSTEM_STREAM_CACHE_ELEMENT_H + +#include +#include +#include +#include + +#include "datasystem/utils/status.h" + +namespace datasystem { +/** + * @brief Element struct settings. + */ +struct Element { + Element(uint8_t *ptr = nullptr, uint64_t size = 0, uint64_t id = ULONG_MAX) : ptr(ptr), size(size), id(id) + { + } + + ~Element() = default; + + /** + * @brief The pointer of element. + */ + uint8_t *ptr; + + /** + * @brief The size of element. + */ + uint64_t size; + + /** + * @brief The id of element which can created and increased by datasystem automatically. + */ + uint64_t id; +}; +} // namespace datasystem +#endif // DATASYSTEM_STREAM_CACHE_ELEMENT_H diff --git a/include/datasystem/stream/producer.h b/include/datasystem/stream/producer.h new file mode 100644 index 0000000..4f074de --- /dev/null +++ b/include/datasystem/stream/producer.h @@ -0,0 +1,111 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Declare stream cache producer. + */ +#ifndef DATASYSTEM_STREAM_CACHE_PRODUCER_H +#define DATASYSTEM_STREAM_CACHE_PRODUCER_H + +#include +#include +#include + +#include "datasystem/stream/element.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class ProducerImpl; +class StreamClientImpl; +} // namespace stream_cache +} // namespace client +} // namespace datasystem + +namespace datasystem { +class __attribute ((visibility ("default"))) Producer { +public: + ~Producer(); + + /** + * @brief Send one element of the stream. + * @param[in] element The element that to be written. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_RUNTIME_ERROR: producer not init. + * K_OUT_OF_MEMORY: out of memory, or unable to secure enough memory for the element. + * K_RUNTIME_ERROR: element copy failed, it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_RUNTIME_ERROR: can not find mmap file or mmap fd failed. + * K_INVALID: invalid parameter. + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state. + * K_SC_ALREADY_CLOSED: producer is already closed/inactive. + * K_SC_STREAM_IN_USE: another thread is calling API from the same producer at the same time. + */ + Status Send(const Element &element); + + /** + * @brief Send one element of the stream, blocking version. + * @param[in] element The element that to be written. + * @param[in] timeoutMs The amount of time in milliseconds to wait for the send to complete in the range of + * [0, INT32_MAX]. A value of 0 means that it will immediately return the error reason without waiting if the send + * cannot be completed right away. A value greater than 0 makes this a possible blocking call where it will wait for + * the operation to complete if needed. If the wait time exceeds the value then the function will stop waiting and + * return the error reason. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_RUNTIME_ERROR: producer not init. + * K_OUT_OF_MEMORY: out of memory, or unable to secure enough memory for the element within timeoutMs. + * K_RUNTIME_ERROR: element copy failed, it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_RUNTIME_ERROR: can not find mmap file or mmap fd failed. + * K_INVALID: invalid parameter. + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state. + * K_SC_ALREADY_CLOSED: producer is already closed/inactive. + * K_SC_STREAM_IN_USE: another thread is calling API from the same producer at the same time. + */ + Status Send(const Element &element, int64_t timeoutMs); + + /** + * @brief Close the producer, after close it will not allow Send new Elements, and it will trigger flush operations + * when the local buffer had not flushed elements. Calling Close() on an already closed producer will return K_OK. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: it's up to return message. + * K_RUNTIME_ERROR: it's up to return message. + * K_SC_STREAM_IN_USE: another thread is calling API from the same producer at the same time. + */ + Status Close(); + +private: + explicit Producer(std::shared_ptr impl); + + /** + * @cond Friend does not show up in the documentation. + */ + friend class client::stream_cache::StreamClientImpl; + // @endcond + + // for make_shared to access private/protected constructor. + friend std::shared_ptr std::make_shared(); + // for make_unique to access private/protected constructor. + friend std::unique_ptr std::make_unique(); + + std::shared_ptr impl_; +}; +} // namespace datasystem +#endif // DATASYSTEM_STREAM_CACHE_PRODUCER_H diff --git a/include/datasystem/stream/stream_config.h b/include/datasystem/stream/stream_config.h new file mode 100644 index 0000000..5a5b12a --- /dev/null +++ b/include/datasystem/stream/stream_config.h @@ -0,0 +1,134 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define api of stream cache + */ +#ifndef DATASYSTEM_STREAM_CACHE_STREAM_CONFIG_H +#define DATASYSTEM_STREAM_CACHE_STREAM_CONFIG_H + +#include +#include + +#include "datasystem/utils/status.h" + +namespace datasystem { +static constexpr int SC_NORMAL_LOG_LEVEL = 1; // Normal output log level for stream cache module +static constexpr int SC_INTERNAL_LOG_LEVEL = 2; // Internal output log level for stream cache module +static constexpr int SC_DEBUG_LOG_LEVEL = 3; // Debug output log level for stream cache module + +/** + * @brief Subscription Types. + * @details Stream Mode, Queue Mode (Round Robin and Key Partition). + */ +enum SubscriptionType { STREAM, ROUND_ROBIN, KEY_PARTITIONS, UNKNOWN }; + +/** + * @brief Subscription configuration. + * @details Consisting of subscription name and type. Optionally, the cache capacity can be adjusted, and the cache + * prefetch low water mark can be enabled (non-zero value will turn prefetching on). + */ +struct SubscriptionConfig { + static constexpr uint32_t SC_CACHE_CAPACITY = 32768; // Default local subscription cache capacity + static constexpr uint16_t SC_CACHE_LWM = 0; // Default cache prefetch percent. + std::string subscriptionName; + SubscriptionType subscriptionType = SubscriptionType::STREAM; + uint32_t cacheCapacity = SC_CACHE_CAPACITY; + uint16_t cachePrefetchLWM = SC_CACHE_LWM; // Enabled when value is greater than 0. Default is off. + // Should the consumer receive notification about the fault of a producer. Default is false. + + SubscriptionConfig(std::string subName, const SubscriptionType subType) + : subscriptionName(std::move(subName)), subscriptionType(subType) + { + } + + SubscriptionConfig(std::string subName, const SubscriptionType subType, uint32_t cacheMax, + uint16_t cachePrefetchPercent) + : subscriptionName(std::move(subName)), subscriptionType(subType), cacheCapacity(cacheMax), + cachePrefetchLWM(cachePrefetchPercent) + { + } + + SubscriptionConfig() = default; + + SubscriptionConfig(const SubscriptionConfig &other) = default; + + SubscriptionConfig &operator=(const SubscriptionConfig &other) = default; + + SubscriptionConfig(SubscriptionConfig &&other) noexcept + { + subscriptionName = std::move(other.subscriptionName); + subscriptionType = other.subscriptionType; + cacheCapacity = other.cacheCapacity; + cachePrefetchLWM = other.cachePrefetchLWM; + } + + SubscriptionConfig &operator=(SubscriptionConfig &&other) noexcept + { + subscriptionName = std::move(other.subscriptionName); + subscriptionType = other.subscriptionType; + cacheCapacity = other.cacheCapacity; + cachePrefetchLWM = other.cachePrefetchLWM; + return *this; + } + + bool operator==(const SubscriptionConfig &config) const + { + return (subscriptionName == config.subscriptionName && subscriptionType == config.subscriptionType + && cacheCapacity == config.cacheCapacity && cachePrefetchLWM == config.cachePrefetchLWM); + } + + bool operator!=(const SubscriptionConfig &config) const + { + return !(*this == config); + } +}; + +enum StreamMode : int32_t { MPMC = 0, MPSC, SPSC }; + +/** + * @brief Producer configuration. + * @details Auto flush time and page size. + */ +struct ProducerConf { + // default auto flush time 5ms. + int64_t delayFlushTime = 5; + + // default page size 1MB, must be a multiple of 4KB, must not greater than 16MB. + int64_t pageSize = 1024 * 1024ul; + + // default max stream size 100MB, must greater then 64KB and less than the shared memory size. + uint64_t maxStreamSize = 100 * 1024 * 1024ul; + + // auto stream clean up when the last producer/consumer exits. + bool autoCleanup = false; + + // the number of consumers to retain data for, default to 0. + // Notice: If a worker is voluntary scaled down, data will be lost if no remote consumer is created, even if + // retainForNumConsumers is set. + uint64_t retainForNumConsumers = 0; + + // enable stream data encryption between workers. + bool encryptStream = false; + + // default reserve size to page size, must be a multiple of page size. + uint64_t reserveSize = 0; + + // default stream mode MPMC. + StreamMode streamMode = StreamMode::MPMC; +}; +} // namespace datasystem +#endif // DATASYSTEM_STREAM_CACHE_STREAM_CONFIG_H diff --git a/include/datasystem/stream_client.h b/include/datasystem/stream_client.h new file mode 100644 index 0000000..9096f85 --- /dev/null +++ b/include/datasystem/stream_client.h @@ -0,0 +1,147 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Declare stream cache client. + */ +#ifndef DATASYSTEM_STREAM_CACHE_STREAM_CLIENT_H +#define DATASYSTEM_STREAM_CACHE_STREAM_CLIENT_H + +#include + +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/connection.h" +#include "datasystem/utils/sensitive_value.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class StreamClientImpl; +} // namespace stream_cache +} // namespace client +} // namespace datasystem + +namespace datasystem { +class __attribute((visibility("default"))) StreamClient { +public: + /** brief Construct ObjectClient. + * @param[in] connectOptions the connect options. + */ + explicit StreamClient(ConnectOptions connectOptions); + + ~StreamClient(); + + /** + * @brief Shutdown the stream client. + * @return K_OK on success; the error code otherwise. + */ + Status ShutDown(); + + /** + * @brief Initialize the stream client. + * @param[in] reportWorkerLost Report to the user that the worker was lost previously. + * @return K_OK on success; the error code otherwise. + */ + Status Init(bool reportWorkerLost = false); + + /** + * @brief Create one Producer to send element. + * @param[in] streamName The name of stream. The name should not be empty and should only contains english + * alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256 + * @param[out] outProducer The output Producer that user can use it to send element. + * @param[in] producerConf The producer configure. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_RUNTIME_ERROR: delete pub node in global scope fail on master process. + * K_RUNTIME_ERROR: fail to init mmap memory for producer. + * K_NOT_READY: the worker is not ready. + * K_IO_ERROR: can not open curve key from file. + */ + Status CreateProducer(const std::string &streamName, std::shared_ptr &outProducer, + ProducerConf producerConf = {}); + + /** + * @brief Create the relation of subscribe and generate one Consumer to receive elements. + * @param[in] streamName The name of stream. The name should not be empty and should only contains english + * alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256 + * @param[in] config The config of subscription. + * @param[out] outConsumer The output Consumer that user can use it to receive data elements. + * @param[in] autoAck Optional setting to toggle if automatic Acks should be enabled or not. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_RUNTIME_ERROR: add pub node in global scope fail on master process. + * K_NOT_READY: the worker is not ready. + */ + Status Subscribe(const std::string &streamName, const struct SubscriptionConfig &config, + std::shared_ptr &outConsumer, bool autoAck = false); + + /** + * @brief Delete one stream. + * @param[in] streamName The name of stream. The name should not be empty and should only contains english + * alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256 + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_NOT_READY: the worker is not ready. + * K_RUNTIME_ERROR: not allowed to delete stream when producer is running. + * K_RUNTIME_ERROR: not allowed to delete stream when consumer is running. + * K_RUNTIME_ERROR: not allowed to delete stream when remote producer is running. + * K_RUNTIME_ERROR: not allowed to delete stream when remote consumer is running. + * K_RUNTIME_ERROR: has pub node in global scope. + * K_RUNTIME_ERROR: has sub node in global scope. + * K_IO_ERROR: repeat deleting. + * K_KVSTORE_ERROR: can not delete the key. + */ + Status DeleteStream(const std::string &streamName); + + /** + * @brief Query the number of global producers. + * @param[in] streamName The target stream. The name should not be empty and should only contains english + * alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256 + * @param[out] gProducerNum The number of of global producers. + * @return Status of the call. + */ + Status QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum); + + /** + * @brief Query the number of global consumers. + * @param[in] streamName The target stream. The name should not be empty and should only contains english + * alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256 + * @param[out] gConsumerNum The number of of global consumers. + * @return Status of the call. + */ + Status QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum); + +private: + // for make_unique to access private/protected constructor. + friend std::unique_ptr + std::make_unique(); + friend std::unique_ptr std::make_unique(); + + std::shared_ptr impl_; + std::string ip_; + int32_t port_; + SensitiveValue token_; +}; +} // namespace datasystem +#endif // DATASYSTEM_STREAM_CACHE_STREAM_CLIENT_H diff --git a/include/datasystem/utils/status.h b/include/datasystem/utils/status.h index ee10d16..9f90bd1 100644 --- a/include/datasystem/utils/status.h +++ b/include/datasystem/utils/status.h @@ -59,8 +59,9 @@ enum StatusCode : uint32_t { K_RETRY_IF_LEAVING = 30, K_SCALE_DOWN = 31, K_SCALING = 32, - K_LRU_HARD_LIMIT = 33, - K_LRU_SOFT_LIMIT = 34, + K_CLIENT_DEADLOCK = 33, + K_LRU_HARD_LIMIT = 34, + K_LRU_SOFT_LIMIT = 35, // rpc error code, range: [1000, 2000) K_RPC_CANCELLED = 1000, @@ -77,6 +78,19 @@ enum StatusCode : uint32_t { K_OC_KEY_ALREADY_EXIST = 2004, K_WORKER_PULL_OBJECT_NOT_FOUND = 2005, + // stream error code, range: [3000, 4000) + K_SC_STREAM_NOT_FOUND = 3000, + K_SC_PRODUCER_NOT_FOUND = 3001, + K_SC_CONSUMER_NOT_FOUND = 3002, + K_SC_END_OF_PAGE = 3003, + K_SC_STREAM_IN_RESET_STATE = 3004, + K_SC_WORKER_WAS_LOST = 3005, + K_SC_STREAM_IN_USE = 3006, + K_SC_STREAM_DELETE_IN_PROGRESS = 3007, + K_SC_STREAM_RESOURCE_ERROR = 3008, + K_SC_ALREADY_CLOSED = 3009, + K_SC_STREAM_NOTIFICATION_PENDING = 3010, + // Heterogeneous error code, range: [5000, 6000] K_ACL_ERROR = 5000, K_HCCL_ERROR = 5001, diff --git a/install_tools.sh b/install_tools.sh index 3c3d0e9..33483f5 100644 --- a/install_tools.sh +++ b/install_tools.sh @@ -240,6 +240,9 @@ export GONOSUMDB=* export GOPROXY=https://goproxy.cn,direct export PATH=$GOROOT/bin:$GOPATH/bin:$PATH +go install google.golang.org/protobuf/cmd/protoc-gen-go@latest +go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest + # --- Python --- sudo tee /etc/profile.d/python-env.sh > /dev/null << 'EOF' export PYTHON_PATH_3911=/opt/buildtools/python3.9 diff --git a/k8s/docker/dockerfile/datasystem.Dockerfile b/k8s/docker/dockerfile/datasystem.Dockerfile index 8cfdf4b..efc8eb1 100644 --- a/k8s/docker/dockerfile/datasystem.Dockerfile +++ b/k8s/docker/dockerfile/datasystem.Dockerfile @@ -27,6 +27,13 @@ ARG DATASYSTEM_ROOT=${HOME}/datasystem ARG TARGET_SYSTEM ARG ARCHITECTURE +RUN sed -i 's|repo.openeuler.org|mirrors.huaweicloud.com/openeuler|g' /etc/yum.repos.d/*.repo + +RUN dnf clean all && \ + dnf makecache && \ + dnf install -y shadow-utils && \ + dnf clean all + RUN mkdir -p ${DATASYSTEM_ROOT} && \ groupadd -g ${GROUP_ID} ${GROUP_NAME} && \ useradd -u ${USER_UID} -g ${GROUP_ID} -s /sbin/nologin ${USER_NAME} && \ @@ -58,6 +65,11 @@ RUN chmod -R 500 ${DATASYSTEM_ROOT}/bin && \ chmod 500 ${DATASYSTEM_ROOT}/lib && \ chmod 400 ${DATASYSTEM_ROOT}/lib/* +RUN sh -c 'if [ -d "${DATASYSTEM_ROOT}/lib/urma" ]; then \ + chmod 500 ${DATASYSTEM_ROOT}/lib/urma && \ + chmod 400 ${DATASYSTEM_ROOT}/lib/urma/*; \ + fi' + RUN if [ -f /etc/sudoers ]; then \ sed -i "s|%wheel|#%wheel|g" "/etc/sudoers"; \ fi diff --git a/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml b/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml index cba2d34..58d580d 100644 --- a/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml +++ b/k8s/helm_chart/datasystem/templates/worker_daemonset.yaml @@ -23,6 +23,7 @@ spec: rollingUpdate: maxUnavailable: {{ $.Values.global.maxUnavailable }} type: RollingUpdate + progressDeadlineSeconds: {{ $.Values.global.rollingUpdateTimeoutS }} template: metadata: labels: @@ -95,12 +96,30 @@ spec: - -log_compress={{ $.Values.global.log.logCompress }} - -logfile_mode=416 - -shared_memory_size_mb=$(SHARE_MEMORY_SIZE) + - -sc_local_cache_memory_size_mb={{ $.Values.global.resources.datasystemWorker.scLocalCacheMemorySizeMb }} + - -sc_scan_thread_num={{ $.Values.global.rpc.scScanThreadNum }} + - -sc_scan_num_buckets={{ $.Values.global.rpc.scScanNumBuckets }} + - -sc_shared_page_size_mb={{ $.Values.global.resources.datasystemWorker.scSharedPageSizeMb }} + - -sc_shared_page_group_count={{ $.Values.global.resources.datasystemWorker.scSharedPageGroupCount }} + - -oc_shm_threshold_percentage={{ $.Values.global.resources.datasystemWorker.ocShmThresholdPercentage }} + - -sc_shm_threshold_percentage={{ $.Values.global.resources.datasystemWorker.scShmThresholdPercentage }} - -rpc_thread_num={{ $.Values.global.rpc.rpcThreadNum }} - -oc_thread_num={{ $.Values.global.rpc.ocThreadNum }} + - -sc_regular_socket_num={{ $.Values.global.rpc.scRegularSocketNum }} + - -sc_stream_socket_num={{ $.Values.global.rpc.scStreamSocketNum }} - -enable_huge_tlb={{ $.Values.global.performance.enableHugeTlb }} - -enable_fallocate={{ $.Values.global.performance.enableFallocate }} - -eviction_thread_num={{ $.Values.global.spill.evictionThreadNum }} - -eviction_reserve_mem_threshold_mb={{ $.Values.global.spill.evictionReserveMemThresholdMB }} + - -sc_worker_worker_direct_port={{ $.Values.global.rpc.scWorkerWorkerDirectPort }} + - -sc_worker_worker_pool_size={{ $.Values.global.rpc.scWorkerWorkerPoolSize }} + - -sc_gc_interval_ms={{ $.Values.global.rpc.scGcIntervalMs }} + - -sc_scan_interval_ms={{ $.Values.global.rpc.scScanIntervalMs }} + - -sc_metrics_log_interval_s={{ $.Values.global.log.scMetricsLogIntervalS }} + - -sc_cache_pages={{ $.Values.global.resources.datasystemWorker.scCachePages }} + - -page_size={{ $.Values.global.resources.datasystemWorker.pageSize }} + - -remote_send_thread_num={{ $.Values.global.rpc.remoteSendThreadNum }} + - -master_sc_thread_num={{ $.Values.global.rpc.masterScThreadNum }} - -unix_domain_socket_dir={{ $.Values.global.ipc.udsDir }} - -health_check_path=$(HEALTH_CHECK_PATH) - -ready_check_path=$(READY_CHECK_PATH) @@ -135,6 +154,7 @@ spec: - -rocksdb_store_dir=/home/sn/datasystem/rocksdb - -rocksdb_max_open_file={{ $.Values.global.metadata.rocksdbMaxOpenFile }} - -rocksdb_background_threads={{ $.Values.global.metadata.rocksdbBackgroundThreads }} + - -rocksdb_write_mode={{ $.Values.global.metadata.rocksdbWriteMode }} - -node_timeout_s={{ $.Values.global.reliability.nodeTimeoutS }} - -node_dead_timeout_s={{ $.Values.global.reliability.nodeDeadTimeoutS }} - -enable_etcd_auth={{ $.Values.global.etcd.enableEtcdAuth }} @@ -161,8 +181,8 @@ spec: - -zmq_client_io_context={{ $.Values.global.rpc.zmqClientIoContext }} - -zmq_chunk_sz={{ $.Values.global.rpc.zmqChunkSz | int64 }} - -etcd_address={{ $.Values.global.etcd.etcdAddress }} - - -az_name={{ $.Values.global.azName }} - - -other_az_names={{ $.Values.global.crossAz.otherAzNames }} + - -cluster_name={{ $.Values.global.clusterName }} + - -other_cluster_names={{ $.Values.global.crossAz.otherClusterNames }} - -async_delete={{ $.Values.global.performance.asyncDelete}} - -etcd_meta_pool_size={{ $.Values.global.etcd.etcdMetaPoolSize }} - -arena_per_tenant={{ $.Values.global.performance.arenaPerTenant }} @@ -178,6 +198,20 @@ spec: - -liveness_probe_timeout_s={{ $.Values.global.reliability.livenessProbeTimeoutS }} - -check_async_queue_empty_time_s={{ $.Values.global.gracefulShutdown.checkAsyncQueueEmptyTimeS }} - -enable_lossless_data_exit_mode={{ $.Values.global.gracefulShutdown.enableLosslessDataExitMode }} + - -enable_urma={{ $.Values.global.performance.enableUrma }} + - -urma_poll_size={{ $.Values.global.performance.urmaPollSize }} + - -urma_register_whole_arena={{ $.Values.global.performance.urmaRegisterWholeArena }} + - -urma_connection_size={{ $.Values.global.performance.urmaConnectionSize }} + - -urma_event_mode={{ $.Values.global.performance.urmaEventMode }} + - -oc_shm_transfer_threshold_kb={{ $.Values.global.performance.ocShmTransferThresholdKB }} + - -shared_disk_directory={{ $.Values.global.performance.sharedDiskDirectory }} + - -shared_disk_size_mb={{ $.Values.global.performance.sharedDiskSize }} + - -shared_disk_arena_per_tenant={{ $.Values.global.performance.sharedDiskArenaPerTenant }} + - -stream_idle_time_s={{ $.Values.global.rpc.streamIdleTimes }} + - -enable_distributed_master={{ $.Values.global.reliability.enableDistributedMaster }} + - -enable_p2p_transfer={{ $.Values.global.performance.enableP2pTransfer }} + - -rolling_update_timeout_s={{ $.Values.global.rollingUpdateTimeoutS }} + - -enable_worker_worker_batch_get={{ $.Values.global.performance.enableWorkerWorkerBatchGet }} command: - /home/sn/worker_entry.sh securityContext: diff --git a/k8s/helm_chart/datasystem/values.yaml b/k8s/helm_chart/datasystem/values.yaml index 72b8c49..4073f35 100644 --- a/k8s/helm_chart/datasystem/values.yaml +++ b/k8s/helm_chart/datasystem/values.yaml @@ -10,7 +10,7 @@ global: datasystem: "openyuanrong-datasystem:0.5.0" # Config ETCD table prefix, the value should only contain english alphabetics (a-zA-Z), numbers (0-9) only. - azName: "AZ1" + clusterName: "AZ1" etcd: # ETCD configuration @@ -57,10 +57,25 @@ global: # Upper limit of the shared memory, the default unit for shared memory is MB. # To prevent being rendered as scientific notation by helm, numbers with more than 5 digits should be configured as strings. sharedMemory: 2048 - + # Upper limit of the stream cache local memory, the unit for shared memory is MB. + scLocalCacheMemorySizeMb: 1024 + #Upper limit of the shared memory in percentage can be used by OC, must be within (0, 100] + ocShmThresholdPercentage: 100 + #Upper limit of the shared memory in percentage can be used by SC, must be within (0, 100] + scShmThresholdPercentage: 100 # Maximum number of clients that can be connected to a worker. # Value range: [1, 10000], default value: 200. maxClientNum: 200 + # Number of cached pages. Higher number improve performance + scCachePages: 16 + # The shared page size, should be in range [1, 16]. + scSharedPageSizeMb: 4 + # The shared page group count for each remote worker, should be in range [1, 64]. + scSharedPageGroupCount: 4 + # Size of the page used for caching worker files. + # The valid range is 4096-1073741824. + # Unit for page size is Bytes. + pageSize: "1048576" spill: # The path of the spilling, empty means local_dick spill disabled. @@ -114,6 +129,8 @@ global: logCompress: true # Prefix of log filename, default is program invocation short name. Use standard characters only. logFilename: "" + # Interval between logging stream metrics + scMetricsLogIntervalS: 60 observability: # Record performance and resource logs @@ -193,6 +210,8 @@ global: # Number of background threads rocksdb can use for flushing and compacting, default value is 16. The value can be # modified according to cpu limitations but should be greater than 0 to ensure that there are backend threads to handle tasks. rocksdbBackgroundThreads: 16 + # Config the rocksdb support none, sync or async, async by default. Optional value: 'none', 'sync', 'async'. This represents the method of writing metadata to rocksdb. + rocksdbWriteMode: "async" rpc: # Whether to enable the authentication function between components(worker, master) @@ -218,6 +237,12 @@ global: # Number of parallel connections from worker to worker scenarios to improve throughput. # OC_WORKER_WORKER_DIRECT_PORT must be enabled to take effect. ocWorkerWorkerPoolSize: 3 + # A direct tcp/ip port for worker to workers scenarios to improve latency. + # Acceptable values:0, or some positive integer. 0 means disabled. + scWorkerWorkerDirectPort: 0 + # Number of parallel connections from worker to worker scenarios to improve throughput. + # SC_WORKER_WORKER_DIRECT_PORT must be enabled to take effect. + scWorkerWorkerPoolSize: 3 # The payload threshold to batch get objects, unit is MB. Setting to 0 will indicate no split. batchGetThresholdMb: 100 @@ -238,6 +263,25 @@ global: zmqChunkSz: 1048576 # Maximum number of sessions that can be cached, must be within [512, 10'000] maxRpcSessionNum: 2048 + # The num of threads used to scan new elements in shared memory. + scScanThreadNum: 16 + # Number of partitions for scanning streams. + scScanNumBuckets: 1024 + # The number of regular backend socket for stream cache. + scRegularSocketNum: 64 + # The number of stream backend socket for stream cache. + scStreamSocketNum: 64 + # The maximum number of threads for non-rpc tasks in the master. + masterScThreadNum: 128 + # The num of threads used to send elements to remote worker. + remoteSendThreadNum: 8 + + # Memory resource clean up interval in milliseconds + scGcIntervalMs: 50 + # Remote send interval in milliseconds + scScanIntervalMs: 10 + # stream idle time. default 300s (5 minutes) + streamIdleTimes: 300 ipc: # Determines whether the shared memory feature is enabled. @@ -266,6 +310,8 @@ global: addNodeWaitTimeS: 60 # Decide whether to remove the node from hash ring or not when node is dead autoDelDeadNode: true + # Whether to support distributed master, default is true. + enableDistributedMaster: true gracefulShutdown: # Scale in taint, format is key=value:effect. @@ -309,10 +355,32 @@ global: # Client doesn't need to wait for all workers to delete objects. # The default value is false. asyncDelete: false + # Heterogeneous object transfer protocol Enables p2ptransfer + enableP2pTransfer: false + # Enable worker->worker OC batch get, default false. + enableWorkerWorkerBatchGet: false + # The data threshold to transfer obj data between client and worker via shm, unit is KB. + ocShmTransferThresholdKB: 500 + # Option to turn on urma for OC worker to worker data transfer, default false. + enableUrma: false + # Number of complete record to poll at a time, 16 is the max this device can poll. + urmaPollSize: 8 + # Register the whole arena as segment during init, otherwise, register each object as a segment. + urmaRegisterWholeArena: true + # Number of jfs and jfr pair. + urmaConnectionSize: 16 + # Uses interrupt mode to poll completion events. + urmaEventMode: false + # Disk cache data placement directory, default value is empty, indicating that disk cache is not enabled. + sharedDiskDirectory: "" + # Upper limit of the shared disk, the unit is mb. + sharedDiskSize: 0 + # The number of disk cache Arena for each tenant. Multiple arenas can improve the performance of shared disk allocation for the first time, but each arena will use one more fd. The valid range is 0 to 32. + sharedDiskArenaPerTenant: 8 crossAz: # Specify other az names using the same etcd. Only split by ',' - otherAzNames: "" + otherClusterNames: "" # Control whether to try to get data from other AZ's worker first. If false, data will be retrieved directly from the L2 cache. crossAzGetDataFromWorker: true # Control whether to get meta data from other AZ's worker, if false then get meta data from local AZ. @@ -329,7 +397,7 @@ global: tenantSecretKey: "" # Request expiration time in seconds, the maximum value is 300s. requestExpireTimeS: 300 - + # fsGroup configuration # All processes of the container are also part of the supplementary group ID. fsGid: "1002" @@ -438,4 +506,7 @@ global: terminationGracePeriodSeconds: 1800 # Used to control how many Pods can be in an unavailable state during a rolling update. - maxUnavailable: "100%" \ No newline at end of file + maxUnavailable: "100%" + + # Maximum duration of the rolling upgrade, default value is 1800 seconds. + rollingUpdateTimeoutS: 1800 \ No newline at end of file diff --git a/python/__init__.py b/python/__init__.py index 0a800a0..446dbfa 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -23,6 +23,7 @@ __all__ = [ "ObjectClient", "KVClient", "Status", + "StreamClient", "SubconfigType", "WriteMode", "Context", @@ -38,6 +39,7 @@ __all__ = [ from datasystem.object_client import Buffer, ConsistencyType from datasystem.object_client import ObjectClient, WriteMode from datasystem.lib.libds_client_py import FutureTimeoutException +from datasystem.stream_client import SubconfigType, StreamClient from datasystem.ds_client import DsClient from datasystem.kv_client import KVClient from datasystem.hetero_client import HeteroClient, Blob, DeviceBlobList, MetaInfo, Future diff --git a/python/ds_tensor_client.py b/python/ds_tensor_client.py index 685e52b..322d083 100644 --- a/python/ds_tensor_client.py +++ b/python/ds_tensor_client.py @@ -193,17 +193,32 @@ class DsTensorClient: ) self._device_id = device_id + @staticmethod + def _get_tensor_device_type(tensor: Tensor) -> str: + """Safely get tensor device type with proper error handling""" + if hasattr(tensor, 'device') and isinstance(tensor.device, str): + error_msg = ("tensor.device is a string, 'str' object has no attribute 'type'. This usually indicates that " + "MindSpore is being used without msadapter, or there's a version mismatch.\n " + "Solutions:\n " + "1. Install msadapter for proper tensor operations: \n" + " pip install msadapter\n " + "2. Ensure MindSpore and msadapter version compatibility.\n") + raise AttributeError(error_msg) + + return tensor.device.type + @staticmethod def _is_ms_tensor(tensor: Tensor) -> str: """check if the tensor is mindspore type""" - is_ms = (tensor.device.type == "Ascend") + is_ms = (DsTensorClient._get_tensor_device_type(tensor) == "Ascend") return is_ms @staticmethod def _check_tensor_device_type(tensor: Tensor) -> None: """check the tensor type""" - if tensor.device.type not in ["Ascend", "npu"]: - raise ValueError(f"{tensor.device.type} tensor, not a npu/Ascend tensor") + device_type = DsTensorClient._get_tensor_device_type(tensor) + if device_type not in ["Ascend", "npu"]: + raise ValueError(f"{device_type} tensor, not a npu/Ascend tensor") @staticmethod def _check_tensors_is_contiguous(tensors: List[Tensor]) -> None: diff --git a/python/object_client.py b/python/object_client.py index ffa9f34..89da5c0 100644 --- a/python/object_client.py +++ b/python/object_client.py @@ -364,16 +364,14 @@ class ObjectClient: @staticmethod def _check_or_set_default_create_param(param: Dict): - key_write_mode = "write_mode" key_consistency_type = "consistency_type" if param is None: param = { - key_write_mode: WriteMode.NONE_L2_CACHE, key_consistency_type: ConsistencyType.PRAM, } validator.check_args_types([["param", param, dict]]) - return validator.check_key_exists(param, [key_write_mode, key_consistency_type]) + return validator.check_key_exists(param, [key_consistency_type]) def init(self): """ @@ -394,13 +392,7 @@ class ObjectClient: object_key(str): The id of the object to be created. size(int): The size in bytes of object. param(dict): which contains the following three "key: value" pairs successively: - (1) "write_mode", write_mode(Enum): Indicating whether the object will be written through L2 cache. - There are 3 options: - 1) WriteMode.NONE_L2_CACHE; - 2) WriteMode.WRITE_THROUGH_L2_CACHE; - 3) WriteMode.WRITE_BACK_L2_CACHE; - 4) WriteMode.NONE_L2_CACHE_EVICT; - (2) "consistency_type": consistency_type(Enum): Indicating which consistency type will be used. + (1) "consistency_type": consistency_type(Enum): Indicating which consistency type will be used. There are 2 options: 1) ConsistencyType.PRAM; 2) ConsistencyType.CAUSAL; @@ -413,17 +405,16 @@ class ObjectClient: RuntimeError: Raise a runtime error if the client fails to connect to the worker. """ params = self._check_or_set_default_create_param(param) - write_mode, consistency_type = params[0], params[1] + consistency_type = params[0] args = [ ["object_key", object_key, str], ["size", size, int], - ["write_mode", write_mode, WriteMode], ["consistency_type", consistency_type, ConsistencyType], ] validator.check_args_types(args) create_status, buffer = self.client.create( - object_key, size, write_mode.value, consistency_type.value + object_key, size, consistency_type.value ) if create_status.is_error(): raise RuntimeError(create_status.to_string()) @@ -439,13 +430,7 @@ class ObjectClient: object_key(str): The id of the object to be created. value(memoryview, bytes or bytearray): the data to be put param(dict): which contains the following three "key: value" pairs successively: - (1) "write_mode", write_mode(Enum): Indicating whether the object will be written through L2 cache. - There are 3 options: - 1) WriteMode.NONE_L2_CACHE; - 2) WriteMode.WRITE_THROUGH_L2_CACHE; - 3) WriteMode.WRITE_BACK_L2_CACHE; - 4) WriteMode.NONE_L2_CACHE_EVICT - (2) "consistency_type": consistency_type(Enum): Indicating which consistency type will be used. + (1) "consistency_type": consistency_type(Enum): Indicating which consistency type will be used. There are 2 options: 1) ConsistencyType.PRAM; 2) ConsistencyType.CAUSAL; @@ -456,7 +441,7 @@ class ObjectClient: RuntimeError: Raise a runtime error if the put fails. """ params = self._check_or_set_default_create_param(param) - write_mode, consistency_type = params[0], params[1] + consistency_type = params[0] if nested_object_keys is None: nested_object_keys = [] @@ -464,7 +449,6 @@ class ObjectClient: args = [ ["object_key", object_key, str], ["value", value, memoryview, bytes, bytearray], - ["write_mode", write_mode, WriteMode], ["consistency_type", consistency_type, ConsistencyType], ["nested_object_keys", nested_object_keys, list], ] @@ -472,7 +456,6 @@ class ObjectClient: put_status = self.client.put( object_key, value, - write_mode.value, consistency_type.value, nested_object_keys, ) diff --git a/python/stream_client.py b/python/stream_client.py new file mode 100644 index 0000000..b463cc7 --- /dev/null +++ b/python/stream_client.py @@ -0,0 +1,323 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Stream cache client python interface. +""" + +from enum import Enum +import datasystem.lib.libds_client_py as ds +from datasystem.util import Validator as validator + + +class SubconfigType(Enum): + """The type of stream""" + STREAM = 0 + ROUND_ROBIN = 1 + KEY_PARTITIONS = 2 + + +class StreamClient: + """the client of stream""" + + def __init__(self, + host: str, + port: int, + client_public_key: str = "", + client_private_key: str = "", + server_public_key: str = "", + access_key="", + secret_key="", + tenant_id=""): + """ Constructor of the StreamClient class + + Args: + host(str): The worker address host. + port(str): The worker address port. + client_public_key(str): The client's public key, for curve authentication. + client_private_key(str): The client's private key, for curve authentication. + server_public_key(str): The worker server's public key, for curve authentication. + access_key(str): The access key used by AK/SK authorize. + secret_key(str): The secret key for AK/SK authorize. + tenant_id(str): The tenant ID. + """ + + if isinstance(client_private_key, str): + client_private_key = str.encode(client_private_key) + if isinstance(secret_key, str): + secret_key = str.encode(secret_key) + self._client = ds.StreamClient(host, port, client_public_key, client_private_key, server_public_key, access_key, + secret_key, tenant_id) + + def init(self): + """ Init a stream client to connect to a worker. + + Raises: + RuntimeError: Raise a runtime error if the client fails to connect to the worker. + """ + init_status = self._client.init() + if init_status.is_error(): + raise RuntimeError(init_status.to_string()) + + def create_producer(self, + stream_name, + delay_flush_time_ms=5, + page_size_byte=1024 * 1024, + max_stream_size_byte=1024 * 1024 * 1024, + auto_cleanup=False, + retain_for_num_consumers=0, + encrypt_stream=False, + reserve_size=0): + """ Create one Producer to send element. + + Args: + stream_name: The name of the stream. + delay_flush_time_ms: The time used in automatic flush after send and default is 5ms. + page_size_byte: The size used in allocate page and default is 1MB. + must be a multiple of 4KB. + max_stream_size_byte: The max stream size in worker and default is 1GB. + must between greater then 64KB and less than the shared memory size. + auto_cleanup: Should auto delete when the last producer/consumer exit. + retain_for_num_consumers: The number of consumers to retain data for, default to 0. + encrypt_stream: Enable stream data encryption between workers, default to false. + reserve_size: default reserve size to page size, must be a multiple of page size. + Return: + outProducer: The output Producer that user can use it to send element. + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if creating a producer fails. + """ + if not isinstance(stream_name, str): + raise TypeError("The input of stream_name should be string.") + if not isinstance(delay_flush_time_ms, int): + raise TypeError("The input of delay_flush_time_ms should be int.") + validator.check_param_range("delay_flush_time_ms", delay_flush_time_ms, 0, validator.INT32_MAX_SIZE) + if not isinstance(page_size_byte, int): + raise TypeError("The input of page_size_byte should be int.") + status, out_producer = self._client.CreateProducer(stream_name, delay_flush_time_ms, page_size_byte, + max_stream_size_byte, auto_cleanup, retain_for_num_consumers, + encrypt_stream, reserve_size) + if status.is_error(): + raise RuntimeError(status.to_string()) + return Producer(out_producer) + + def subscribe(self, stream_name, sub_name, subscription_type): + """ Subscribe a new consumer onto master request + + Args: + stream_name: The name of the stream. + sub_name: The name of subscription + subscription_type: The type of subscription. + + Return: + outConsumer: The output Consumer that user can use it to receive element. + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RuntimeError: Raise a runtime error if subscribing a new consumer fails. + """ + if not isinstance(stream_name, str): + raise TypeError("The input of stream_name should be string.") + if not isinstance(subscription_type, int): + raise TypeError("The input of type should be int.") + status, out_consumer = self._client.Subscribe(stream_name, sub_name, subscription_type) + if status.is_error(): + raise RuntimeError(status.to_string()) + return Consumer(out_consumer) + + def delete_stream(self, stream_name): + """ Delete one stream + + Args: + stream_name: The name of the stream. + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if deleting one stream fails. + """ + if not isinstance(stream_name, str): + raise TypeError("The input of stream_name should be string.") + status = self._client.DeleteStream(stream_name) + if status.is_error(): + raise RuntimeError(status.to_string()) + + def query_global_producer_num(self, stream_name): + """ Query number of producer in global worker node + + Args: + stream_name: The name of the target stream. + + Returns: + global_producer_num: Query result. + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if querying global producer number fails. + """ + if not isinstance(stream_name, str): + raise TypeError("The input of stream_name should be string.") + status, global_producer_num = self._client.QueryGlobalProducersNum(stream_name) + if status.is_error(): + raise RuntimeError(status.to_string()) + return global_producer_num + + def query_global_consumer_num(self, stream_name): + """ Query number of consumer in global worker node + + Args: + stream_name: The name of the target stream. + + Returns: + global_consumer_num: Query result. + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if querying global consumer number fails. + """ + if not isinstance(stream_name, str): + raise TypeError("The input of stream_name should be string.") + status, global_consumer_num = self._client.QueryGlobalConsumersNum(stream_name) + if status.is_error(): + raise RuntimeError(status.to_string()) + return global_consumer_num + + +class Producer: + """the producer of stream in client""" + + def __init__(self, producer): + if not isinstance(producer, ds.Producer): + raise TypeError("The input of parament should be Producer.") + self._producer = producer + + def send(self, element_bytes, timeout_ms=None): + """ Produce send one element of the stream each time + + Args: + element: The element that to be written. + timeout_ms: The amount of time in milliseconds to wait for the send to complete in the range of + [0, INT32_MAX]. A value of 0 means that it will immediately return the error reason without waiting if + the send cannot be completed right away. A value greater than 0 makes this a possible blocking call + where it will wait for the operation to complete if needed. If the wait time exceeds the value then + the function will stop waiting and return the error reason. + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if sending one element fails. + """ + if not isinstance(element_bytes, memoryview) and not isinstance(element_bytes, bytes) and not isinstance( + element_bytes, bytearray): + raise TypeError("The input of parament should be memoryview or bytes or bytearray.") + if timeout_ms is None: + status = self._producer.Send(element_bytes) + if status.is_error(): + raise RuntimeError(status.to_string()) + else: + if not isinstance(timeout_ms, int): + raise TypeError("The input of timeout_ms should be int.") + validator.check_param_range("timeout_ms", timeout_ms, 0, validator.INT32_MAX_SIZE) + status = self._producer.Send(element_bytes, timeout_ms) + if status.is_error(): + raise RuntimeError(status.to_string()) + + def close(self): + """ Close a producer, register a publisher to a stream. + + Raise: + RuntimeError: Raise a runtime error if closing a producer fails. + """ + status = self._producer.Close() + if status.is_error(): + raise RuntimeError(status.to_string()) + + +class Consumer: + """the consumer of stream in client""" + + def __init__(self, consumer): + if not isinstance(consumer, ds.Consumer): + raise TypeError("The input of parament should be Consumer.") + self._consumer = consumer + + def receive(self, expect_num, timeout_ms): + """ Receive an expected number of elements. + + Args: + expect_num: The number of elements to be read. + timeout_ms: The timeout in milliseconds to wait or until number of expected elements has been received. + + Return: + values: element has been received + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if receiving elements meta falis. + """ + if not isinstance(expect_num, int): + raise TypeError("The input of expect_num should be int.") + validator.check_param_range("expect_num", expect_num, 1, validator.INT32_MAX_SIZE) + if not isinstance(timeout_ms, int): + raise TypeError("The input of timeout_ms should be int.") + validator.check_param_range("timeout_ms", timeout_ms, 0, validator.INT32_MAX_SIZE) + status, element = self._consumer.Receive(expect_num, timeout_ms) + if status.is_error(): + raise RuntimeError(status.to_string()) + return element + + def receive_any(self, timeout_ms): + """ Receive any number of elements that is available. + + Args: + timeout_ms: The timeout in milliseconds to wait or until any number of elements has been received. + + Return: + values: element has been received + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if receiving elements meta falis. + """ + if not isinstance(timeout_ms, int): + raise TypeError("The input of timeout_ms should be int.") + validator.check_param_range("timeout_ms", timeout_ms, 0, validator.INT32_MAX_SIZE) + status, element = self._consumer.ReceiveAny(timeout_ms) + if status.is_error(): + raise RuntimeError(status.to_string()) + return element + + def ack(self, element_id): + """ Acknowledge elements that had been read by this consumer. + + Args: + element_id: The element id that to be acknowledged. + + Raise: + TypeError: Raise a type error if the input parameter is invalid. + RutimeError: Raise a runtime error if acknowledging elements falis. + """ + if not isinstance(element_id, int): + raise TypeError("The input of element_id should be int.") + status = self._consumer.Ack(element_id) + if status.is_error(): + raise RuntimeError(status.to_string()) + + def close(self): + """ Close the consumer, after close it will not allow Receive and Ack Elements. + + Raise: + RuntimeError: Raise a runtime error if closing the consumer falis. + """ + status = self._consumer.Close() + if status.is_error(): + raise RuntimeError(status.to_string()) diff --git a/scripts/modules/llt_util.sh b/scripts/modules/llt_util.sh index 8b75718..84f384b 100644 --- a/scripts/modules/llt_util.sh +++ b/scripts/modules/llt_util.sh @@ -58,6 +58,7 @@ function get_random_port() function check_etcd_ready() { + echo "start check etcd status" local check_flag=false for i in $(seq 1 ${WAIT_COMPONENT_READY_TIMES}); do if etcdctl --endpoints ${etcd_client_url} endpoint health; then @@ -168,6 +169,6 @@ function stop_all() if ps -p ${etcd_pid} >/dev/null; then # interrupt signal will shutdown the etcd cluster echo "Shutting down etcd service pid: ${etcd_pid}" - kill -2 ${etcd_pid} && echo "Success" || echo "Error: $!" + kill -2 ${etcd_pid} || echo "stop etcd failed: $!" fi } diff --git a/scripts/stream_cache/parse_sc_metrics.py b/scripts/stream_cache/parse_sc_metrics.py new file mode 100644 index 0000000..7d82fc8 --- /dev/null +++ b/scripts/stream_cache/parse_sc_metrics.py @@ -0,0 +1,207 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import argparse +import os + +stream_headers = [ + "Time", + "Stream Name", + "NumLocalProducers", + "NumRemoteProducers", + "NumLocalConsumers", + "NumRemoteConsumers", + "SharedMemoryUsed", + "LocalMemoryUsed", + "NumTotalElementsSent", + "NumTotalElementsReceived", + "NumTotalElementsAcked", + "NumSendRequests", + "NumReceiveRequests", + "NumPagesCreated", + "NumPagesReleased", + "NumPagesCached", + "NumBigPagesCreated", + "NumBigPagesReleased", + "NumLocalProducersBlocked", + "NumRemoteProducersBlocked", + "NumRemoteConsumersBlocking", + "RetainDataState", + "StreamState", +] + +worker_headers = [ + "Time", + "TotalNumberStreams", + "TotalNumberInActiveStreams", + "TotalStreamMemoryUsed", + "TotalStreamMemoryLimit", +] + +retain_data_state = ["NONE", "INIT", "RETAIN", "NOT_RETAIN"] + +stream_manager_state = [ + "ACTIVE", + "RESET_IN_PROGRESS", + "RESET_COMPLETE", + "DELETE_IN_PROGRESS", +] + +METRIC_LOG_ENTRY_INDEX = 7 + + +def parse_args(): + """ + Parse arguments + """ + parser = argparse.ArgumentParser( + description="""Parses worker and stream metrics in sc_metrics.log file. + \rOutputs metrics to sc_worker_metrics.csv and sc_stream_metrics.csv respectively""", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""Example: Get worker and stream metrics from sc_metrics.log file, filtering on stream1,stream2 + \r $ python parse_sc_metrics.py sc_metrics.log -t sw -f stream1,stream2""", + ) + parser.add_argument("filename", help="path to the sc_metrics.log file") + parser.add_argument( + "-t", + "--type", + help="metric types to parse from the file. (s = stream, w = worker, sw or ws = stream and worker)", + choices=["s", "w", "sw", "ws"], + default="sw", + ) + parser.add_argument( + "-f", + "--filter", + help='filter stream metrics by stream names. Separate stream names with ","', + ) + parser.add_argument( + "-o", "--output", help="path of the output folder, default: ./", default="./" + ) + args = parser.parse_args() + + parse_worker = "w" in args.type + parse_stream = "s" in args.type + + if args.filter is None: + stream_filter = None + else: + stream_filter = args.filter.split(",") + + return ( + args.filename, + parse_worker, + parse_stream, + stream_filter, + args.output, + ) + + +def print_info_message( + filename, parse_worker, parse_stream, stream_filter, output +): + """ + Prints info message + """ + if parse_worker and parse_stream: + print("Parsing worker, stream metrics from", filename) + elif parse_worker: + print("Parsing worker metrics from", filename) + elif parse_stream: + print("Parsing stream metrics from", filename) + else: + print("error") + if stream_filter: + print("Filtering on stream names:", stream_filter) + print("Outputing to directory:", output) + + +def parse_sc_metrics(filename, parse_worker, parse_stream, stream_filter, output): + """ + Parse metrics from log file + """ + # Start parsing log file + file = open(filename, "r") + # create output directory + if not os.path.exists(output): + os.makedirs(output) + if parse_stream: + stream_csv = open(os.path.join(output, "sc_stream_metrics.csv"), "w") + stream_wr = csv.writer(stream_csv) + stream_wr.writerow(stream_headers) + if parse_worker: + worker_csv = open(os.path.join(output, "sc_worker_metrics.csv"), "w") + worker_wr = csv.writer(worker_csv) + worker_wr.writerow(worker_headers) + + for line in file: + categories = line.split("|") + # actual log line located at index 7 + metric_log = categories[METRIC_LOG_ENTRY_INDEX] + time = categories[0].rstrip() + if parse_worker and "Worker metrics" in metric_log: + metrics = metric_log.split("/") + metrics[len(metrics) - 1] = metrics[len(metrics) - 1].rstrip() + metrics.pop(0) + metrics.insert(0, time) + worker_wr.writerow(metrics) + elif parse_stream and "Worker metrics" not in metric_log and "master " not in metric_log: + metrics = metric_log.split("/") + metrics[0] = metrics[0].lstrip() + # filter based on stream name + if stream_filter is None or metrics[0] in stream_filter: + metrics[len(metrics) - 1] = metrics[len(metrics) - 1].rstrip() + metrics.insert(0, time) + # convert enums to string + metrics[ + stream_headers.index("RetainDataState") + ] = retain_data_state[ + int(metrics[stream_headers.index("RetainDataState")]) + ] + metrics[stream_headers.index("StreamState")] = stream_manager_state[ + int(metrics[stream_headers.index("StreamState")]) + ] + stream_wr.writerow(metrics) + + # close files + if parse_stream: + stream_csv.close() + if parse_worker: + worker_csv.close() + file.close() + + +def main(): + """ + Main execution + """ + ( + filename, + parse_worker, + parse_stream, + stream_filter, + output, + ) = parse_args() + + print_info_message( + filename, parse_worker, parse_stream, stream_filter, output + ) + + parse_sc_metrics(filename, parse_worker, parse_stream, stream_filter, output) + + print("Parse metrics success") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index dd68f0b..17ec13f 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ package_datas = { '': ['sdk_lib_list', 'datasystem_worker', '*.py', 'worker_config.json', 'cluster_config.json', '*so', '*so*', 'helm_chart/**/*', 'helm_chart/*', 'helm_chart/**/**/*', - 'include/*', 'include/**/**/*', 'include/**/*', 'lib/**/**/*', 'lib/*', + 'include/*', 'include/**/**/*', 'include/**/*', 'lib/**/**/*', 'lib/*', 'lib/urma/*', 'cpp_template/**/*', 'cpp_template/*', 'cpp_template/**/**/*', 'cpp_template/**/**/**/*'] } diff --git a/src/datasystem/client/CMakeLists.txt b/src/datasystem/client/CMakeLists.txt index 9b871b4..869c6be 100644 --- a/src/datasystem/client/CMakeLists.txt +++ b/src/datasystem/client/CMakeLists.txt @@ -18,7 +18,16 @@ list(APPEND CLIENT_SRCS object_cache/device/page_attn_utils.cpp kv_cache/read_only_buffer.cpp kv_cache/kv_client.cpp - hetero_cache/hetero_client.cpp) + hetero_cache/hetero_client.cpp + stream_cache/client_worker_api.cpp + stream_cache/client_base_impl.cpp + stream_cache/consumer.cpp + stream_cache/consumer_impl.cpp + stream_cache/producer.cpp + stream_cache/producer_impl.cpp + stream_cache/stream_client.cpp + stream_cache/stream_client_impl.cpp + stream_cache/producer_consumer_worker_api.cpp) list(APPEND CLIENT_DEPEND_LIBS ${SECUREC_LIBRARY} @@ -30,6 +39,7 @@ list(APPEND CLIENT_DEPEND_LIBS common_inject common_log common_perf + common_sc common_shm_unit_info common_util common_immutable_string @@ -37,6 +47,7 @@ list(APPEND CLIENT_DEPEND_LIBS posix_protos_client share_memory_protos_client worker_object_protos_client + worker_stream_protos_client common_acl_device common_shared_memory common_rdma) @@ -65,10 +76,12 @@ add_dependencies(datasystem_static share_memory_protos_client posix_protos_client worker_object_protos_client - master_object_protos_client) + master_object_protos_client + worker_stream_protos_client) add_dependencies(datasystem share_memory_protos_client posix_protos_client worker_object_protos_client - master_object_protos_client) + master_object_protos_client + worker_stream_protos_client) diff --git a/src/datasystem/client/client_worker_common_api.cpp b/src/datasystem/client/client_worker_common_api.cpp index a63dbb4..62bba00 100644 --- a/src/datasystem/client/client_worker_common_api.cpp +++ b/src/datasystem/client/client_worker_common_api.cpp @@ -58,6 +58,7 @@ ClientWorkerCommonApi::ClientWorkerCommonApi(HostPort hostPort, RpcCredential cr bool enableCrossNodeConnection) : hostPort_(std::move(hostPort)), cred_(std::move(cred)), + pageSize_(0), socketFd_(-1), heartbeatType_(heartbeatType), signature_(signature), @@ -110,7 +111,7 @@ Status ClientWorkerCommonApi::Init(int32_t timeoutMs) CHECK_FAIL_RETURN_STATUS(TimerQueue::GetInstance()->Initialize(), K_RUNTIME_ERROR, "TimerQueue init failed!"); RegisterClientReqPb req; RETURN_IF_NOT_OK(Connect(req, timeoutMs)); - VLOG(1) << "The new client id is: " << clientId_; + VLOG(1) << "The new client id is: " << clientId_ << ", Received pageSize= " << pageSize_ << " from worker."; return Status::OK(); } @@ -365,6 +366,7 @@ Status ClientWorkerCommonApi::RegisterClient(RegisterClientReqPb &req, int32_t t workerTimeoutMult_ = rsp.quorum_timeout_mult(); clientId_ = rsp.client_id(); workerStartId_ = rsp.worker_start_id(); + pageSize_ = static_cast(rsp.page_size()); lockId_ = rsp.lock_id(); (void)workerVersion_.fetch_add(1, std::memory_order_relaxed); shmThreshold_ = rsp.shm_threshold(); diff --git a/src/datasystem/client/client_worker_common_api.h b/src/datasystem/client/client_worker_common_api.h index 0948657..0158a4e 100644 --- a/src/datasystem/client/client_worker_common_api.h +++ b/src/datasystem/client/client_worker_common_api.h @@ -422,6 +422,7 @@ protected: std::shared_timed_mutex standbyWorkerMutex_; std::unordered_set standbyWorkerAddrs_; bool isUseStandbyWorker_ = false; + uint32_t pageSize_{ 0 }; // The page size used when reading files. HostPort masterAddress_; std::string clientId_; std::string workerStartId_; // To judge whether the worker is restarted. diff --git a/src/datasystem/client/context/context.cpp b/src/datasystem/client/context/context.cpp index 3e1bbc4..6a6b8e5 100644 --- a/src/datasystem/client/context/context.cpp +++ b/src/datasystem/client/context/context.cpp @@ -35,13 +35,13 @@ Status Context::SetTraceId(const std::string &traceId) CHECK_FAIL_RETURN_STATUS(traceId.length() <= Trace::TRACEID_PREFIX_SIZE, K_INVALID, FormatString("The length of trace id should less than %d", Trace::TRACEID_PREFIX_SIZE)); Trace::Instance().SetPrefix(traceId); - VLOG(1) << "set trace id:" << traceId; + VLOG(1) << "Set trace id: " << traceId; return Status::OK(); } void Context::SetTenantId(const std::string &tenantId) { g_ContextTenantId = tenantId; - LOG(INFO) << "set tenant id : " << g_ContextTenantId; + LOG(INFO) << "Set tenant id: " << g_ContextTenantId; } } // namespace datasystem \ No newline at end of file diff --git a/src/datasystem/client/hetero_cache/device_buffer.h b/src/datasystem/client/hetero_cache/device_buffer.h index f67b0f2..f6f333f 100644 --- a/src/datasystem/client/hetero_cache/device_buffer.h +++ b/src/datasystem/client/hetero_cache/device_buffer.h @@ -31,6 +31,7 @@ #include #include "datasystem/client/hetero_cache/device_util.h" +#include "datasystem/hetero/device_common.h" #include "datasystem/hetero/future.h" #include "datasystem/object/object_enum.h" #include "datasystem/utils/status.h" @@ -82,11 +83,11 @@ public: /// \return Status of the result. Return error if lifetime is not MOVE. Status GetSendStatus(std::vector &futureVec); - /// \brief Gets the list of DataInfo. + /// \brief Gets the list of Blob. /// /// - /// \return The list of DataInfo. - std::vector GetDataInfoList() const; + /// \return The list of Blob. + std::vector GetDevBlobList() const; /// \brief Detach a directory location /// @@ -123,7 +124,7 @@ private: /// \brief Release the buffer owned resources. void Release(); - + /// \brief Get the device memory unit. /// /// \return The device memory unit. diff --git a/src/datasystem/client/listen_worker.cpp b/src/datasystem/client/listen_worker.cpp index 2877de8..2be38ce 100644 --- a/src/datasystem/client/listen_worker.cpp +++ b/src/datasystem/client/listen_worker.cpp @@ -107,6 +107,7 @@ Status ListenWorker::StartListenWorker(int socketFd) nullptr)); } else { workerListenedThread_ = Thread(&ListenWorker::CheckHeartbeat, this); + workerListenedThread_.set_name("ListenWorker"); firstHeartbeatWaitPost_->WaitFor(clientCommonWorker_->GetConnectTimeoutMs()); INJECT_POINT("listen_worker.StartListenWorker"); if (!firstHeartbeatReceived_.load()) { @@ -268,7 +269,7 @@ void ListenWorker::RunAllCallback() LOG(INFO) << "All callback size: " << callBackTable_.size() << ", local worker: " << isLocalWorker_; auto traceId = Trace::Instance().GetTraceID(); auto func = [this, traceId]() { - auto traceGuard = Trace::Instance().SetSubTraceID(traceId); + auto traceGuard = Trace::Instance().SetTraceNewID(traceId); std::shared_lock l(callbackMutex_); for (const auto &func : callBackTable_) { if (stop_) { diff --git a/src/datasystem/client/mmap_manager.h b/src/datasystem/client/mmap_manager.h index a336a21..7455fe3 100644 --- a/src/datasystem/client/mmap_manager.h +++ b/src/datasystem/client/mmap_manager.h @@ -44,7 +44,7 @@ public: /** * @brief Loop the input share memory unit and mmap the fd if it was not mmapped in the client. - * @param[in] tenantId for producer consumer get clientfd by tenantId. + * @param[in] tenantId for stream producer consumer get clientfd by tenantId. * @param[in] unit The input share memory unit. * @return Status of the call. */ @@ -52,7 +52,7 @@ public: /** * @brief Loop the input share memory unit and mmap the fd if it was not mmapped in the client. - * @param[in] tenantId for producer consumer get clientfd by tenantId + * @param[in] tenantId for stream producer consumer get clientfd by tenantId * @param[in] units The input share memory unit. * @return Status of the call. */ diff --git a/src/datasystem/client/object_cache/client_worker_api.cpp b/src/datasystem/client/object_cache/client_worker_api.cpp index 221f498..9b4e970 100644 --- a/src/datasystem/client/object_cache/client_worker_api.cpp +++ b/src/datasystem/client/object_cache/client_worker_api.cpp @@ -144,10 +144,11 @@ Status ClientWorkerApi::Create(const std::string &objectKey, int64_t dataSize, u return Status::OK(); } -Status ClientWorkerApi::MultiCreate(std::vector &createParams, uint32_t &version) +Status ClientWorkerApi::MultiCreate(bool skipCheckExistence, std::vector &createParams, + uint32_t &version, std::vector &exists, bool &useShmTransfer) { MultiCreateReqPb req; - + req.set_skip_check_existence(skipCheckExistence); req.set_client_id(GetClientId()); for (auto ¶m : createParams) { req.add_object_key(param.objectKey); @@ -172,7 +173,26 @@ Status ClientWorkerApi::MultiCreate(std::vector &createParams, createParams.size() == static_cast(rsp.results().size()), K_INVALID, FormatString("The length of objectKeyList (%zu) and dataSizeList (%zu) should be the same.", createParams.size(), rsp.results().size())); + auto rspExists = rsp.exists(); + CHECK_FAIL_RETURN_STATUS(static_cast(rspExists.size()) == createParams.size(), K_INVALID, + "The size of rspExists is not consistent with createParams"); + exists.reserve(createParams.size()); + for (auto val : rspExists) { + exists.emplace_back(val); + } + for (auto res : rsp.results()) { + if (!res.shm_id().empty()) { + useShmTransfer = true; + break; + } + } + if (!useShmTransfer) { + return Status::OK(); + } for (auto i = 0ul; i < createParams.size(); i++) { + if (exists[i]) { + continue; + } auto &shmBuf = createParams[i].shmBuf; auto subRsp = rsp.results()[i]; shmBuf->fd = subRsp.store_fd(); @@ -374,6 +394,7 @@ Status ClientWorkerApi::MultiPublish(const std::vector(param.existence)); req.set_is_replica(param.isReplica); + req.set_auto_release_memory_ref(!bufferInfo[0]->shmId.empty()); std::vector payloads; for (size_t i = 0; i < bufferInfo.size(); ++i) { if (bufferInfo[i]->shmId.empty()) { @@ -798,28 +819,28 @@ Status ClientWorkerApi::SubscribeReceiveEvent(int32_t deviceId, SubscribeReceive } void ClientWorkerApi::FillDevObjMeta(const std::shared_ptr &bufferInfo, - const std::vector &dataInfoList, DeviceObjectMetaPb *metaPb) + const std::vector &blobs, DeviceObjectMetaPb *metaPb) { metaPb->set_object_key(bufferInfo->devObjKey); metaPb->set_lifetime(LifetimeParamPb(static_cast(bufferInfo->lifetimeType))); auto loc = metaPb->add_locations(); loc->set_client_id(GetClientId()); loc->set_device_id(bufferInfo->deviceIdx); - for (const auto &dataInfo : dataInfoList) { - const auto &dataInfos = metaPb->add_data_infos(); - dataInfos->set_data_type(static_cast(dataInfo.dataType)); - dataInfos->set_count(dataInfo.count); + for (const auto &blob : blobs) { + const auto &blobInfos = metaPb->add_data_infos(); + blobInfos->set_data_type(static_cast(DataType::DATA_TYPE_INT8)); + blobInfos->set_count(blob.size); } metaPb->set_src_offset(bufferInfo->srcOffset); } Status ClientWorkerApi::PutP2PMeta(const std::shared_ptr &bufferInfo, - const std::vector &dataInfoList) + const std::vector &blobs) { PutP2PMetaReqPb req; PutP2PMetaRspPb resp; auto subReq = req.add_dev_obj_meta(); - FillDevObjMeta(bufferInfo, dataInfoList, subReq); + FillDevObjMeta(bufferInfo, blobs, subReq); RpcOptions opts; opts.SetTimeout(timeoutMs_); INJECT_POINT("ClientWorkerApi.PutP2PMeta.timeoutDuration", [](int time) { @@ -835,7 +856,7 @@ Status ClientWorkerApi::PutP2PMeta(const std::shared_ptr &buff } Status ClientWorkerApi::GetP2PMeta(std::vector> &bufferInfoList, - std::vector> &dataInfoList, GetP2PMetaRspPb &resp, + std::vector &devBlobList, GetP2PMetaRspPb &resp, int64_t subTimeoutMs) { INJECT_POINT("GETP2PMeta.subTimeoutMs", [&subTimeoutMs](int64_t t) { @@ -847,13 +868,13 @@ Status ClientWorkerApi::GetP2PMeta(std::vector Validator::IsInNonNegativeInt32(timeoutMs), K_INVALID, FormatString("timeoutMs %d is out of range., which should be between [%d, %d]", timeoutMs, 0, INT32_MAX)); GetP2PMetaReqPb req; - if (bufferInfoList.size() != dataInfoList.size()) { + if (bufferInfoList.size() != devBlobList.size()) { LOG(ERROR) << "buffer info list size not matching data info list size"; return Status(K_INVALID, "buffer info list size not matching data info list size"); } for (size_t i = 0; i < bufferInfoList.size(); i++) { auto subReq = req.add_dev_obj_meta(); - FillDevObjMeta(bufferInfoList[i], dataInfoList[i], subReq); + FillDevObjMeta(bufferInfoList[i], devBlobList[i].blobs, subReq); } req.set_sub_timeout(subTimeoutMs); RpcOptions opts; @@ -867,9 +888,8 @@ Status ClientWorkerApi::GetP2PMeta(std::vector return stub_->GetP2PMeta(opts, req, resp); } -Status ClientWorkerApi::SendRootInfo(SendRootInfoReqPb &req) +Status ClientWorkerApi::SendRootInfo(SendRootInfoReqPb &req, SendRootInfoRspPb &resp) { - SendRootInfoRspPb resp; RpcOptions opts; opts.SetTimeout(timeoutMs_); reqTimeoutDuration.Init(ClientGetRequestTimeout(timeoutMs_)); @@ -910,7 +930,7 @@ Status ClientWorkerApi::AckRecvFinish(AckRecvFinishReqPb &req) return stub_->AckRecvFinish(opts, req, resp); } -Status ClientWorkerApi::GetDataInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &dataInfos) +Status ClientWorkerApi::GetBlobsInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &blobs) { RpcOptions opts; auto rpcTimeout = std::max(timeoutMs, rpcTimeoutMs_); @@ -928,11 +948,11 @@ Status ClientWorkerApi::GetDataInfo(const std::string &devObjKey, int32_t timeou GetDataInfoRspPb resp; PerfPoint perfPoint(PerfKey::RPC_HETERO_CLIENT_GET_DATA_INFO); RETURN_IF_NOT_OK(stub_->GetDataInfo(opts, req, resp)); - // Obtains dataInfos from resp + // Obtains the blobs from resp std::vector dataInfoPbs = { resp.data_infos().begin(), resp.data_infos().end() }; - dataInfos.reserve(dataInfoPbs.size()); + blobs.reserve(dataInfoPbs.size()); for (const auto &dataInfoPb : dataInfoPbs) { - dataInfos.emplace_back(DataInfo{ nullptr, static_cast(dataInfoPb.data_type()), dataInfoPb.count() }); + blobs.emplace_back(Blob{ nullptr, dataInfoPb.count() }); } return Status::OK(); } diff --git a/src/datasystem/client/object_cache/client_worker_api.h b/src/datasystem/client/object_cache/client_worker_api.h index 58c6f26..ecaf9b2 100644 --- a/src/datasystem/client/object_cache/client_worker_api.h +++ b/src/datasystem/client/object_cache/client_worker_api.h @@ -36,6 +36,7 @@ #include "datasystem/common/util/strings_util.h" #include "datasystem/common/util/uuid_generator.h" #include "datasystem/common/util/status_helper.h" +#include "datasystem/hetero/device_common.h" #include "datasystem/object/buffer.h" #include "datasystem/client/hetero_cache/device_util.h" #include "datasystem/protos/master_object.pb.h" @@ -287,29 +288,29 @@ public: /** * @brief Put the p2p metadata to worker. * @param[in] bufferInfo The info of device buffer. - * @param[in] dataInfoList The list of data info. + * @param[in] blobs The list of device blob. * @return Status of the call */ - Status PutP2PMeta(const std::shared_ptr &bufferInfo, const std::vector &dataInfoList); + Status PutP2PMeta(const std::shared_ptr &bufferInfo, const std::vector &blobs); /** * @brief Get the p2p metadata from worker. * @param[in] bufferInfoList The info of device buffer. - * @param[in] dataInfoList The list of data info. + * @param[in] devBlobList The list of device blob. * @param[out] resp The response of the rpc call. * @param[in] subTimeoutMs The maximum time elapse of subscriptions. * @return Status of the call */ Status GetP2PMeta(std::vector> &bufferInfoList, - std::vector> &dataInfoList, GetP2PMetaRspPb &resp, - int64_t subTimeoutMs = 500); + std::vector &devBlobList, GetP2PMetaRspPb &resp, int64_t subTimeoutMs = 500); /** * @brief Send the root info to worker. * @param[in] req The request of the call. + * @param[out] resp The response of the call. * @return Status of the call. */ - Status SendRootInfo(SendRootInfoReqPb &req); + Status SendRootInfo(SendRootInfoReqPb &req, SendRootInfoRspPb &resp); /** * @brief Receive the root info from worker. @@ -320,14 +321,14 @@ public: Status RecvRootInfo(RecvRootInfoReqPb &req, RecvRootInfoRspPb &resp); /** - * @brief Obtains the DataInfos, including the number of DataInfo, and the count and DataType of each DataInfo. + * @brief Obtains the BlobInfos, including the number of DataInfo, and the count and DataType of each DataInfo. * @param[in] devObjKey The object key. * @param[in] timeoutMs Waiting for the result return if object not ready. A positive integer number required. * 0 means no waiting time allowed. And the range is [0, INT32_MAX]. - * @param[out] dataInfos The list of data info. (Include pointer、count and data type) + * @param[out] blobs The list of blob info. (Include pointer、count and data type) * @return K_OK on any object success; the error code otherwise. */ - Status GetDataInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &dataInfos); + Status GetBlobsInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &blobs); /** * @brief Acknowledge the get operation is ok, and indicate whether worker as data providers @@ -359,11 +360,15 @@ public: /** * @brief Create multiple objects in the store-server. + * @param[in] skipCheckExistence Whether skip existence check. * @param[out] createParams The params for create operation. * @param[out] version Object version. + * @param[out] exists The Key exist list. + * @param[out] useShmTransfer Transfer by shm. * @return K_OK on success; the error code otherwise. */ - Status MultiCreate(std::vector &createParams, uint32_t &version); + Status MultiCreate(bool skipCheckExistence, std::vector &createParams, uint32_t &version, + std::vector &exists, bool &useShmTransfer); /** * @brief Invoke worker client to query the size of objectKeys (include the objectKeys of other AZ). @@ -448,10 +453,10 @@ private: /** * @brief Fill device object meta to Pb. * @param[in] bufferInfo The info of device buffer. - * @param[in] dataInfoList The list of data info. + * @param[in] blobs The list of blob info. * @param[out] metaPb The device object meta pb. */ - void FillDevObjMeta(const std::shared_ptr &bufferInfo, const std::vector &dataInfoList, + void FillDevObjMeta(const std::shared_ptr &bufferInfo, const std::vector &blobs, DeviceObjectMetaPb *metaPb); // To protect the decreaseRPCQ_ and waitRespMap_ from being manipulated by different threads of the same client. diff --git a/src/datasystem/client/object_cache/device/client_device_object_manager.cpp b/src/datasystem/client/object_cache/device/client_device_object_manager.cpp index 919fc11..a367c63 100644 --- a/src/datasystem/client/object_cache/device/client_device_object_manager.cpp +++ b/src/datasystem/client/object_cache/device/client_device_object_manager.cpp @@ -57,39 +57,27 @@ Status ClientDeviceObjectManager::Init() return Status::OK(); } -Status ClientDeviceObjectManager::CreateDevBuffer(const std::string &devObjKey, uint64_t size, void *devPtr, - int32_t deviceIdx, std::shared_ptr &deviceBuffer) -{ - std::vector dataInfoList{ { devPtr, DataType::DATA_TYPE_INT8, size } }; - auto bufferInfo = - std::make_shared(devObjKey, deviceIdx, LifetimeType::REFERENCE, true, TransferType::HOST); - return CreateDevBufferImpl(bufferInfo, dataInfoList, deviceBuffer); -} - -Status ClientDeviceObjectManager::CreateDevBuffer(const std::string &devObjKey, - const std::vector &dataInfoList, int32_t deviceIdx, +Status ClientDeviceObjectManager::CreateDevBuffer(const std::string &devObjKey, const DeviceBlobList &devBlobList, const CreateDeviceParam ¶m, std::shared_ptr &deviceBuffer) { - auto bufferInfo = std::make_shared(devObjKey, deviceIdx, param.lifetime, param.cacheLocation, - TransferType::P2P); - return CreateDevBufferImpl(bufferInfo, dataInfoList, deviceBuffer); + auto bufferInfo = std::make_shared(devObjKey, devBlobList.deviceIdx, param.lifetime, + param.cacheLocation, TransferType::P2P); + return CreateDevBufferImpl(bufferInfo, devBlobList, deviceBuffer); } Status ClientDeviceObjectManager::CreateDevBufferImpl(std::shared_ptr bufferInfo, - const std::vector &dataInfoList, + const DeviceBlobList &devBlobList, std::shared_ptr &deviceBuffer) { // check input parameter RETURN_IF_NOT_OK(objClientImpl_->IsClientReady()); CHECK_FAIL_RETURN_STATUS(!bufferInfo->devObjKey.empty(), K_INVALID, "The devObjKey is empty"); - CHECK_FAIL_RETURN_STATUS(!dataInfoList.empty(), K_INVALID, "The dataInfoList is empty"); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(bufferInfo->devObjKey), K_INVALID, - "The devObjKey maybe contains illegal char(s) or the length of id is > 255."); + CHECK_FAIL_RETURN_STATUS(!devBlobList.blobs.empty(), K_INVALID, "The devBlobList is empty"); + RETURN_IF_NOT_OK(ObjectClientImpl::CheckValidObjectKey(bufferInfo->devObjKey)); int32_t deviceIdxNow = -1; - RETURN_IF_NOT_OK_APPEND_MSG(devInterImpl_->GetDeviceIdx(deviceIdxNow), - "May not create context or set device in this thread."); + RETURN_IF_NOT_OK(devInterImpl_->GetDeviceIdx(deviceIdxNow)); auto &deviceIdx = bufferInfo->deviceIdx; if (deviceIdx < 0) { deviceIdx = deviceIdxNow; @@ -99,11 +87,11 @@ Status ClientDeviceObjectManager::CreateDevBufferImpl(std::shared_ptr 0, K_INVALID, "The size value should be bigger than zero."); + for (auto &blob : devBlobList.blobs) { + CHECK_FAIL_RETURN_STATUS(blob.size > 0, K_INVALID, "The size value should be bigger than zero."); } - auto memUnit = std::make_shared(bufferInfo->devObjKey, dataInfoList); + auto memUnit = std::make_shared(bufferInfo->devObjKey, devBlobList.blobs); RETURN_IF_NOT_OK(memUnit->MallocDeviceMemoryIfUserNotSet()); deviceBuffer = DeviceBuffer::CreateDeviceBuffer(bufferInfo, memUnit, objClientImpl_->shared_from_this()); @@ -123,17 +111,17 @@ Status ClientDeviceObjectManager::PublishDeviceObject(const std::shared_ptr &buffer) { auto &bufferInfo = buffer->bufferInfo_; - DataInfo dataInfo; - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(buffer->GetDeviceMemUnit()->CheckAndGetSingleDataInfo(dataInfo), - "The device object with host buffer just support single dataInfo for now"); - auto dataSize = dataInfo.Size(); + Blob blob; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(buffer->GetDeviceMemUnit()->CheckAndGetSingleBlob(blob), + "The device object with host buffer just support single blob for now"); + auto dataSize = blob.size; auto &devObjKey = bufferInfo->devObjKey; std::shared_ptr hostBuffer; RETURN_IF_NOT_OK(objClientImpl_->Create(devObjKey, dataSize, {}, hostBuffer)); auto hostBufferInfo = ObjectClientImpl::GetBufferInfo(hostBuffer); bufferInfo->shmId = hostBufferInfo->shmId; bufferInfo->version = hostBufferInfo->version; - RETURN_IF_NOT_OK(devInterImpl_->MemCopyD2H(hostBuffer->MutableData(), dataSize, dataInfo.devPtr, dataSize)); + RETURN_IF_NOT_OK(devInterImpl_->MemCopyD2H(hostBuffer->MutableData(), dataSize, blob.pointer, dataSize)); std::shared_ptr workerApi; RETURN_IF_NOT_OK(objClientImpl_->GetAvailableWorkerApi(workerApi)); return workerApi->PublishDeviceObject(bufferInfo, dataSize, !bufferInfo->shmId.empty(), hostBuffer->MutableData()); @@ -152,7 +140,7 @@ Status ClientDeviceObjectManager::GetDevBufferWithHost(const std::vectorIsClientReady()); - RETURN_IF_NOT_OK(objClientImpl_->CheckStringVector(devObjKeys)); + RETURN_IF_NOT_OK(ObjectClientImpl::CheckValidObjectKeyVector(devObjKeys)); if (devObjKeys.size() > 1 || !map.empty()) { RETURN_STATUS(K_INVALID, "The resharding get is not supported now, please keep the devObjKeys only have one objectKey and " @@ -285,13 +273,13 @@ Status ClientDeviceObjectManager::GetSendStatus(const std::shared_ptr> &dataInfoList, +Status ClientDeviceObjectManager::MemCopyBetweenDevAndHost(const std::vector &devBlobList, std::vector &bufferList, aclrtMemcpyKind copyKind, bool enableHugeTlb) { DeviceBatchCopyHelper helper; - RETURN_IF_NOT_OK(helper.Prepare(dataInfoList, bufferList, copyKind)); - auto deviceId = dataInfoList[0][0].deviceIdx; + RETURN_IF_NOT_OK(helper.Prepare(devBlobList, bufferList, copyKind)); + auto deviceId = devBlobList[0].deviceIdx; aclResourceMgr_.SetD2HPolicyByHugeTlb(enableHugeTlb); INJECT_POINT("NO_USE_FFTS", [this]() { aclResourceMgr_.SetPolicyDirect(); diff --git a/src/datasystem/client/object_cache/device/client_device_object_manager.h b/src/datasystem/client/object_cache/device/client_device_object_manager.h index ed4f210..5c1c281 100644 --- a/src/datasystem/client/object_cache/device/client_device_object_manager.h +++ b/src/datasystem/client/object_cache/device/client_device_object_manager.h @@ -31,6 +31,7 @@ #include "datasystem/client/object_cache/device/device_memory_unit.h" #include "datasystem/client/object_cache/device/p2p_subscribe.h" #include "datasystem/common/device/ascend/acl_device_manager.h" +#include "datasystem/hetero/device_common.h" #include "datasystem/object_client.h" #include "datasystem/utils/status.h" @@ -83,20 +84,20 @@ struct DeviceBatchCopyHelper { return (address & alignmentMask) == 0; } - Status Prepare(const std::vector> &dataInfoList, std::vector &bufferList, + Status Prepare(const std::vector &devBlobList, std::vector &bufferList, aclrtMemcpyKind copyKind) { std::vector hostPointerList; std::vector devPointerList; std::vector hostBuffers; std::vector deviceBuffers; - hostBuffers.reserve(dataInfoList.size()); - deviceBuffers.reserve(dataInfoList.size()); - CHECK_FAIL_RETURN_STATUS(!dataInfoList.empty(), K_INVALID, "The dataInfoList is empty."); + hostBuffers.reserve(devBlobList.size()); + deviceBuffers.reserve(devBlobList.size()); + CHECK_FAIL_RETURN_STATUS(!devBlobList.empty(), K_INVALID, "The devBlobList is empty."); CHECK_FAIL_RETURN_STATUS(!bufferList.empty(), K_INVALID, "The bufferList is empty."); size_t keyStartInBlobs = 0; - for (size_t i = 0; i < dataInfoList.size(); i++) { - auto &dataInfos = dataInfoList[i]; + for (size_t i = 0; i < devBlobList.size(); i++) { + auto &blobs = devBlobList[i].blobs; if (bufferList[i] == nullptr) { continue; } @@ -106,18 +107,18 @@ struct DeviceBatchCopyHelper { auto sz = *offsetArrPtr; auto offsets = offsetArrPtr + 1; CHECK_FAIL_RETURN_STATUS( - sz == dataInfos.size() && sz > 0, K_INVALID, + sz == blobs.size() && sz > 0, K_INVALID, FormatString("Blobs count mismatch in devBlobList between sender and receiver, sender count is: %ld, " "receiver count is: %ld, mismatch devBlobList index: %s, mismatch key index: %s", - sz, dataInfos.size(), i, i)); + sz, blobs.size(), i, i)); size_t dataSize = buffer->GetSize() - offsets[0]; bufferMetas.emplace_back( - BufferMetaInfo{ .blobCount = dataInfos.size(), .firstBlobOffset = keyStartInBlobs, .size = dataSize }); + BufferMetaInfo{ .blobCount = blobs.size(), .firstBlobOffset = keyStartInBlobs, .size = dataSize }); hostBuffers.emplace_back(BufferView{ .ptr = hostRawPointer + offsets[0], .size = dataSize }); - for (size_t j = 0; j < dataInfos.size(); j++) { + for (size_t j = 0; j < blobs.size(); j++) { auto hostDataSize = offsets[j + 1] - offsets[j]; - auto devicePointer = dataInfos[j].devPtr; - auto deviceDataSize = dataInfos[j].Size(); + auto devicePointer = blobs[j].pointer; + auto deviceDataSize = blobs[j].size; auto hostPointer = hostRawPointer + offsets[j]; if (!is64BitAligned(hostPointer)) { LOG(WARNING) << "host memory is not 64 aligned: " << hostRawPointer; @@ -133,7 +134,7 @@ struct DeviceBatchCopyHelper { dataSizeList.emplace_back(hostDataSize); batchSize++; } - keyStartInBlobs += dataInfos.size(); + keyStartInBlobs += blobs.size(); } if (copyKind == aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE) { srcBuffers = std::move(hostBuffers); @@ -169,20 +170,6 @@ public: Status Init(); - /** - * @brief Invoke worker client to create a device object. - * @param[in] objectKey The Key of the device object to create. Key should not be empty and should only contains - * english alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. Key length should less than 256. - * @param[in] size The size in bytes of device object. - * @param[in] devPtr The device memory pointer. Pass the pointer if user want do malloc by self. - * Pass the nullptr then client will malloc device memory and free when DeviceBuffer is destructed. - * @param[in] deviceIdx The device index of the device memory. - * @param[out] deviceBuffer The device buffer for the object. - * @return Status K_OK on success; the error code otherwise. - */ - Status CreateDevBuffer(const std::string &devObjKey, uint64_t size, void *devPtr, int32_t deviceIdx, - std::shared_ptr &deviceBuffer); - /** * @brief Publish device object to datasystem with host. * @param[in] buffer The device buffer ready to publish. @@ -233,13 +220,12 @@ public: * @brief Invoke worker client to create a device object with p2p. * @param[in] objectKey The Key of the device object to create. Key should not be empty and should only contains * english alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. Key length should less than 256. - * @param[in] dataInfoList The list of data info. - * @param[in] deviceIdx The device index of the device memory. + * @param[in] devBlobList The list of blob info. * @param[in] param The create param of device object. * @param[out] deviceBuffer The device buffer for the object. * @return Status K_OK on success; the error code otherwise. */ - Status CreateDevBuffer(const std::string &devObjKey, const std::vector &dataInfoList, int32_t deviceIdx, + Status CreateDevBuffer(const std::string &devObjKey, const DeviceBlobList &devBlobList, const CreateDeviceParam ¶m, std::shared_ptr &deviceBuffer); /** @@ -264,11 +250,11 @@ public: /** * @brief The implement of create device buffer. * @param[in] bufferInfo The info of device buffer. - * @param[in] dataInfoList The list of data info. + * @param[in] devBlobList The list of blob info. * @param[out] deviceBuffer The device buffer for the object. * @return Status K_OK on success; the error code otherwise. */ - Status CreateDevBufferImpl(std::shared_ptr bufferInfo, const std::vector &dataInfoList, + Status CreateDevBufferImpl(std::shared_ptr bufferInfo, const DeviceBlobList &devBlobList, std::shared_ptr &deviceBuffer); /** @@ -279,15 +265,15 @@ public: Status GetSendStatus(const std::shared_ptr &buffer, std::vector &futureVec); /** - * @brief The memory copy between dataInfoList and bufferList - * @param[in] dataInfoList The 2D list of dataInfo. + * @brief The memory copy between devBlobList and bufferList + * @param[in] devBlobList The 2D list of blob info. * @param[in] bufferList The list of buffer. * @param[in] copyKind The memory copy kind in CANN. * @param[in] enableHugeTlb The memory is enable huge tlb. * @return Status K_OK on success; the error code otherwise. */ - Status MemCopyBetweenDevAndHost(const std::vector> &dataInfoList, - std::vector &bufferList, aclrtMemcpyKind copyKind, bool enableHugeTlb); + Status MemCopyBetweenDevAndHost(const std::vector &devBlobList, std::vector &bufferList, + aclrtMemcpyKind copyKind, bool enableHugeTlb); /** * @brief Print MSetD2H detail info diff --git a/src/datasystem/client/object_cache/device/device_memory_unit.cpp b/src/datasystem/client/object_cache/device/device_memory_unit.cpp index 6c79753..6c805c9 100644 --- a/src/datasystem/client/object_cache/device/device_memory_unit.cpp +++ b/src/datasystem/client/object_cache/device/device_memory_unit.cpp @@ -24,64 +24,62 @@ namespace datasystem { -DeviceMemoryUnit::DeviceMemoryUnit(const std::string &devMemId, std::vector dataInfoStorage) - : devMemId_(devMemId), - dataInfoStorage_(std::move(dataInfoStorage)), - dsAllocatedStorage_(dataInfoStorage_.size(), false) +DeviceMemoryUnit::DeviceMemoryUnit(const std::string &devMemId, std::vector blobStorage) + : devMemId_(devMemId), blobStorage_(std::move(blobStorage)), dsAllocatedStorage_(blobStorage_.size(), false) { } Status DeviceMemoryUnit::MallocDeviceMemoryIfUserNotSet() { - CHECK_FAIL_RETURN_STATUS(dataInfoStorage_.size() == dsAllocatedStorage_.size(), K_RUNTIME_ERROR, - "The size of dataInfoStorage and dsAllocatedStorage is not same."); - for (auto i = 0ul; i < dataInfoStorage_.size(); i++) { - auto &dataInfo = dataInfoStorage_[i]; - if (dataInfo.devPtr == nullptr) { - VLOG(1) << "Malloc device memory, size: " << dataInfo.Size(); - RETURN_IF_NOT_OK(acl::AclDeviceManager::Instance()->MallocDeviceMemory(dataInfo.Size(), dataInfo.devPtr)); + CHECK_FAIL_RETURN_STATUS(blobStorage_.size() == dsAllocatedStorage_.size(), K_RUNTIME_ERROR, + "The size of blobStorage and dsAllocatedStorage is not same."); + for (auto i = 0ul; i < blobStorage_.size(); i++) { + auto &blob = blobStorage_[i]; + if (blob.pointer == nullptr) { + VLOG(1) << "Malloc device memory, size: " << blob.size; + RETURN_IF_NOT_OK(acl::AclDeviceManager::Instance()->MallocDeviceMemory(blob.size, blob.pointer)); dsAllocatedStorage_[i] = true; } } return Status::OK(); } -const std::vector &DeviceMemoryUnit::GetDataInfoStorage() const +const std::vector &DeviceMemoryUnit::GetBlobsStorage() const { - return dataInfoStorage_; + return blobStorage_; } -Status DeviceMemoryUnit::CheckAndGetSingleDataInfo(DataInfo &dataInfo) const +Status DeviceMemoryUnit::CheckAndGetSingleBlob(Blob &blob) const { - if (dataInfoStorage_.empty()) { + if (blobStorage_.empty()) { RETURN_STATUS(K_RUNTIME_ERROR, "The list of data info in device buffer is empty."); } - if (dataInfoStorage_.size() > 1) { + if (blobStorage_.size() > 1) { RETURN_STATUS(K_RUNTIME_ERROR, "The size of data info list in device buffer > 1"); } - dataInfo = dataInfoStorage_[0]; + blob = blobStorage_[0]; return Status::OK(); } DeviceMemoryUnit::~DeviceMemoryUnit() { std::vector freeIndexVec; - freeIndexVec.reserve(dataInfoStorage_.size()); - for (auto i = 0ul; i < dataInfoStorage_.size(); i++) { - const auto &dataInfo = dataInfoStorage_[i]; - if (dsAllocatedStorage_[i] && dataInfo.devPtr) { + freeIndexVec.reserve(blobStorage_.size()); + for (auto i = 0ul; i < blobStorage_.size(); i++) { + const auto &blob = blobStorage_[i]; + if (dsAllocatedStorage_[i] && blob.pointer) { freeIndexVec.push_back(i); - LOG_IF_ERROR(acl::AclDeviceManager::Instance()->FreeDeviceMemory(dataInfo.devPtr), + LOG_IF_ERROR(acl::AclDeviceManager::Instance()->FreeDeviceMemory(blob.pointer), "Release device memory allocated by datasystem failed."); } } - VLOG(1) << "Free device memory unit: " << devMemId_ << ", dataInfo index: " << VectorToString(freeIndexVec); + VLOG(1) << "Free device memory unit: " << devMemId_ << ", blob index: " << VectorToString(freeIndexVec); } Status DeviceMemoryUnit::CheckEmptyPointer() const { - for (auto i = 0ul; i < dataInfoStorage_.size(); i++) { - CHECK_FAIL_RETURN_STATUS(dataInfoStorage_[i].devPtr != nullptr, K_INVALID, + for (auto i = 0ul; i < blobStorage_.size(); i++) { + CHECK_FAIL_RETURN_STATUS(blobStorage_[i].pointer != nullptr, K_INVALID, FormatString("The device pointer [index: %zu] in device buffer is nullptr.", i)); } return Status::OK(); diff --git a/src/datasystem/client/object_cache/device/device_memory_unit.h b/src/datasystem/client/object_cache/device/device_memory_unit.h index c465371..8e1e12e 100644 --- a/src/datasystem/client/object_cache/device/device_memory_unit.h +++ b/src/datasystem/client/object_cache/device/device_memory_unit.h @@ -28,11 +28,12 @@ #include "datasystem/common/immutable_string/immutable_string.h" #include "datasystem/client/hetero_cache/device_util.h" #include "datasystem/common/util/status_helper.h" +#include "datasystem/hetero/device_common.h" namespace datasystem { class DeviceMemoryUnit { public: - explicit DeviceMemoryUnit(const std::string &devMemId, std::vector dataInfoStorage); + explicit DeviceMemoryUnit(const std::string &devMemId, std::vector blobStorage); /** * @brief Malloc device memory if user not set the device pointer in data info list. @@ -44,14 +45,14 @@ public: * @brief Get the data info list. * @return The list of data info. */ - const std::vector &GetDataInfoStorage() const; + const std::vector &GetBlobsStorage() const; /** * @brief Check and get the single data info. - * @param[out] dataInfo The data info. + * @param[out] blob The data info. * @return Return ok only if memory unit have only one data info. */ - Status CheckAndGetSingleDataInfo(DataInfo &dataInfo) const; + Status CheckAndGetSingleBlob(Blob &blob) const; /** * @brief Check if the device pointer is nullptr in data info list. @@ -63,7 +64,7 @@ public: private: std::string devMemId_; - std::vector dataInfoStorage_; + std::vector blobStorage_; std::deque dsAllocatedStorage_; }; } // namespace datasystem diff --git a/src/datasystem/client/object_cache/device/hccl_comm_factory.cpp b/src/datasystem/client/object_cache/device/hccl_comm_factory.cpp index c9ba271..6965f21 100644 --- a/src/datasystem/client/object_cache/device/hccl_comm_factory.cpp +++ b/src/datasystem/client/object_cache/device/hccl_comm_factory.cpp @@ -21,6 +21,7 @@ #include "datasystem/common/device/ascend/cann_types.h" #include "datasystem/common/device/ascend/hccl_comm_wrapper.h" #include "datasystem/common/device/device_helper.h" +#include "datasystem/common/inject/inject_point.h" #include "datasystem/utils/status.h" namespace datasystem { @@ -63,6 +64,15 @@ HcclCommFactory::~HcclCommFactory() ShutDown(); } +Status HcclCommFactory::SetStateIfError(std::shared_ptr &comm, Status status) +{ + if (status.IsError()) { + comm->SetHcclDetailState(status); + return status; + } + return Status::OK(); +} + std::string HcclCommFactory::GetHcclCommKey(P2PEventType eventType, int32_t localDeviceId, const std::string &remoteClientId, int32_t remoteDeviceId) { @@ -80,11 +90,13 @@ Status HcclCommFactory::GetOrCreateHcclComm(P2PEventType eventType, int32_t loca PerfPoint perfPoint(PerfKey::GET_OR_CREATE_HCCL_COMMONE); auto commKey = GetHcclCommKey(eventType, localDeviceId, remoteClientId, remoteDeviceId); TbbHcclCommTable::accessor acc; + VLOG(1) << FormatString("Trying to acquire read lock, commKey: %s", commKey); std::shared_lock lock(mutex_); if (commTable_.find(acc, commKey)) { comm = acc->second; return CreateHcclCommCheckError(comm); } + // If insert failed, mean the hccl comm exist, get it and return. if (!commTable_.insert(acc, commKey)) { comm = acc->second; @@ -115,51 +127,69 @@ Status HcclCommFactory::GetOrCreateHcclComm(P2PEventType eventType, int32_t loca return CreateHcclCommCheckError(comm); } -Status HcclCommFactory::InitRootInfoReq(int32_t localDeviceId, int32_t remoteDeviceId, - const std::string &remoteClientId, HcclRootInfo &rootInfo) -{ - RecvRootInfoReqPb rootInfoReq; - auto localClientId = clientWorkerApi_->GetClientId(); - rootInfoReq.set_dst_client_id(remoteClientId); - rootInfoReq.set_dst_device_id(remoteDeviceId); - rootInfoReq.set_src_client_id(localClientId); - rootInfoReq.set_src_device_id(localDeviceId); - RecvRootInfoRspPb rootInfoResp; - auto peerId = GetHcclPeerId(localClientId, localDeviceId, remoteClientId, remoteDeviceId); - LOG(INFO) << FormatString("Start to recv RootInfo from worker, peerId: %s", peerId); - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerApi_->RecvRootInfo(rootInfoReq, rootInfoResp), "Failed with receive"); - - CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( - rootInfoResp.root_info().length() == sizeof(rootInfo.internal), K_RUNTIME_ERROR, - "The rsp rootInfo size is not as expected: " + std::to_string(rootInfoResp.root_info().length())); - auto ret = memcpy_s(static_cast(rootInfo.internal), sizeof(rootInfo.internal), - static_cast(rootInfoResp.root_info().c_str()), rootInfoResp.root_info().length()); - if (ret != EOK) { - RETURN_STATUS(K_RUNTIME_ERROR, FormatString("Copy root info failed, the memcpy_s return: %d", ret)); - } - LOG(INFO) << "Sender start init hccl comm"; - PrintRootInfo(rootInfo); - return Status::OK(); -} - void HcclCommFactory::CreateHcclCommInSend(int32_t localDeviceId, const std::string &remoteClientId, int32_t remoteDeviceId, bool isSameNode, std::shared_ptr &comm) { - auto process = [this, comm, localDeviceId, remoteDeviceId, remoteClientId, isSameNode]() mutable { + auto traceId = Trace::Instance().GetTraceID(); + auto processFunc = [this, comm, localDeviceId, remoteDeviceId, remoteClientId, isSameNode, traceId]() mutable { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + INJECT_POINT("CreateHcclCommInSend.sleep"); + PerfPoint point(PerfKey::CLIENT_CREATE_HCCL_IN_SEND); + auto localClientId = clientWorkerApi_->GetClientId(); + auto peerId = GetHcclPeerId(localClientId, localDeviceId, remoteClientId, remoteDeviceId); + VLOG(1) << FormatString("[Sender] Sender try to acquire write lock, peerId: %s", peerId); std::lock_guard lock(mutex_); + + LOG(INFO) << FormatString("[Sender] Start to recv RootInfo from worker, peerId: %s", peerId); + RecvRootInfoReqPb rootInfoReq; + rootInfoReq.set_dst_client_id(remoteClientId); + rootInfoReq.set_dst_device_id(remoteDeviceId); + rootInfoReq.set_src_client_id(localClientId); + rootInfoReq.set_src_device_id(localDeviceId); + RecvRootInfoRspPb rootInfoResp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerApi_->RecvRootInfo(rootInfoReq, rootInfoResp), + "Failed with receive"); + if (rootInfoResp.is_dead_lock()) { + std::string msg = "[Sender] Deadlock detected, releasing lock and retrying."; + LOG(WARNING) << msg; + RETURN_IF_NOT_OK(SetStateIfError(comm, Status(K_CLIENT_DEADLOCK, msg))); + } + HcclRootInfo rootInfo; - RETURN_IF_NOT_OK(InitRootInfoReq(localDeviceId, remoteDeviceId, remoteClientId, rootInfo)); - RETURN_IF_NOT_OK(comm->InitCommunicator(rootInfo, HcclCommDirection::SEND, isSameNode)); + if (rootInfoResp.root_info().length() != sizeof(rootInfo.internal)) { + std::string msg = FormatString( + "The rsp rootInfo size is not as expected: %d, which usually indicates that the receiver " + "did not send the rootInfo properly. Check if there are any errors on the receiver side, peerId: %s", + rootInfoResp.root_info().length(), + peerId); + RETURN_IF_NOT_OK(SetStateIfError(comm, Status(K_RUNTIME_ERROR, msg))); + } + auto ret = memcpy_s(static_cast(rootInfo.internal), + sizeof(rootInfo.internal), + static_cast(rootInfoResp.root_info().c_str()), + rootInfoResp.root_info().length()); + if (ret != EOK) { + RETURN_STATUS(K_RUNTIME_ERROR, FormatString("Copy root info failed, the memcpy_s return: %d", ret)); + } + LOG(INFO) << "[Sender] Start init hccl comm"; + PrintRootInfo(rootInfo); + auto rc = comm->InitCommunicator(rootInfo, HcclCommDirection::SEND, isSameNode); + RETURN_IF_NOT_OK(SetStateIfError(comm, rc)); PerfPoint perfPoint(PerfKey::CLIENT_HCCL_WARMUP_IN_SEND); return comm->WarmUpComm(HcclCommDirection::SEND); }; - auto traceId = Trace::Instance().GetTraceID(); - comm->Execute([this, comm, process, traceId]() mutable { + + auto commPtr = comm; + comm->Execute([this, commPtr, processFunc, traceId]() mutable { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); - auto rc = process(); - auto checkRc = CreateHcclCommCheckError(comm); - StatusComparisonWithSetStatus(comm, rc, checkRc); + constexpr int32_t timeoutMs = 60 * 1000; + this->AsyncRetryWithTimeout( + commPtr, processFunc, [this](auto comm) { return this->CreateHcclCommCheckError(comm); }, timeoutMs, + { StatusCode::K_CLIENT_DEADLOCK }, + [this](auto comm, Status result, Status checkRc) { + this->StatusComparisonWithSetStatus(comm, result, checkRc); + }); }); } @@ -177,50 +207,104 @@ void HcclCommFactory::StatusComparisonWithSetStatus(std::shared_ptr &comm) { - auto asyncError = comm->HcclGetCommAsyncError(); - Status rc = Status::OK(); - if (asyncError != HCCL_SUCCESS) { - rc = Status(K_RUNTIME_ERROR, FormatString("Hccl comm async error, code is : %d", asyncError)); - } - comm->SetHcclDetailState(asyncError); - return rc; + auto status = comm->HcclGetCommAsyncError(); + comm->SetHcclDetailState(status); + return status; +} + +void HcclCommFactory::AsyncRetryWithTimeout( + std::shared_ptr comm, std::function processFunc, + std::function)> errorCheckFunc, int32_t timeoutMs, + const std::vector retryableErrors, + std::function, Status, Status)> finalHandler) +{ + auto startTime = std::chrono::steady_clock::now(); + auto retryFunction = std::make_shared>(); + *retryFunction = [comm, processFunc, errorCheckFunc, startTime, timeoutMs, retryableErrors, finalHandler, + retryFunction]() mutable { + // Execute processing function + Status result = processFunc(); + + // Check if error is retriable + bool shouldRetry = false; + for (auto errorCode : retryableErrors) { + if (result.GetCode() == errorCode) { + shouldRetry = true; + break; + } + } + if (shouldRetry) { + // Verify timeout + auto currentTime = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(currentTime - startTime).count(); + if (elapsed < timeoutMs) { + // Requeue for retry + comm->Execute(*retryFunction); + return; + } + } + + // Final processing step + auto checkRc = errorCheckFunc(comm); + finalHandler(comm, result, checkRc); + comm->SetCommReady(true); + }; + + // Start execution + (*retryFunction)(); } void HcclCommFactory::CreateHcclCommInRecv(int32_t localDeviceId, const std::string &remoteClientId, int32_t remoteDeviceId, bool isSameNode, std::shared_ptr &comm) { - auto process = [this, comm, localDeviceId, remoteDeviceId, remoteClientId, isSameNode]() mutable { + auto traceId = Trace::Instance().GetTraceID(); + auto process = [this, comm, localDeviceId, remoteDeviceId, remoteClientId, isSameNode, traceId]() mutable { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + INJECT_POINT("CreateHcclCommInRecv.sleep"); PerfPoint point(PerfKey::CLIENT_CREATE_HCCL_IN_RECV); + auto localClientId = clientWorkerApi_->GetClientId(); + auto peerId = GetHcclPeerId(remoteClientId, remoteDeviceId, localClientId, localDeviceId); + VLOG(1) << FormatString("[Receiver] Try to acquire write lock, peerId: %s", peerId); std::lock_guard lock(mutex_); + HcclRootInfo rootInfo; - RETURN_IF_NOT_OK(comm->CreateRootInfo(rootInfo)); + RETURN_IF_NOT_OK(SetStateIfError(comm, comm->CreateRootInfo(rootInfo))); - auto localClientId = clientWorkerApi_->GetClientId(); - SendRootInfoReqPb req; // rootInfo contain \0, must construct string like this. // use c_str() return to rootInfo. + SendRootInfoReqPb req; req.set_root_info(std::string(std::begin(rootInfo.internal), std::end(rootInfo.internal))); req.set_dst_client_id(localClientId); req.set_dst_device_id(localDeviceId); req.set_src_client_id(remoteClientId); req.set_src_device_id(remoteDeviceId); - auto peerId = GetHcclPeerId(remoteClientId, remoteDeviceId, localClientId, localDeviceId); - LOG(INFO) << "Send root info to worker, peerId: " << peerId; - RETURN_IF_NOT_OK(clientWorkerApi_->SendRootInfo(req)); - LOG(INFO) << "Receiver start init hccl comm"; + + LOG(INFO) << "[Receiver] Send root info to worker, peerId: " << peerId; + SendRootInfoRspPb rsp; + auto rc = clientWorkerApi_->SendRootInfo(req, rsp); + if (rc.GetCode() == StatusCode::K_CLIENT_DEADLOCK) { + LOG(WARNING) << "[Receiver] Deadlock occurred, release the lock and retry"; + RETURN_IF_NOT_OK(SetStateIfError(comm, rc)); + } + + LOG(INFO) << "[Receiver] Start init hccl comm"; PrintRootInfo(rootInfo); - RETURN_IF_NOT_OK(comm->InitCommunicator(rootInfo, HcclCommDirection::RECV, isSameNode)); + rc = comm->InitCommunicator(rootInfo, HcclCommDirection::RECV, isSameNode); + RETURN_IF_NOT_OK(SetStateIfError(comm, rc)); PerfPoint perfPoint(PerfKey::CLIENT_HCCL_WARMUP_IN_RECV); return comm->WarmUpComm(HcclCommDirection::RECV); }; - auto traceId = Trace::Instance().GetTraceID(); - comm->Execute([this, comm, process, traceId]() mutable { + auto commPtr = comm; + comm->Execute([this, commPtr, process, traceId]() mutable { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); - auto rc = process(); - comm->SetStatus(rc); - auto checkRc = CreateHcclCommCheckError(comm); - StatusComparisonWithSetStatus(comm, rc, checkRc); + constexpr int32_t timeoutMs = 60 * 1000; + this->AsyncRetryWithTimeout( + commPtr, process, [this](auto comm) { return this->CreateHcclCommCheckError(comm); }, timeoutMs, + { StatusCode::K_CLIENT_DEADLOCK }, + [this](auto comm, Status result, Status checkRc) { + this->StatusComparisonWithSetStatus(comm, result, checkRc); + }); }); } diff --git a/src/datasystem/client/object_cache/device/hccl_comm_factory.h b/src/datasystem/client/object_cache/device/hccl_comm_factory.h index b8b65ce..fc9c666 100644 --- a/src/datasystem/client/object_cache/device/hccl_comm_factory.h +++ b/src/datasystem/client/object_cache/device/hccl_comm_factory.h @@ -151,8 +151,36 @@ public: int32_t remoteDeviceId); private: - Status InitRootInfoReq(int32_t localDeviceId, int32_t remoteDeviceId, const std::string &remoteClientId, - HcclRootInfo &rootInfo); + /** + * @brief Handle communicator creation errors and set detailed error state. + * + * This function checks if the given status indicates an error, and if so, + * sets the detailed HCCL communication state on the communicator object + * before returning the error status. If no error is detected, it returns OK status. + * + * @param[in] comm Pointer to the Communicator object to set error state on + * @param[in] status The status to check for errors + * @return Returns the original error status if status.IsError() is true, + * otherwise returns Status::OK() + */ + Status SetStateIfError(std::shared_ptr &comm, Status status); + + /** + * @brief Asynchronously retry an operation with timeout and error handling. + * @param[in] comm The HCCL communicator wrapper shared pointer. + * @param[in] processFunc The main processing function to be executed and retried. + * @param[in] errorCheckFunc Function to check for errors in the communicator state before retrying. + * @param[in] timeoutMs Maximum timeout in milliseconds for the entire retry operation. + * @param[in] retryableErrors List of error status codes that should trigger a retry. + * @param[in] finalHandler Callback function to handle final result (success, timeout, or fatal error). + * @note This function will retry the processFunc for retryable errors until success or timeout. + * Non-retryable errors or timeout will trigger the finalHandler with appropriate status. + */ + void AsyncRetryWithTimeout(std::shared_ptr comm, std::function processFunc, + std::function)> errorCheckFunc, + int32_t timeoutMs, const std::vector retryableErrors, + std::function, Status, Status)> finalHandler); + TbbHcclCommTable commTable_; // To prevent two threads from trying to create a communication domain at the same time. std::shared_timed_mutex mutex_; diff --git a/src/datasystem/client/object_cache/device/p2p_subscribe.cpp b/src/datasystem/client/object_cache/device/p2p_subscribe.cpp index bd087b4..a976d40 100644 --- a/src/datasystem/client/object_cache/device/p2p_subscribe.cpp +++ b/src/datasystem/client/object_cache/device/p2p_subscribe.cpp @@ -136,72 +136,81 @@ void P2PSubscribe::ProcessP2PSend( auto recvDeviceId = kv.first.remoteDeviceId; bool isSameNode = kv.first.sameNode; auto &npuEvents = kv.second; + std::vector objectKeys; + std::transform(npuEvents.begin(), npuEvents.end(), std::back_inserter(objectKeys), + [](const SubscribeReceiveNpuEventPb &npuEvent) { return npuEvent.object_key(); }); + LOG(INFO) << FormatString("Get send event from npuId: %s;%d, keys:%s", recvClientId, recvDeviceId, + VectorToString(objectKeys)); StartMonitorThread(); std::shared_ptr comm; Status rc = commFactory_->GetOrCreateHcclComm(P2PEventType::SEND, deviceId_, recvClientId, recvDeviceId, isSameNode, clientEnableP2Ptransfer_, comm); if (rc.IsError()) { - std::vector objectKeys; - std::transform(npuEvents.begin(), npuEvents.end(), std::back_inserter(objectKeys), - [](const SubscribeReceiveNpuEventPb &npuEvent) { return npuEvent.object_key(); }); LOG(ERROR) << "ObjectKeys: " << VectorToString(objectKeys) << ", GetOrCreateHcclComm failed, " << rc.ToString(); - continue; + return; } CommRefCheckMoreThanOne(); + + // Register a callback function to be executed after the communication domain is ready Timer timer; auto traceId = Trace::Instance().GetTraceID(); - comm->Execute([this, npuEvents = std::move(npuEvents), comm, traceId, timer]() { - auto elapsedMs = static_cast(timer.ElapsedMicroSecond() * ONE_SECOND_MS); - PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_PUB_SUBMIT_DELAY, elapsedMs); - PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_PUB_SUBMIT_KEY_COUNT, npuEvents.size()); - PerfPoint point(PerfKey::CLIENT_P2P_PUB_PIPELINE_PREPARE); - TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); - point.RecordAndReset(PerfKey::CLIENT_P2P_PUB_PIPELINE_SUBMIT_ALL); - for (const auto &npuEvent : npuEvents) { - const auto &objectKey = npuEvent.object_key(); - std::shared_ptr putRequest; - auto found = GetPutRequest(objectKey, putRequest); - if (!found) { - LOG(ERROR) << FormatString("Can't find %s P2PPutRequest info", objectKey); - continue; - } - PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_PUB_SUBMIT_KEY_SIZE, putRequest->GetTotalSize()); - putRequest->CreateEvent(); - size_t srcOffset = npuEvent.src_offset(); - size_t length = npuEvent.length(); - std::vector dataInfos = putRequest->GetDataInfoStorage(); - // Calculate minimum size from all dataInfos - size_t minSize = - std::min_element(dataInfos.begin(), dataInfos.end(), [](const DataInfo &a, const DataInfo &b) { - return a.Size() < b.Size(); - })->Size(); - // Execute if receiver expects only partial data - if (srcOffset > 0 || length < minSize) { - VLOG(1) << "Adjusting data info parameters: srcOffset=" << srcOffset << ", length=" << length - << ", minSize=" << minSize; - for (auto &dataInfo : dataInfos) { - dataInfo.devPtr = static_cast(static_cast(dataInfo.devPtr) + srcOffset); - dataInfo.count = npuEvent.length(); + comm->AddReadyCallback([this, npuEvents = std::move(npuEvents), comm, traceId, timer]() { + comm->Execute([this, npuEvents, comm, traceId, timer]() { + auto elapsedMs = static_cast(timer.ElapsedMicroSecond() * ONE_SECOND_MS); + PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_PUB_SUBMIT_DELAY, elapsedMs); + PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_PUB_SUBMIT_KEY_COUNT, npuEvents.size()); + PerfPoint point(PerfKey::CLIENT_P2P_PUB_PIPELINE_PREPARE); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + point.RecordAndReset(PerfKey::CLIENT_P2P_PUB_PIPELINE_SUBMIT_ALL); + for (const auto &npuEvent : npuEvents) { + std::shared_ptr putRequest; + const auto &objectKey = npuEvent.object_key(); + auto found = GetPutRequest(objectKey, putRequest); + if (!found) { + LOG(ERROR) << FormatString("Can't find %s P2PPutRequest info", objectKey); + continue; } + PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_PUB_SUBMIT_KEY_SIZE, putRequest->GetTotalSize()); + putRequest->CreateEvent(); + size_t srcOffset = npuEvent.src_offset(); + size_t length = npuEvent.length(); + std::vector blobs = putRequest->GetBlobsStorage(); + // Calculate minimum size from all blobs + size_t minSize = + std::min_element(blobs.begin(), blobs.end(), [](const Blob &a, const Blob &b) { + return a.size < b.size; + })->size; + // Execute if receiver expects only partial data + if (srcOffset > 0 || length < minSize) { + VLOG(1) << "Adjusting data info parameters: srcOffset=" << srcOffset << ", length=" << length + << ", minSize=" << minSize; + for (auto &blob : blobs) { + blob.pointer = + static_cast(static_cast(blob.pointer) + srcOffset); + blob.size = npuEvent.length(); + } + } + LOG(INFO) << "Start submit send task for object key:" << objectKey; + acl::P2PSendTask sendTask{ .srcBuffers = putRequest->GetBlobsStorage(), + .totalSize = putRequest->GetTotalSize(), + .comm = comm, + .event = putRequest->GetEvent() }; + auto rc = comm->SubmitPipelineTask(std::move(sendTask)); + if (rc.IsError()) { + LOG(ERROR) << FormatString( + "Submitted P2P send task execution failed for object:%s, error msg: [%s]", + objectKey, + rc.GetMsg()); + LOG_IF_ERROR(putRequest->SetPromiseValue(rc), "promise set value failed."); + return; + } + std::shared_ptr ackReq = std::make_shared(putRequest); + p2pAckQueue_.Push(ackReq); } - LOG(INFO) << "Start submit send task for object key:" << objectKey; - acl::P2PSendTask sendTask{ .srcBuffers = dataInfos, - .totalSize = putRequest->GetTotalSize(), - .comm = comm, - .event = putRequest->GetEvent() }; - auto rc = comm->SubmitPipelineTask(std::move(sendTask)); - if (rc.IsError()) { - LOG(ERROR) << FormatString("ObjectKey %s submit task failed, %s", objectKey, rc.ToString()); - putRequest->SetPromiseValue(rc); - continue; - } - - std::shared_ptr ackReq = std::make_shared(putRequest); - p2pAckQueue_.Push(ackReq); - } - point.RecordAndReset(PerfKey::CLIENT_P2P_PUB_PIPELINE_OTHER); + point.RecordAndReset(PerfKey::CLIENT_P2P_PUB_PIPELINE_OTHER); + }); }); } } @@ -219,7 +228,7 @@ void P2PSubscribe::RunP2PRecvLoop() if (p2pGetRequests->Size() == 0) { continue; } - TraceGuard traceGuard = Trace::Instance().SetTraceNewID(GetStringUuid()); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(p2pGetRequests->getTraceId_); if (!first) { auto elapsedMs = static_cast(lastGetTimer.ElapsedMicroSecond() * ONE_SECOND_MS); PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_SUB_NEXT_GET_DELAY, elapsedMs); @@ -282,6 +291,7 @@ void P2PSubscribe::RunP2PAckLoop() if (p2pAckQueue_.Pop(p2pAckReq).IsError() || p2pAckReq == nullptr) { continue; } + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); if (p2pAckReq->type == P2PAckReqType::GET) { auto &p2pGetRequest = p2pAckReq->p2pGetRequest; if (p2pGetRequest->GetEvent() == nullptr) { @@ -318,15 +328,15 @@ Status P2PSubscribe::ProcessP2PGet(const std::shared_ptr return Status(K_NOT_FOUND, "p2p meta data get timeout"); } std::vector> bufferInfoList; - std::vector> dataInfoStorageList; + std::vector blobStorageList; std::unordered_map> objKeyToP2PRequest; for (size_t i = 0; i < p2pGetRequests->Size(); i++) { auto &p2pGetRequest = p2pGetRequests->requestList_[i]; const auto &bufferInfo = p2pGetRequest->GetBufferInfo(); const auto &objectKey = bufferInfo->devObjKey; - const auto &dataInfoStorage = p2pGetRequest->GetDataInfoStorage(); + const auto &blobStorage = p2pGetRequest->GetBlobsStorage(); bufferInfoList.emplace_back(bufferInfo); - dataInfoStorageList.emplace_back(dataInfoStorage); + blobStorageList.emplace_back(DeviceBlobList{ .blobs = blobStorage, .deviceIdx = -1 }); (void)objKeyToP2PRequest.emplace(objectKey, p2pGetRequest); VLOG(1) << FormatString("%s is ready to P2PGet", objectKey); } @@ -336,7 +346,7 @@ Status P2PSubscribe::ProcessP2PGet(const std::shared_ptr std::chrono::duration_cast(now - p2pGetRequests->initializationTime_).count(); auto subTimeout = elapsedTime > p2pGetRequests->subTimeout_ ? 0 : p2pGetRequests->subTimeout_ - elapsedTime; point.RecordAndReset(PerfKey::CLIENT_P2P_SUB_GETMETA); - auto ret = clientWorkerApi_->GetP2PMeta(bufferInfoList, dataInfoStorageList, resp, subTimeout); + auto ret = clientWorkerApi_->GetP2PMeta(bufferInfoList, blobStorageList, resp, subTimeout); if (ret.IsError()) { LOG(ERROR) << "GetP2PMeta error,msg:" << ret.GetMsg(); if (ret.GetCode() == K_RPC_DEADLINE_EXCEEDED) { @@ -363,7 +373,6 @@ void P2PSubscribe::ProcessP2PRecv( auto isSameNode = kv.first.sameNode; auto &respList = kv.second; std::shared_ptr comm; - auto traceId = Trace::Instance().GetTraceID(); StartMonitorThread(); auto rc = commFactory_->GetOrCreateHcclComm(P2PEventType::RECV, deviceId_, srcClientId, srcDeviceId, isSameNode, clientEnableP2Ptransfer_, comm); @@ -376,54 +385,63 @@ void P2PSubscribe::ProcessP2PRecv( } finishedList.insert(objectKeys.cbegin(), objectKeys.cend()); CommRefCheckMoreThanOne(); + + // Register a callback function to be executed after the communication domain is ready Timer timer; - comm->Execute([this, respList = std::move(respList), comm, objKeyToP2PRequest, srcClientId, srcDeviceId, - traceId, timer]() mutable { - auto elapsedMs = static_cast(timer.ElapsedMicroSecond() * ONE_SECOND_MS); - PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_SUB_SUBMIT_DELAY, elapsedMs); - PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_SUB_SUBMIT_KEY_COUNT, respList.size()); - PerfPoint point(PerfKey::CLIENT_P2P_SUB_PIPELINE_PREPARE); - TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); - std::vector> requests; - size_t maxObjectSize = 0; - for (const auto &resp : respList) { - const auto &objectKey = resp.object_key(); - auto iter = objKeyToP2PRequest.find(objectKey); - if (iter == objKeyToP2PRequest.end()) { - LOG(ERROR) << "object key:" << objectKey << " not found in objKeyToP2PRequest"; - continue; - } - auto &getRequest = iter->second; - requests.emplace_back(getRequest); - maxObjectSize = std::max(maxObjectSize, getRequest->GetTotalSize()); - } - point.RecordAndReset(PerfKey::CLIENT_P2P_SUB_PIPELINE_SUBMIT_ALL); - for (const auto &p2pGetRequest : requests) { - PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_SUB_SUBMIT_KEY_SIZE, p2pGetRequest->GetTotalSize()); - const auto &objectKey = p2pGetRequest->GetObjectKey(); - const auto &bufferInfo = p2pGetRequest->GetBufferInfo(); - auto dataInfoStorage = p2pGetRequest->GetDataInfoStorage(); - LOG(INFO) << FormatString("Start submit recv task for object key: %s", objectKey); - acl::P2PRecvTask recvTask{ .destBuffers = dataInfoStorage, - .totalSize = p2pGetRequest->GetTotalSize(), - .comm = comm, - .event = p2pGetRequest->GetEvent() }; - auto rc = comm->SubmitPipelineTask(std::move(recvTask)); - if (rc.IsError()) { - LOG(ERROR) << "P2Precv error objkey: " << objectKey << " with error " << rc.GetMsg(); - LOG_IF_ERROR(p2pGetRequest->SetPromiseValue(rc), "promise set value failed."); - continue; + auto traceId = Trace::Instance().GetTraceID(); + comm->AddReadyCallback([this, respList = std::move(respList), comm, objKeyToP2PRequest, srcClientId, + srcDeviceId, traceId, timer]() { + comm->Execute([this, respList, comm, objKeyToP2PRequest, srcClientId, srcDeviceId, traceId, + timer]() mutable { + auto elapsedMs = static_cast(timer.ElapsedMicroSecond() * ONE_SECOND_MS); + PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_SUB_SUBMIT_DELAY, elapsedMs); + PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_SUB_SUBMIT_KEY_COUNT, respList.size()); + PerfPoint point(PerfKey::CLIENT_P2P_SUB_PIPELINE_PREPARE); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + std::vector> requests; + size_t maxObjectSize = 0; + for (const auto &resp : respList) { + const auto &objectKey = resp.object_key(); + auto iter = objKeyToP2PRequest.find(objectKey); + if (iter == objKeyToP2PRequest.end()) { + LOG(ERROR) << "object key:" << objectKey << " not found in objKeyToP2PRequest"; + continue; + } + auto &getRequest = iter->second; + requests.emplace_back(getRequest); + maxObjectSize = std::max(maxObjectSize, getRequest->GetTotalSize()); } - if (bufferInfo->cacheLocation) { - AddSubscribe(bufferInfo, dataInfoStorage); - (void)devMemUnitTable_.insert(std::make_pair(objectKey, p2pGetRequest->GetMemUnit())); + point.RecordAndReset(PerfKey::CLIENT_P2P_SUB_PIPELINE_SUBMIT_ALL); + for (const auto &p2pGetRequest : requests) { + PerfPoint::RecordElapsed(PerfKey::CLIENT_P2P_SUB_SUBMIT_KEY_SIZE, p2pGetRequest->GetTotalSize()); + const auto &objectKey = p2pGetRequest->GetObjectKey(); + const auto &bufferInfo = p2pGetRequest->GetBufferInfo(); + auto blobStorage = p2pGetRequest->GetBlobsStorage(); + LOG(INFO) << FormatString("Start submit recv task for object key: %s", objectKey); + acl::P2PRecvTask recvTask{ .destBuffers = blobStorage, + .totalSize = p2pGetRequest->GetTotalSize(), + .comm = comm, + .event = p2pGetRequest->GetEvent() }; + auto rc = comm->SubmitPipelineTask(std::move(recvTask)); + if (rc.IsError()) { + LOG(ERROR) << FormatString( + "Submitted P2P receive task execution failed for object:%s, error msg: [%s]", + objectKey, + rc.GetMsg()); + LOG_IF_ERROR(p2pGetRequest->SetPromiseValue(rc), "promise set value failed."); + continue; + } + if (bufferInfo->cacheLocation) { + AddSubscribe(bufferInfo, blobStorage); + (void)devMemUnitTable_.insert(std::make_pair(objectKey, p2pGetRequest->GetMemUnit())); + } + p2pGetRequest->SetSrcClientId(srcClientId); + p2pGetRequest->SetSrcDeviceId(srcDeviceId); + std::shared_ptr req = std::make_shared(p2pGetRequest); + p2pAckQueue_.Push(req); } - p2pGetRequest->SetSrcClientId(srcClientId); - p2pGetRequest->SetSrcDeviceId(srcDeviceId); - std::shared_ptr req = std::make_shared(p2pGetRequest); - p2pAckQueue_.Push(req); - } - point.RecordAndReset(PerfKey::CLIENT_P2P_SUB_PIPELINE_OTHER); + point.RecordAndReset(PerfKey::CLIENT_P2P_SUB_PIPELINE_OTHER); + }); }); } } @@ -456,16 +474,24 @@ Status P2PSubscribe::ProcessP2PResponse( groupedSubResp[groupKey].emplace_back(std::move(subResp)); } ProcessP2PRecv(groupedSubResp, objKeyToP2PRequest, finishedList); + std::stringstream retryKeys; + bool first = true; auto remainTasks = std::make_shared(p2pGetRequests->prefetchTimeout_, p2pGetRequests->subTimeout_); for (size_t i = 0; i < p2pGetRequests->Size(); i++) { const auto &objectKey = p2pGetRequests->requestList_[i]->GetBufferInfo()->devObjKey; if (finishedList.find(objectKey) == finishedList.end()) { remainTasks->requestList_.emplace_back(std::move(p2pGetRequests->requestList_[i])); + if (!first) { + retryKeys << ", "; + } + retryKeys << objectKey; + first = false; } remainTasks->initializationTime_ = p2pGetRequests->initializationTime_; } if (remainTasks->Size() > 0) { + VLOG(1) << FormatString("Re-adding unfinished keys [%s] to p2pGetQueue_", retryKeys.str()); p2pGetQueue_.Push(remainTasks); } return Status::OK(); @@ -490,9 +516,9 @@ bool P2PSubscribe::GetPutRequest(const std::string &objectKey, std::shared_ptr

P2PSubscribe::AddSubscribe(const std::shared_ptr &bufferInfo, - const std::vector &dataInfoList) + const std::vector &blobs) { - auto putRequest = std::make_shared(bufferInfo, dataInfoList); + auto putRequest = std::make_shared(bufferInfo, blobs); (void)objKey2PutReqTable_.insert(std::make_pair(bufferInfo->devObjKey, putRequest)); return putRequest; } @@ -509,11 +535,11 @@ Status P2PSubscribe::PublishDeviceObject(const std::shared_ptr &bu RETURN_STATUS_LOG_ERROR(K_RUNTIME_ERROR, FormatString("The ID already exists,ID:%s", bufferInfo->devObjKey)); } - auto &storagedataInfoVec = buffAcc->second->GetDataInfoStorage(); - auto &newDataInfoVec = buffer->GetDeviceMemUnit()->GetDataInfoStorage(); + auto &storageBlobVec = buffAcc->second->GetBlobsStorage(); + auto &newBlobVec = buffer->GetDeviceMemUnit()->GetBlobsStorage(); auto sameDataPtr = - std::equal(storagedataInfoVec.begin(), storagedataInfoVec.end(), newDataInfoVec.begin(), - newDataInfoVec.end(), [](const DataInfo &a, const DataInfo &b) { return a.devPtr == b.devPtr; }); + std::equal(storageBlobVec.begin(), storageBlobVec.end(), newBlobVec.begin(), + newBlobVec.end(), [](const Blob &a, const Blob &b) { return a.pointer == b.pointer; }); if (sameDataPtr) { return Status::OK(); } @@ -522,9 +548,9 @@ Status P2PSubscribe::PublishDeviceObject(const std::shared_ptr &bu bufferInfo->devObjKey)); } auto devMemUnit = buffer->GetDeviceMemUnit(); - auto putRequest = AddSubscribe(bufferInfo, devMemUnit->GetDataInfoStorage()); + auto putRequest = AddSubscribe(bufferInfo, devMemUnit->GetBlobsStorage()); VLOG(1) << "PutP2PMeta to worker, objectKey: " << buffer->GetObjectKey(); - auto rc = clientWorkerApi_->PutP2PMeta(bufferInfo, devMemUnit->GetDataInfoStorage()); + auto rc = clientWorkerApi_->PutP2PMeta(bufferInfo, devMemUnit->GetBlobsStorage()); INJECT_POINT("PublishDeviceObject.PutP2PMeta.Timeout", [&rc]() { rc = Status(StatusCode::K_RUNTIME_ERROR, "timeout"); return Status::OK(); @@ -544,7 +570,7 @@ Status P2PSubscribe::GetSendStatus(const std::string &objectKey, std::vectorsecond; - return putRequest->CreateEventAndFutureList(putRequest->GetDataInfoStorage().size(), futureVec); + return putRequest->CreateEventAndFutureList(putRequest->GetBlobsStorage().size(), futureVec); } RETURN_STATUS(K_NOT_FOUND, FormatString("The objectKey [ %s ] is not found in this client.", objectKey)); } @@ -602,7 +628,7 @@ void P2PSubscribe::MonitorLoop() for (const auto &comm : hcclCommVec) { auto rc = comm->CheckHealth(connectTimeOutMS_); if (rc.IsError()) { - LOG(ERROR) << rc.ToString(); + LOG(ERROR) << FormatString("Hccl comm health check failed, %s", rc.ToString()); (void)commFactory_->DelComm(comm->GetCommId()); } } @@ -655,14 +681,16 @@ Status P2PSubscribe::AsyncGet(const std::vector> & std::string devObjKey = buffer->bufferInfo_->devObjKey; TbbP2PPutRequestTable::accessor acc; if (objKey2PutReqTable_.find(acc, devObjKey)) { + LOG(INFO) << "Get key " << devObjKey + << " found locally in put request table, performing on-device D2D copy directly."; std::shared_ptr putRequest = acc->second; - const std::vector dataInfosInPut = putRequest->GetDataInfoStorage(); - std::vector dataInfosInGet = buffer->GetDataInfoList(); - for (size_t i = 0; i < dataInfosInPut.size(); i++) { - auto adjustedPtr = static_cast(static_cast(dataInfosInPut[i].devPtr) + const std::vector blobsInPut = putRequest->GetBlobsStorage(); + std::vector blobsInGet = buffer->GetDevBlobList(); + for (size_t i = 0; i < blobsInPut.size(); i++) { + auto adjustedPtr = static_cast(static_cast(blobsInPut[i].pointer) + buffer->bufferInfo_->srcOffset); - RETURN_IF_NOT_OK(aclImpl_->MemCopyD2D(dataInfosInGet[i].devPtr, dataInfosInGet[i].Size(), - static_cast(adjustedPtr), dataInfosInGet[i].Size())); + RETURN_IF_NOT_OK(aclImpl_->MemCopyD2D(blobsInGet[i].pointer, blobsInGet[i].size, + static_cast(adjustedPtr), blobsInGet[i].size)); } auto promise = std::make_shared(devObjKey); promise->CreateEventAndFutureList(0, futureVec); @@ -670,11 +698,12 @@ Status P2PSubscribe::AsyncGet(const std::vector> & continue; } auto getRequest = - std::make_shared(buffer->bufferInfo_, buffer->GetDataInfoList(), buffer->GetDeviceMemUnit()); + std::make_shared(buffer->bufferInfo_, buffer->GetDevBlobList(), buffer->GetDeviceMemUnit()); p2pRequestsWrapper->requestList_.emplace_back(getRequest); - RETURN_IF_NOT_OK(getRequest->CreateEventAndFutureList(buffer->GetDataInfoList().size(), futureVec)); + RETURN_IF_NOT_OK(getRequest->CreateEventAndFutureList(buffer->GetDevBlobList().size(), futureVec)); } p2pRequestsWrapper->subTimeout_ = subTimeoutMs; + p2pRequestsWrapper->getTraceId_ = Trace::Instance().GetTraceID(); p2pGetQueue_.Push(p2pRequestsWrapper); return Status::OK(); } diff --git a/src/datasystem/client/object_cache/device/p2p_subscribe.h b/src/datasystem/client/object_cache/device/p2p_subscribe.h index b95d865..91c182c 100644 --- a/src/datasystem/client/object_cache/device/p2p_subscribe.h +++ b/src/datasystem/client/object_cache/device/p2p_subscribe.h @@ -119,10 +119,10 @@ private: class P2PPutRequest : public PromiseWithEvent { public: - P2PPutRequest(std::shared_ptr deviceBufferInfo, std::vector dataInfoStorage) + P2PPutRequest(std::shared_ptr deviceBufferInfo, std::vector blobStorage) : PromiseWithEvent(deviceBufferInfo->devObjKey), bufferInfo_(std::move(deviceBufferInfo)), - dataInfoStorage_(std::move(dataInfoStorage)) + blobStorage_(std::move(blobStorage)) { } std::shared_ptr GetBufferInfo() @@ -130,15 +130,15 @@ public: return bufferInfo_; } - const std::vector &GetDataInfoStorage() const + const std::vector &GetBlobsStorage() const { - return dataInfoStorage_; + return blobStorage_; } size_t GetTotalSize() const { - return std::accumulate(dataInfoStorage_.begin(), dataInfoStorage_.end(), 0ul, - [](size_t total, const DataInfo &info) { return total + info.size; }); + return std::accumulate(blobStorage_.begin(), blobStorage_.end(), 0ul, + [](size_t total, const Blob &info) { return total + info.size; }); } const std::string &GetObjectKey() const @@ -148,14 +148,14 @@ public: private: std::shared_ptr bufferInfo_; - std::vector dataInfoStorage_; + std::vector blobStorage_; }; class P2PGetRequest : public P2PPutRequest { public: - P2PGetRequest(std::shared_ptr deviceBufferInfo, std::vector dataInfoStorage, + P2PGetRequest(std::shared_ptr deviceBufferInfo, std::vector blobStorage, std::shared_ptr memUnit) - : P2PPutRequest(std::move(deviceBufferInfo), std::move(dataInfoStorage)), devMemUnit_(memUnit) + : P2PPutRequest(std::move(deviceBufferInfo), std::move(blobStorage)), devMemUnit_(memUnit) { } @@ -225,6 +225,7 @@ public: std::vector> requestList_; int64_t prefetchTimeout_; int64_t subTimeout_; + std::string getTraceId_; }; enum class P2PEventReqType { CREATE, UNCREATE }; @@ -276,12 +277,12 @@ public: /** * @brief Add the device object key to the subscribe queue. * @param[in] bufferInfo The info of device buffer. - * @param[in] dataInfoList The list of data info. + * @param[in] blobs The list of blob info. * @return The future vector of HcclSend result. You can use the Get() method of the future object corresponding to * sendDataList to wait for and access the result of HcclSend. */ std::shared_ptr AddSubscribe(const std::shared_ptr &bufferInfo, - const std::vector &dataInfoList); + const std::vector &blobs); /** * @brief Remove the device object key from the subscribe queue in worker. @@ -342,7 +343,7 @@ public: /** * @brief Get the Data Info object. * @param[in] objectKey The device object key. - * @param[in] dataInfos The list of data info. + * @param[in] putRequest The request of put. * @return true if get data info success. */ bool GetPutRequest(const std::string &objectKey, std::shared_ptr &putRequest); diff --git a/src/datasystem/client/object_cache/object_client.cpp b/src/datasystem/client/object_cache/object_client.cpp index e5b6d86..045b6b8 100644 --- a/src/datasystem/client/object_cache/object_client.cpp +++ b/src/datasystem/client/object_cache/object_client.cpp @@ -64,12 +64,16 @@ Status ObjectClient::Create(const std::string &objectKey, uint64_t size, const C { TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); AccessRecorder accessPoint(AccessRecorderKey::DS_OBJECT_CLIENT_CREATE); - Status rc = impl_->Create(objectKey, size, param, buffer); + object_cache::FullParam innerParam; + innerParam.writeMode = WriteMode::NONE_L2_CACHE; + innerParam.consistencyType = param.consistencyType; + innerParam.cacheType = param.cacheType; + Status rc = impl_->Create(objectKey, size, innerParam, buffer); RequestParam reqParam; reqParam.objectKey = objectKey.substr(0, LOG_OBJECT_KEY_SIZE_LIMIT); - reqParam.writeMode = std::to_string(static_cast(param.writeMode)); - reqParam.consistencyType = std::to_string(static_cast(param.consistencyType)); - reqParam.cacheType = std::to_string(static_cast(param.cacheType)); + reqParam.writeMode = std::to_string(static_cast(innerParam.writeMode)); + reqParam.consistencyType = std::to_string(static_cast(innerParam.consistencyType)); + reqParam.cacheType = std::to_string(static_cast(innerParam.cacheType)); accessPoint.Record(rc.GetCode(), std::to_string(size), reqParam, rc.GetMsg()); return rc; } @@ -127,12 +131,16 @@ Status ObjectClient::Put(const std::string &objectKey, const uint8_t *data, uint { TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); AccessRecorder accessPoint(AccessRecorderKey::DS_OBJECT_CLIENT_PUT); - Status rc = impl_->Put(objectKey, data, size, param, nestedObjectKeys); + object_cache::FullParam innerParam; + innerParam.writeMode = WriteMode::NONE_L2_CACHE; + innerParam.consistencyType = param.consistencyType; + innerParam.cacheType = param.cacheType; + Status rc = impl_->Put(objectKey, data, size, innerParam, nestedObjectKeys); RequestParam reqParam; reqParam.objectKey = objectKey.substr(0, LOG_OBJECT_KEY_SIZE_LIMIT); - reqParam.writeMode = std::to_string(static_cast(param.writeMode)); - reqParam.consistencyType = std::to_string(static_cast(param.consistencyType)); - reqParam.cacheType = std::to_string(static_cast(param.cacheType)); + reqParam.writeMode = std::to_string(static_cast(innerParam.writeMode)); + reqParam.consistencyType = std::to_string(static_cast(innerParam.consistencyType)); + reqParam.cacheType = std::to_string(static_cast(innerParam.cacheType)); accessPoint.Record(rc.GetCode(), std::to_string(size), reqParam, rc.GetMsg()); return rc; } diff --git a/src/datasystem/client/object_cache/object_client_impl.cpp b/src/datasystem/client/object_cache/object_client_impl.cpp index 1e5efe4..35c7ccc 100644 --- a/src/datasystem/client/object_cache/object_client_impl.cpp +++ b/src/datasystem/client/object_cache/object_client_impl.cpp @@ -255,7 +255,7 @@ void ObjectClientImpl::MGetAsyncRpcThread(const std::shared_ptrpromise.set_value({ result, resourcePtr->failList }); return; } - auto rc = HostDataCopy2Device(resourcePtr->dataInfoList, resourcePtr->existBufferList); + auto rc = HostDataCopy2Device(resourcePtr->devBlobList, resourcePtr->existBufferList); resourcePtr->promise.set_value({ rc, resourcePtr->failList }); } @@ -541,33 +541,6 @@ Status ObjectClientImpl::GetAvailableWorkerApi(std::shared_ptr return Status::OK(); } -Status ObjectClientImpl::ConvertToDataInfoList(const std::vector &devBlobList, - std::vector> &dataInfoList) -{ - auto defaultType = DataType::DATA_TYPE_INT8; - auto dataByteSize = GetBytesFromDataType(defaultType); - if (dataByteSize == 0) { - return Status(K_RUNTIME_ERROR, "Get unexpected data type!"); - } - if (devBlobList.empty()) { - return Status::OK(); - } - dataInfoList.resize(devBlobList.size()); - std::vector device; - for (size_t i = 0; i < devBlobList.size(); i++) { - for (const auto &blob : devBlobList[i].blobs) { - CHECK_FAIL_RETURN_STATUS( - blob.size > 0, K_INVALID, - FormatString("Got empty or illegal size in devBlobList and the illegal size is: %lu", blob.size)); - DataInfo info{ blob.pointer, defaultType, (blob.size / dataByteSize), blob.size, devBlobList[i].deviceIdx }; - dataInfoList[i].emplace_back(std::move(info)); - } - device.emplace_back(devBlobList[i].deviceIdx); - } - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckDeviceValid(device), "Check device failed."); - return Status::OK(); -} - std::shared_future ObjectClientImpl::MGetH2D(const std::vector &objectKeys, const std::vector &devBlobList, uint64_t timeoutMs) @@ -600,7 +573,10 @@ std::shared_future ObjectClientImpl::MGetH2D(const std::vector &existBufferList = asyncResource->existBufferList; existBufferList.reserve(bufferList.size()); + std::vector devices; + devices.reserve(objectKeys.size()); for (auto i = 0ul; i < objectKeys.size(); i++) { + devices.emplace_back(devBlobList[i].deviceIdx); if (!bufferList[i]) { asyncResource->failList.emplace_back(objectKeys[i]); existBufferList.emplace_back(nullptr); @@ -609,8 +585,8 @@ std::shared_future ObjectClientImpl::MGetH2D(const std::vector> &dataInfoList = asyncResource->dataInfoList; - RETURN_IF_NOT_OK(ConvertToDataInfoList(devBlobList, dataInfoList)); + asyncResource->devBlobList = devBlobList; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckDeviceValid(devices), "Check device failed."); return Status::OK(); }); @@ -622,11 +598,11 @@ std::shared_future ObjectClientImpl::MGetH2D(const std::vector> &dataInfoList, +Status ObjectClientImpl::HostDataCopy2Device(std::vector &devBlobList, std::vector &existBufferList) { PerfPoint point(PerfKey::CLIENT_H2D_MEMCPY); - RETURN_IF_NOT_OK(devOcImpl_->MemCopyBetweenDevAndHost(dataInfoList, existBufferList, + RETURN_IF_NOT_OK(devOcImpl_->MemCopyBetweenDevAndHost(devBlobList, existBufferList, aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE, workerApi_[LOCAL_WORKER]->IsEnableHugeTlb())); @@ -638,35 +614,46 @@ Status ObjectClientImpl::HostDataCopy2Device(std::vector> Status ObjectClientImpl::DeviceDataCreate(const std::vector &objectKeys, const std::vector &devBlobList, const SetParam &setParam, - std::vector> &bufferList, - std::vector &destroyBufferList) + std::vector> &bufferList, std::vector &exists) { PerfPoint point(PerfKey::CLIENT_MULTI_CREATE_OBJECT); CHECK_FAIL_RETURN_STATUS(!objectKeys.empty(), K_INVALID, "The keys are empty"); CHECK_FAIL_RETURN_STATUS(objectKeys.size() == devBlobList.size(), K_INVALID, "The size of objectKeys and devBlobList does not match"); - std::vector> dataInfoList; - RETURN_IF_NOT_OK(ConvertToDataInfoList(devBlobList, dataInfoList)); - CreateParam param; + FullParam param; param.writeMode = setParam.writeMode; param.cacheType = setParam.cacheType; std::vector dataSizeList; dataSizeList.reserve(objectKeys.size()); BlobListInfo blobInfo; - RETURN_IF_NOT_OK(PrepareDataSizeList(dataSizeList, dataInfoList, blobInfo)); + RETURN_IF_NOT_OK(PrepareDataSizeList(dataSizeList, devBlobList, blobInfo)); LOG(INFO) << blobInfo.ToString(true); + RETURN_IF_NOT_OK(MultiCreate(objectKeys, dataSizeList, param, false, bufferList, exists)); + std::vector> filterBufferList; + std::vector filterDevBlobList; + filterBufferList.reserve(objectKeys.size()); + filterDevBlobList.reserve(objectKeys.size()); + for (auto idx = 0u; idx < objectKeys.size(); idx++) { + if (exists[idx]) { + continue; + } + filterBufferList.emplace_back(bufferList[idx]); + filterDevBlobList.emplace_back(devBlobList[idx]); + } - RETURN_IF_NOT_OK(MultiCreate(objectKeys, dataSizeList, param, bufferList)); + bufferList = filterBufferList; + if (bufferList.empty()) { + return Status::OK(); + } point.RecordAndReset(PerfKey::CLIENT_D2H_MEMCPY); - ComposeBufferData(bufferList, dataInfoList); - - // destroyBufferList same as bufferList - destroyBufferList.reserve(bufferList.size()); + ComposeBufferData(bufferList, filterDevBlobList); + std::vector bufferRawPtrList; + bufferRawPtrList.reserve(bufferList.size()); for (auto &buff : bufferList) { - destroyBufferList.emplace_back(buff.get()); + bufferRawPtrList.emplace_back(buff.get()); } - RETURN_IF_NOT_OK(devOcImpl_->MemCopyBetweenDevAndHost(dataInfoList, destroyBufferList, + RETURN_IF_NOT_OK(devOcImpl_->MemCopyBetweenDevAndHost(filterDevBlobList, bufferRawPtrList, aclrtMemcpyKind::ACL_MEMCPY_DEVICE_TO_HOST, workerApi_[LOCAL_WORKER]->IsEnableHugeTlb())); @@ -698,14 +685,14 @@ std::shared_future ObjectClientImpl::MSet(const std::vector ObjectClientImpl::MSet(const std::vector> bufferList; std::vector exists; - { - PerfPoint point(PerfKey::CLIENT_MSET_CHECK_EXISTS); - (void)Exist(objectKeys, exists, false, true); + auto rc = DeviceDataCreate(objectKeys, devBlobList, setParam, bufferList, exists); + if (rc.IsError()) { + result.status = rc; + return result; } - // Filter non-existing objects std::vector nonExistobjectKeys; std::vector nonExistDevBlobList; + std::vector devices; for (size_t i = 0; i < objectKeys.size(); ++i) { if (!exists[i]) { nonExistobjectKeys.emplace_back(objectKeys[i]); nonExistDevBlobList.emplace_back(devBlobList[i]); + devices.emplace_back(devBlobList[i].deviceIdx); } } + auto deviceCheckRc = CheckDeviceValid(devices); + if (deviceCheckRc.IsError()) { + result.status = deviceCheckRc; + return result; + } // If all objects already exist, return success immediately if (nonExistobjectKeys.empty()) { result.status = Status::OK(); return result; } - - // Step2: execute DeviceDataCreate - std::vector> bufferList; - std::vector destroyBufferList; - auto rc = DeviceDataCreate(nonExistobjectKeys, nonExistDevBlobList, setParam, bufferList, destroyBufferList); - if (rc.IsError()) { - result.status = rc; - return result; - } - // Step3: Execute final MultiPublish operation { PerfPoint point(PerfKey::CLIENT_MULTI_PUBLISH_OBJECT); std::vector> blobSizes; - blobSizes.reserve(devBlobList.size()); - for (auto &devblob : devBlobList) { + blobSizes.reserve(nonExistDevBlobList.size()); + for (auto &devblob : nonExistDevBlobList) { std::vector sizeList; sizeList.reserve(devblob.blobs.size()); for (auto &blob : devblob.blobs) { @@ -766,12 +751,6 @@ std::shared_future ObjectClientImpl::MSet(const std::vector &buffer) { std::shared_lock shutdownLck(shutdownMux_); RETURN_IF_NOT_OK(IsClientReady()); CHECK_FAIL_RETURN_STATUS(!objectKey.empty(), K_INVALID, "The objectKey is empty"); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(objectKey), K_INVALID, "The objectKey contains illegal char(s)."); + RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); CHECK_FAIL_RETURN_STATUS(dataSize > 0, K_INVALID, "The dataSize value should be bigger than zero."); RETURN_IF_NOT_OK(CheckConnection()); PerfPoint createPoint(PerfKey::CLIENT_CREATE_OBJECT); @@ -858,52 +837,34 @@ Status ObjectClientImpl::Create(const std::string &objectKey, uint64_t dataSize, } Status ObjectClientImpl::ConstructMultiCreateParam(const std::vector &objectKeyList, - const std::vector &dataSizeList, const CreateParam ¶m, + const std::vector &dataSizeList, std::vector> &bufferList, std::vector &multiCreateParamList) { CHECK_FAIL_RETURN_STATUS(objectKeyList.size() == dataSizeList.size(), K_INVALID, "The length of objectKeyList and dataSizeList should be the same."); - auto totalDataSize = 0ul; for (size_t i = 0; i < objectKeyList.size(); i++) { auto &objectKey = objectKeyList[i]; auto dataSize = dataSizeList[i]; CHECK_FAIL_RETURN_STATUS(!objectKey.empty(), K_INVALID, "The objectKey is empty"); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(objectKey), K_INVALID, - "The objectKey contains illegal char(s)."); + RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); CHECK_FAIL_RETURN_STATUS(dataSize > 0, K_INVALID, "The dataSize value should be bigger than zero."); - totalDataSize += dataSize; } bufferList.resize(objectKeyList.size()); // if total size >=500k , transfer by shm - if (totalDataSize >= workerApi_[LOCAL_WORKER]->GetShmThreshold() && ShmEnable()) { - size_t index = 0; - std::for_each(objectKeyList.begin(), objectKeyList.end(), - [&index, &multiCreateParamList, &dataSizeList](const std::string &objectKey) { - multiCreateParamList.emplace_back(index, objectKey, dataSizeList[index]); - index++; - }); - return Status::OK(); - } - for (size_t i = 0; i < objectKeyList.size(); i++) { - auto &objectKey = objectKeyList[i]; - auto dataSize = dataSizeList[i]; - auto version = 0u; - std::shared_ptr newBuffer; - ObjectBufferInfo bufferInfo = SetObjectBufferInfo(objectKey, nullptr, dataSize, 0, param, false, version); - auto rc = Buffer::CreateBuffer(bufferInfo, shared_from_this(), newBuffer); - if (rc.IsError()) { - bufferList.clear(); - return rc; - } - bufferList[i] = std::move(newBuffer); - } + size_t index = 0; + std::for_each(objectKeyList.begin(), objectKeyList.end(), + [&index, &multiCreateParamList, &dataSizeList](const std::string &objectKey) { + multiCreateParamList.emplace_back(index, objectKey, dataSizeList[index]); + index++; + }); return Status::OK(); } Status ObjectClientImpl::MultiCreate(const std::vector &objectKeyList, - const std::vector &dataSizeList, const CreateParam ¶m, - std::vector> &bufferList) + const std::vector &dataSizeList, const FullParam ¶m, + const bool skipCheckExistence, std::vector> &bufferList, + std::vector &exists) { std::shared_lock shutdownLck(shutdownMux_); RETURN_IF_NOT_OK(IsClientReady()); @@ -912,15 +873,37 @@ Status ObjectClientImpl::MultiCreate(const std::vector &objectKeyLi std::shared_lock lck(memoryRefMutex_); std::vector multiCreateParamList; - RETURN_IF_NOT_OK(ConstructMultiCreateParam(objectKeyList, dataSizeList, param, bufferList, multiCreateParamList)); - // if multiCreateParamList is empty, not call MultiCreate rpc - if (multiCreateParamList.empty()) { - return Status::OK(); - } + RETURN_IF_NOT_OK(ConstructMultiCreateParam(objectKeyList, dataSizeList, bufferList, multiCreateParamList)); // If failed with create, need to rollback. auto version = 0u; - RETURN_IF_NOT_OK(workerApi_[LOCAL_WORKER]->MultiCreate(multiCreateParamList, version)); + auto useShmTransfer = false; + auto sizeSum = std::accumulate(dataSizeList.begin(), dataSizeList.end(), 0); + if (!skipCheckExistence || static_cast(sizeSum) >= workerApi_[LOCAL_WORKER]->GetShmThreshold()) { + RETURN_IF_NOT_OK(workerApi_[LOCAL_WORKER]->MultiCreate(skipCheckExistence, multiCreateParamList, version, + exists, useShmTransfer)); + } else { + exists.resize(objectKeyList.size(), false); + } + if (!useShmTransfer) { + for (size_t i = 0; i < objectKeyList.size(); i++) { + if (exists[i]) { + continue; + } + auto &objectKey = objectKeyList[i]; + auto dataSize = dataSizeList[i]; + auto version = 0u; + std::shared_ptr newBuffer; + ObjectBufferInfo bufferInfo = SetObjectBufferInfo(objectKey, nullptr, dataSize, 0, param, false, version); + auto rc = Buffer::CreateBuffer(bufferInfo, shared_from_this(), newBuffer); + if (rc.IsError()) { + bufferList.clear(); + return rc; + } + bufferList[i] = std::move(newBuffer); + } + return Status::OK(); + } bool isInactive = false; Raii handlerCreateFailed([&isInactive, &bufferList, this]() { if (isInactive) { @@ -934,9 +917,11 @@ Status ObjectClientImpl::MultiCreate(const std::vector &objectKeyLi } bufferList.clear(); }); - Status injectRC = Status::OK(); for (auto &createParam : multiCreateParamList) { + if (exists[createParam.index]) { + continue; + } PerfPoint mmapPoint(PerfKey::CLIENT_LOOK_UP_MMAP_FD); auto &shmBuf = createParam.shmBuf; RETURN_IF_NOT_OK(mmapManager_->LookupUnitsAndMmapFd("", shmBuf)); @@ -961,7 +946,6 @@ Status ObjectClientImpl::MultiCreate(const std::vector &objectKeyLi bufferList[createParam.index] = std::move(newBuffer); } isInactive = true; - return Status::OK(); } @@ -1097,7 +1081,7 @@ Status ObjectClientImpl::Seal(const std::shared_ptr &bufferInf RETURN_IF_NOT_OK(IsClientReady()); PerfPoint sealPoint(PerfKey::CLIENT_SEAL_OBJECT); RETURN_IF_NOT_OK(CheckConnection()); - RETURN_IF_NOT_OK(CheckStringVector(nestedObjectKeys, true)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(nestedObjectKeys, true)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( Validator::IsBatchSizeUnderLimit(nestedObjectKeys.size()), K_INVALID, FormatString("The nestedObjectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); @@ -1123,7 +1107,7 @@ Status ObjectClientImpl::Publish(const std::shared_ptr &buffer RETURN_IF_NOT_OK(IsClientReady()); PerfPoint perfPoint(PerfKey::CLIENT_PUBLISH_OBJECT); RETURN_IF_NOT_OK(CheckConnection()); - RETURN_IF_NOT_OK(CheckStringVector(nestedObjectKeys, true)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(nestedObjectKeys, true)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( Validator::IsBatchSizeUnderLimit(nestedObjectKeys.size()), K_INVALID, FormatString("The nestedObjectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); @@ -1141,14 +1125,14 @@ Status ObjectClientImpl::Publish(const std::shared_ptr &buffer Status ObjectClientImpl::InvalidateBuffer(const std::string &objectKey) { RETURN_IF_NOT_OK(IsClientReady()); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(objectKey), K_INVALID, "The objectKey contains illegal char(s)."); + RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); RETURN_IF_NOT_OK(CheckConnection()); RETURN_IF_NOT_OK(workerApi_[LOCAL_WORKER]->InvalidateBuffer(objectKey)); return Status::OK(); } Status ObjectClientImpl::ProcessShmPut(const std::string &objectKey, const uint8_t *data, uint64_t size, - const CreateParam ¶m, + const FullParam ¶m, const std::unordered_set &nestedObjectKeys, uint32_t ttlSecond, const std::shared_ptr &workerApi, int existence) { @@ -1248,14 +1232,14 @@ Status ObjectClientImpl::Publish(const std::vector return result; } -Status ObjectClientImpl::Put(const std::string &objectKey, const uint8_t *data, uint64_t size, const CreateParam ¶m, +Status ObjectClientImpl::Put(const std::string &objectKey, const uint8_t *data, uint64_t size, const FullParam ¶m, const std::unordered_set &nestedObjectKeys, uint32_t ttlSecond, int existence) { std::shared_lock shutdownLck(shutdownMux_); RETURN_IF_NOT_OK(IsClientReady()); PerfPoint perfPoint(PerfKey::CLIENT_PUT_OBJECT); CHECK_FAIL_RETURN_STATUS(!objectKey.empty(), K_INVALID, "The objectKey should not be empty."); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(objectKey), K_INVALID, "The objectKey contains illegal char(s)."); + RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); CHECK_FAIL_RETURN_STATUS(data != nullptr, K_INVALID, "The data pointer should not be null."); CHECK_FAIL_RETURN_STATUS(size > 0, K_INVALID, "The dataSize value should be bigger than zero."); CHECK_FAIL_RETURN_STATUS(nestedObjectKeys.find(objectKey) == nestedObjectKeys.end(), K_UNKNOWN_ERROR, @@ -1304,7 +1288,7 @@ Status ObjectClientImpl::Get(const std::vector &objectKeys, int64_t { PerfPoint perfPoint(PerfKey::CLIENT_GET_OBJECT); RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); std::shared_ptr workerApi; @@ -1342,7 +1326,7 @@ Status ObjectClientImpl::Read(const std::vector &readParams, std::vec for (const auto ¶m : readParams) { objectKeys.emplace_back(param.key); } - RETURN_IF_NOT_OK(CheckStringVector(objectKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys)); GetParam getParam{ .objectKeys = objectKeys, .subTimeoutMs = 0, .readParams = readParams }; Status rc = GetBuffersFromWorker(workerApi, getParam, objectBuffers); buffers.clear(); @@ -1365,9 +1349,10 @@ Status ObjectClientImpl::SetShmObjectBuffer(const std::string &objectKey, const std::shared_ptr mmapEntry; uint8_t *pointer; RETURN_IF_NOT_OK(MmapShmUnit(info.store_fd(), info.mmap_size(), info.offset(), mmapEntry, pointer)); - CreateParam param{ .writeMode = WriteMode(info.write_mode()), - .consistencyType = ConsistencyType(info.consistency_type()), - .cacheType = CacheType(info.cache_type()) }; + FullParam param; + param.writeMode = WriteMode(info.write_mode()); + param.consistencyType = ConsistencyType(info.consistency_type()); + param.cacheType = CacheType(info.cache_type()); ObjectBufferInfo bufferInfo = SetObjectBufferInfo(objectKey, pointer, info.data_size(), info.metadata_size(), param, info.is_seal(), version, info.shm_id(), nullptr, std::move(mmapEntry)); @@ -1397,7 +1382,7 @@ Status ObjectClientImpl::MmapShmUnit(int64_t fd, uint64_t mmapSize, ptrdiff_t of } ObjectBufferInfo ObjectClientImpl::SetObjectBufferInfo(const std::string &objectKey, uint8_t *pointer, uint64_t size, - uint64_t metaSize, const CreateParam ¶m, bool isSeal, + uint64_t metaSize, const FullParam ¶m, bool isSeal, uint32_t version, const std::string &shmId, const std::shared_ptr &payloadPointer, std::shared_ptr mmapEntry) @@ -1508,9 +1493,10 @@ Status ObjectClientImpl::SetNonShmObjectBuffer(const std::string &objectKey, con int version, std::vector &payloads, std::shared_ptr &bufferPtr) { - CreateParam param{ .writeMode = WriteMode(payloadInfo.write_mode()), - .consistencyType = ConsistencyType(payloadInfo.consistency_type()), - .cacheType = CacheType(payloadInfo.cache_type()) }; + FullParam param; + param.writeMode = WriteMode(payloadInfo.write_mode()); + param.consistencyType = ConsistencyType(payloadInfo.consistency_type()); + param.cacheType = CacheType(payloadInfo.cache_type()); int payloadIndexSize = payloadInfo.part_index().size(); if (payloadIndexSize == 1) { std::shared_ptr payloadSharedPtr = @@ -1566,9 +1552,10 @@ Status ObjectClientImpl::SetOffsetReadObjectBuffer(const std::string &objectKey, std::shared_ptr mmapEntry; uint8_t *pointer; MmapShmUnit(info.store_fd(), info.mmap_size(), info.offset(), mmapEntry, pointer); - CreateParam param{ .writeMode = WriteMode(info.write_mode()), - .consistencyType = ConsistencyType(info.consistency_type()), - .cacheType = CacheType(info.cache_type()) }; + FullParam param; + param.writeMode = WriteMode(info.write_mode()); + param.consistencyType = ConsistencyType(info.consistency_type()); + param.cacheType = CacheType(info.cache_type()); ObjectBufferInfo bufferInfo = SetObjectBufferInfo(objectKey, pointer, info.data_size(), info.metadata_size(), param, info.is_seal(), version, info.shm_id(), nullptr, std::move(mmapEntry)); @@ -1598,7 +1585,7 @@ Status ObjectClientImpl::GIncreaseRef(const std::vector &objectKeys PerfPoint point(PerfKey::CLIENT_GINCREASE_REFERENCE); std::shared_lock shutdownLck(shutdownMux_); RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(failedObjectKeys.empty(), K_INVALID, "The failedObjectKeys not empty"); @@ -1701,8 +1688,7 @@ Status ObjectClientImpl::GDecreaseRef(const std::vector &objectKeys PerfPoint point(PerfKey::CLIENT_GDECREASE_REFERENCE); RETURN_IF_NOT_OK(IsClientReady()); for (auto &objectKey : objectKeys) { - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(objectKey), K_INVALID, - "The objectKey contains illegal char(s)."); + RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); } CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); @@ -1782,6 +1768,15 @@ void ObjectClientImpl::GDecreaseRefRollback(const std::vector &roll LOG(WARNING) << "[Ref] failed GDecreaseRef objectKeys " << VectorToString(rollbackObjectKeys); } +Status ObjectClientImpl::CheckValidObjectKey(const std::string &key) +{ + CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(key), K_INVALID, + FormatString("The key contains illegal char(s), allowed regex format: %s " + "or the length of key must be no more than 255, current key length is %d.", + Validator::objKeyFormat, key.size())); + return Status::OK(); +} + void ObjectClientImpl::RemoveZeroGlobalRefByRefTable(const std::vector &checkIds, std::map &accessorTable) { @@ -1832,7 +1827,7 @@ Status ObjectClientImpl::Delete(const std::vector &objectKeys, std: { PerfPoint perfPoint(PerfKey::HETERO_CLIENT_DELETE); RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); std::shared_ptr workerApi; @@ -1866,10 +1861,11 @@ void ObjectClientImpl::AddTbbLockForGlobalRefIds(const std::vector Status ObjectClientImpl::Set(const std::string &key, const StringView &val, const SetParam &setParam) { RETURN_IF_NOT_OK(IsClientReady()); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(key), K_INVALID, "The key contains illegal char(s)."); - CreateParam param{ .writeMode = setParam.writeMode, - .consistencyType = ConsistencyType::CAUSAL, - .cacheType = setParam.cacheType }; + RETURN_IF_NOT_OK(CheckValidObjectKey(key)); + FullParam param; + param.writeMode = setParam.writeMode; + param.consistencyType = ConsistencyType::CAUSAL; + param.cacheType = setParam.cacheType; return Put(key, reinterpret_cast(val.data()), val.size(), param, {}, setParam.ttlSecond, static_cast(setParam.existence)); } @@ -1891,19 +1887,12 @@ Status ObjectClientImpl::CheckMultiSetInputParamValidationNtx(const std::vector< std::map &kv) { CHECK_FAIL_RETURN_STATUS(keys.size() > 0, K_INVALID, "The keys should not be empty."); - CHECK_FAIL_RETURN_STATUS( - keys.size() < BATCH_SET_MAX_KEY_COUNT, K_INVALID, - FormatString("The maximum size of keys in single operation is less than %d.", BATCH_SET_MAX_KEY_COUNT)); CHECK_FAIL_RETURN_STATUS(keys.size() == vals.size(), K_INVALID, "The number of key and value is not the same."); for (size_t i = 0; i < keys.size(); ++i) { CHECK_FAIL_RETURN_STATUS(!keys[i].empty(), K_INVALID, "The key should not be empty."); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(keys[i]), K_INVALID, - FormatString("The key %s contains illegal char(s).", keys[i])); + RETURN_IF_NOT_OK(CheckValidObjectKey(keys[i])); CHECK_FAIL_RETURN_STATUS(vals[i].data() != nullptr, K_INVALID, FormatString("The value associated with key %s should not be empty.", keys[i])); - CHECK_FAIL_RETURN_STATUS(vals[i].size() < workerApi_[LOCAL_WORKER]->GetShmThreshold(), K_INVALID, - FormatString("The size for the val must be less than %d Byte", - workerApi_[LOCAL_WORKER]->GetShmThreshold())); if (kv.find(keys[i]) == kv.end()) { kv[keys[i]] = vals[i]; } else { @@ -1928,8 +1917,7 @@ Status ObjectClientImpl::CheckMultiSetInputParamValidation(const std::vector keyRecord; for (size_t i = 0; i < keys.size(); ++i) { CHECK_FAIL_RETURN_STATUS(!keys[i].empty(), K_INVALID, "The key should not be empty."); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(keys[i]), K_INVALID, - FormatString("The key %s contains illegal char(s).", keys[i])); + RETURN_IF_NOT_OK(CheckValidObjectKey(keys[i])); CHECK_FAIL_RETURN_STATUS(vals[i].data() != nullptr, K_INVALID, FormatString("The value associated with key %s should not be empty.", keys[i])); CHECK_FAIL_RETURN_STATUS(kv.find(keys[i]) == kv.end(), K_INVALID, @@ -1946,7 +1934,10 @@ Status ObjectClientImpl::AllocateMemoryForMSet(const std::map> &bufferInfo, const CacheType &cacheType) { - CreateParam param{ .writeMode = writeMode, .consistencyType = ConsistencyType::CAUSAL, .cacheType = cacheType }; + FullParam param; + param.writeMode = writeMode, + param.consistencyType = ConsistencyType::CAUSAL; + param.cacheType = cacheType; int i = 0; for (const auto &keyValue : kv) { ObjectBufferInfo objInfo; @@ -1992,30 +1983,54 @@ Status ObjectClientImpl::MSet(const std::vector &keys, const std::v std::unique_ptr raii; RETURN_IF_NOT_OK(GetAvailableWorkerApi(workerApi, raii)); LOG(INFO) << "Begin to multiput object." << VectorToString(keys); - std::vector> buffers(kv.size(), nullptr); - std::vector> bufferInfo(kv.size(), nullptr); - CreateParam creatParam{ .writeMode = param.writeMode, - .consistencyType = ConsistencyType::CAUSAL, - .cacheType = param.cacheType }; - int i = 0; - for (const auto &keyValue : kv) { - ObjectBufferInfo objInfo; - // if is not transaction, the val of object must less than 500k, not ShmCreateable. - objInfo = - SetObjectBufferInfo(keyValue.first, reinterpret_cast(const_cast(keyValue.second.data())), - keyValue.second.size(), 0, creatParam, false, 0); - bufferInfo[i] = std::make_shared(objInfo); - i++; + FullParam creatParam; + creatParam.writeMode = param.writeMode; + creatParam.consistencyType = ConsistencyType::CAUSAL; + creatParam.cacheType = param.cacheType; + std::vector filteredKeys; + std::vector filteredValues; + filteredKeys.reserve(kv.size()); + filteredValues.reserve(kv.size()); + for (const auto &key : keys) { + if (kv.find(key) != kv.end()) { + filteredKeys.emplace_back(key); + filteredValues.emplace_back(kv[key]); + } + } + PerfPoint point(PerfKey::CLIENT_MSET_MULTICREATE); + std::vector dataSizeList; + for (const auto &val : filteredValues) { + dataSizeList.emplace_back(val.size()); + } + std::vector> bufferList; + std::vector exist; + RETURN_IF_NOT_OK(MultiCreate(filteredKeys, dataSizeList, creatParam, true, bufferList, exist)); + std::vector> bufferInfoList; + bufferInfoList.reserve(bufferList.size()); + point.RecordAndReset(PerfKey::CLIENT_MSET_MEMCOPY); + auto idx = 0; + for (auto &buffer : bufferList) { + if (buffer == nullptr) { + bufferInfoList.emplace_back(std::make_shared(SetObjectBufferInfo( + filteredKeys[idx], reinterpret_cast(const_cast(filteredValues[idx].data())), + filteredValues[idx].size(), 0, creatParam, false, 0))); + continue; + } + RETURN_IF_NOT_OK(buffer->CheckDeprecated()); + CHECK_FAIL_RETURN_STATUS(!buffer->bufferInfo_->isSeal, K_OC_ALREADY_SEALED, "Client object is already sealed"); + RETURN_IF_NOT_OK(buffer->MemoryCopy(filteredValues[idx].data(), filteredValues[idx].size())); + bufferInfoList.emplace_back(buffer->bufferInfo_); + idx++; } + point.RecordAndReset(PerfKey::CLIENT_MSET_MULTI_PUBLSIH); MultiPublishRspPb rsp; PublishParam publishParam{ .isTx = false, .isReplica = false, .existence = param.existence, .ttlSecond = param.ttlSecond }; - RETURN_IF_NOT_OK(workerApi->MultiPublish(bufferInfo, publishParam, rsp)); + RETURN_IF_NOT_OK(workerApi->MultiPublish(bufferInfoList, publishParam, rsp)); for (const auto &objKey : rsp.failed_object_keys()) { outFailedKeys.emplace_back(objKey); } - LOG(INFO) << "Finish to multiset key: " << VectorToString(keys); Status recvRc(static_cast(rsp.last_rc().error_code()), rsp.last_rc().error_msg()); if (!outFailedKeys.empty() || recvRc.IsError()) { LOG(WARNING) << "Cannot set all the objects from worker, status:" << recvRc.ToString() @@ -2072,8 +2087,7 @@ Status ObjectClientImpl::MSet(const std::vector &keys, const std::v Status ObjectClientImpl::GenerateKey(std::string &key, const std::string &prefixKey) { - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(prefixKey), K_INVALID, - "The objectKey contains illegal char(s), allowed regex format: " + Validator::idFormat); + RETURN_IF_NOT_OK(CheckValidObjectKey(prefixKey)); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(IsClientReady(), "Generate key failed."); std::shared_ptr workerApi; @@ -2136,20 +2150,12 @@ std::shared_ptr ObjectClientImpl::GetMemoryCopyThreadPool() return memoryCopyThreadPool_; } -Status ObjectClientImpl::CreateDevBuffer(const std::string &devObjKey, uint64_t size, void *devPtr, int32_t deviceIdx, - std::shared_ptr &deviceBuffer) +Status ObjectClientImpl::CreateDevBuffer(const std::string &devObjKey, const DeviceBlobList &devBlobList, + const CreateDeviceParam ¶m, std::shared_ptr &deviceBuffer) { RETURN_IF_NOT_OK(IsClientReady()); - return devOcImpl_->CreateDevBuffer(devObjKey, size, devPtr, deviceIdx, deviceBuffer); -} - -Status ObjectClientImpl::CreateDevBuffer(const std::string &devObjKey, const std::vector &dataInfoList, - int32_t deviceIdx, const CreateDeviceParam ¶m, - std::shared_ptr &deviceBuffer) -{ PerfPoint perfPoint(PerfKey::HETERO_CLIENT_CREATE_DEV_BUFFER); - RETURN_IF_NOT_OK(IsClientReady()); - return devOcImpl_->CreateDevBuffer(devObjKey, dataInfoList, deviceIdx, param, deviceBuffer); + return devOcImpl_->CreateDevBuffer(devObjKey, devBlobList, param, deviceBuffer); } Status ObjectClientImpl::PublishDeviceObject(std::shared_ptr buffer) @@ -2173,19 +2179,18 @@ Status ObjectClientImpl::GetSendStatus(const std::shared_ptr &buff return devOcImpl_->GetSendStatus(buffer, futureVec); } -Status ObjectClientImpl::GetDataInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &dataInfos) +Status ObjectClientImpl::GetBlobsInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &blobs) { RETURN_IF_NOT_OK(IsClientReady()); CHECK_FAIL_RETURN_STATUS(!devObjKey.empty(), K_INVALID, "The objectKey is empty"); - CHECK_FAIL_RETURN_STATUS(Validator::IsIdFormat(devObjKey), K_INVALID, - "The devObjKey maybe contains illegal char(s) or the length of id is > 255."); + RETURN_IF_NOT_OK(CheckValidObjectKey(devObjKey)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( Validator::IsInNonNegativeInt32(timeoutMs), K_INVALID, FormatString("timeoutMs %d is out of range., which should be between [%d, %d]", timeoutMs, 0, INT32_MAX)); std::shared_ptr workerApi; std::unique_ptr raii; RETURN_IF_NOT_OK(GetAvailableWorkerApi(workerApi, raii)); - return workerApi->GetDataInfo(devObjKey, timeoutMs, dataInfos); + return workerApi->GetBlobsInfo(devObjKey, timeoutMs, blobs); } Status ObjectClientImpl::RemoveP2PLocation(const std::string &objectKey, int32_t deviceId) @@ -2201,7 +2206,7 @@ Status ObjectClientImpl::GetObjMetaInfo(const std::string &tenantId, const std:: std::vector &objMetas) { RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(objectKeys.size() <= OBJ_META_MAX_SIZE_LIMIT, K_INVALID, FormatString("The objectKeys size exceed %d.", OBJ_META_MAX_SIZE_LIMIT)); std::shared_ptr workerApi; @@ -2228,7 +2233,7 @@ Status ObjectClientImpl::DeleteDevObjects(const std::vector &objKey { PerfPoint perfPoint(PerfKey::HETERO_CLIENT_DEV_DELETE); RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(objKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objKeys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); std::shared_ptr workerApi; @@ -2274,8 +2279,13 @@ Status ObjectClientImpl::MultiPublish(const std::vector> } Status recvRc(static_cast(rsp.last_rc().error_code()), rsp.last_rc().error_msg()); + auto failedSet = std::set{ rsp.failed_object_keys().begin(), rsp.failed_object_keys().end() }; for (auto &buffer : bufferList) { if (buffer->isShm_) { + if (failedSet.find(buffer->bufferInfo_->objectKey) == failedSet.end()) { + memoryRefCount_.erase(buffer->bufferInfo_->shmId); + buffer->isReleased_ = true; + } buffer->SetVisibility(recvRc.IsOk()); } } @@ -2292,7 +2302,7 @@ Status ObjectClientImpl::MultiPublish(const std::vector> Status ObjectClientImpl::QuerySize(const std::vector &objectKeys, std::vector &outSizes) { RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(objectKeys.size() <= QUERY_SIZE_OBJECT_LIMIT, K_INVALID, FormatString("The objectKeys size exceed %d.", QUERY_SIZE_OBJECT_LIMIT)); std::shared_ptr workerApi; @@ -2336,10 +2346,9 @@ Status ObjectClientImpl::DevPublish(const std::vector &objectKeys, "The size of objectKeys and devBlobList does not match"); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys, true)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys, true)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); - std::vector> dataInfoList; std::vector> devBuffPtrList; CreateDeviceParam createParam = CreateDeviceParam{ LifetimeType::MOVE, false }; RETURN_IF_NOT_OK(ConvertToDevBufferPtrList(objectKeys, devBlobList, createParam, devBuffPtrList)); @@ -2371,10 +2380,9 @@ Status ObjectClientImpl::DevSubscribe(const std::vector &objectKeys "The size of objectKeys and devBlobList does not match"); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys, true)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys, true)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); - std::vector> dataInfoList; std::vector> devBuffPtrList; CreateDeviceParam createParam{ LifetimeType::MOVE, false }; RETURN_IF_NOT_OK(ConvertToDevBufferPtrList(objectKeys, devBlobList, createParam, devBuffPtrList)); @@ -2390,7 +2398,7 @@ Status ObjectClientImpl::DevLocalDelete(const std::vector &objectKe std::vector &failedObjectKeys) { PerfPoint perfPoint(PerfKey::HETERO_CLIENT_LOCAL_DELETE); - RETURN_IF_NOT_OK(CheckStringVector(objectKeys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(objectKeys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(objectKeys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); auto ret = Status::OK(); @@ -2432,6 +2440,7 @@ Status ObjectClientImpl::DevMSet(const std::vector &keys, const std "The size of keys and devBlobList does not match"); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(keys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(keys, true)); std::vector> devBuffPtrList; CreateDeviceParam createParam{ LifetimeType::REFERENCE, true }; RETURN_IF_NOT_OK(ConvertToDevBufferPtrList(keys, blob2dList, createParam, devBuffPtrList)); @@ -2452,7 +2461,7 @@ Status ObjectClientImpl::DevMGet(const std::vector &keys, const std FormatString("Got empty parameters : keys nums %d, blobList nums %d.", keys.size(), blob2dList.size())); CHECK_FAIL_RETURN_STATUS(keys.size() == blob2dList.size(), K_INVALID, "The size of objectKeys and blob2dList does not match"); - RETURN_IF_NOT_OK(CheckStringVector(keys, true)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(keys, true)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(keys.size()), K_INVALID, FormatString("The objectKeys size exceed %d.", OBJECT_KEYS_MAX_SIZE_LIMIT)); std::vector> devBuffPtrList; @@ -2467,11 +2476,11 @@ Status ObjectClientImpl::ConvertToDevBufferPtrList(const std::vector> &deviceBuffPtrList) { - std::vector> dataInfoList; - RETURN_IF_NOT_OK(ConvertToDataInfoList(blob2dList, dataInfoList)); - for (size_t i = 0; i < dataInfoList.size(); i++) { + for (size_t i = 0; i < blob2dList.size(); i++) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckDeviceValid({ (uint32_t)blob2dList[i].deviceIdx }), + "Check device failed."); std::shared_ptr devBuff; - RETURN_IF_NOT_OK(CreateDevBuffer(keys[i], dataInfoList[i], dataInfoList[i][0].deviceIdx, createParam, devBuff)); + RETURN_IF_NOT_OK(CreateDevBuffer(keys[i], blob2dList[i], createParam, devBuff)); devBuff->bufferInfo_->autoRelease = false; devBuff->bufferInfo_->srcOffset = blob2dList[i].srcOffset; deviceBuffPtrList.emplace_back(devBuff); @@ -2523,7 +2532,7 @@ Status ObjectClientImpl::Exist(const std::vector &keys, std::vector { PerfPoint perfPoint(PerfKey::HETERO_CLIENT_EXIST); RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(keys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(keys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(keys.size() <= QUERY_SIZE_OBJECT_LIMIT, K_INVALID, FormatString("The objectKeys size exceed %d.", QUERY_SIZE_OBJECT_LIMIT)); std::shared_ptr workerApi; @@ -2538,7 +2547,7 @@ Status ObjectClientImpl::Expire(const std::vector &keys, uint32_t t { PerfPoint perfPoint(PerfKey::CLIENT_EXPIRE_OBJECT); RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(keys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(keys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(keys.size() <= QUERY_SIZE_OBJECT_LIMIT, K_INVALID, FormatString("The objectKeys size exceed %d.", QUERY_SIZE_OBJECT_LIMIT)); std::shared_ptr workerApi; @@ -2553,7 +2562,7 @@ Status ObjectClientImpl::GetMetaInfo(const std::vector &keys, const std::vector &metaInfos, std::vector &failKeys) { RETURN_IF_NOT_OK(IsClientReady()); - RETURN_IF_NOT_OK(CheckStringVector(keys)); + RETURN_IF_NOT_OK(CheckValidObjectKeyVector(keys)); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(keys.size() <= QUERY_SIZE_OBJECT_LIMIT, K_INVALID, FormatString("The objectKeys size exceed %d.", QUERY_SIZE_OBJECT_LIMIT)); std::shared_ptr workerApi; diff --git a/src/datasystem/client/object_cache/object_client_impl.h b/src/datasystem/client/object_cache/object_client_impl.h index b4ddeeb..40d9a80 100644 --- a/src/datasystem/client/object_cache/object_client_impl.h +++ b/src/datasystem/client/object_cache/object_client_impl.h @@ -71,6 +71,10 @@ struct P2PPeer { uint64_t count; }; +struct FullParam : public CreateParam { + WriteMode writeMode = WriteMode::NONE_L2_CACHE; +}; + using P2PPeerTable = tbb::concurrent_hash_map; class __attribute((visibility("default"))) ObjectClientImpl : public std::enable_shared_from_this { @@ -136,7 +140,7 @@ public: * K_RUNTIME_ERROR: client fd mmap failed. * K_DUPLICATED: the object already exists, no need to create. */ - Status Create(const std::string &objectKey, uint64_t dataSize, const CreateParam ¶m, + Status Create(const std::string &objectKey, uint64_t dataSize, const FullParam ¶m, std::shared_ptr &buffer); /** @@ -192,7 +196,7 @@ public: * @param[in] existence Used by state api, to determine whether to set or not set the key if it does already exist. * @return K_OK on success; the error code otherwise. */ - Status Put(const std::string &objectKey, const uint8_t *data, uint64_t size, const CreateParam ¶m, + Status Put(const std::string &objectKey, const uint8_t *data, uint64_t size, const FullParam ¶m, const std::unordered_set &nestedObjectKeys, uint32_t ttlSecond = 0, int existence = 0); /** @@ -339,21 +343,6 @@ public: */ std::string GetFutureMapIdentifier(const std::string &devObjKey, std::shared_ptr deviceBuffer); - /** - * @brief Invoke worker client to create a device object. - * @param[in] objectKey The ID of the device object to create. ID should not be empty and should only contains - * english alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256. - * @param[in] size The size in bytes of device object. - * @param[in] devPtr The device memory pointer. Pass the pointer if user want do malloc by self. - * Pass the nullptr then client will malloc device memory and free when DeviceBuffer is destructed. - * @param[in] deviceIdx The device index of the device memory. - * @param[out] deviceBuffer The device buffer for the object. - * @return Status K_OK on success; the error code otherwise. - */ - - Status CreateDevBuffer(const std::string &devObjKey, uint64_t size, void *devPtr, int32_t deviceIdx, - std::shared_ptr &deviceBuffer); - /** * @brief Publish device object to datasystem. * @param[in] buffer The device buffer ready to publish. @@ -376,13 +365,12 @@ public: * @brief Invoke worker client to create a device object with p2p. * @param[in] objectKey The ID of the device object to create. ID should not be empty and should only contains * english alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256. - * @param[in] dataInfoList The list of data info. - * @param[in] deviceIdx The device index of the device memory. + * @param[in] devBlobList The list of blob info. * @param[in] param The create param of device object. * @param[out] deviceBuffer The device buffer for the object. * @return Status K_OK on success; the error code otherwise. */ - Status CreateDevBuffer(const std::string &devObjKey, const std::vector &dataInfoList, int32_t deviceIdx, + Status CreateDevBuffer(const std::string &devObjKey, const DeviceBlobList &devBlobList, const CreateDeviceParam ¶m, std::shared_ptr &deviceBuffer); /** @@ -413,15 +401,15 @@ public: Status GetSendStatus(const std::shared_ptr &buffer, std::vector &futureVec); /** - * @brief Obtains the DataInfos, including the number of DataInfo, and the count and DataType of each DataInfo. + * @brief Obtains the DBlobInfos, including the number of blobs, and the count * @param[in] devObjKey The object key. ID should not be empty and should only contains english * alphabetics (a-zA-Z), numbers and ~!@#$%^&*.-_ only. ID length should less than 256. * @param[in] timeoutMs Waiting for the result return if object not ready. A positive integer number required. * 0 means no waiting time allowed. And the range is [0, INT32_MAX]. - * @param[out] dataInfos The list of data info. (Include pointer、count and data type) + * @param[out] blobs The list of data info. (Include pointer、count and data type) * @return K_OK on any object success; the error code otherwise. */ - Status GetDataInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &dataInfos); + Status GetBlobsInfo(const std::string &devObjKey, int32_t timeoutMs, std::vector &blobs); /** * @brief Remove the location of device object @@ -463,12 +451,12 @@ public: /** * @brief For device object, to async get multiple objects * @param[in] objectKeys multiple keys support - * @param[out] devBlobList vector of compose DataInfo + * @param[out] devBlobList vector of compose blobInfo * @param[in] timeoutMs max waiting time of getting data * @return future of AsyncResult, describe get status and failed list. */ std::shared_future MGetH2D(const std::vector &objectKeys, - const std::vector &devBlobList, uint64_t timeout); + const std::vector &devBlobList, uint64_t timeout); /** * @brief For device object, to invoke worker client to create and async publish multiple objects @@ -589,7 +577,7 @@ private: struct MGetAsyncRPCSource { std::future rpcFuture; std::promise promise; - std::vector> dataInfoList; + std::vector devBlobList; // hold the buffer to avoid it destroy before batch release. std::vector> bufferList; std::vector existBufferList; @@ -608,24 +596,22 @@ private: * @brief Check and construct the multi createParam. * @param[in] objectKeyList The vector of the object key that needs to create. * @param[in] dataSizeList The object sizes. - * @param[in] param The create param of device object. * @param[out] bufferList The buffer list needs to store data information. * @param[out] multiCreateParamList The list of objects create param. * @return Status of the result. */ Status ConstructMultiCreateParam(const std::vector &objectKeyList, - const std::vector &dataSizeList, const CreateParam ¶m, + const std::vector &dataSizeList, std::vector> &bufferList, std::vector &multiCreateParamList); /** * @brief For device object, to async get multiple objects - * @param[in] dataInfoList The user dataInfo list of device data. + * @param[in] devBlobList The user blobInfo list of device data. * @param[in] existBufferList The tmp buffer list which will be decrease after memory copy * @return Status of the result. */ - Status HostDataCopy2Device(std::vector> &dataInfoList, - std::vector &existBufferList); + Status HostDataCopy2Device(std::vector &devBlobList, std::vector &existBufferList); /** * @brief Multiple shared memory and copy data from device. @@ -637,18 +623,21 @@ private: */ Status DeviceDataCreate(const std::vector &objectKeys, const std::vector &devBlobList, const SetParam &setParam, std::vector> &bufferList, - std::vector &destroyBufferList); + std::vector &exists); /** * @brief Create multiple objects at a time to the worker. * @param[in] objectKeyList The vector of the object key that needs to create. * @param[in] dataSizeList The object sizes. * @param[in] param The create param of device object. + * @param[in] skipCheckExistence Whether skip check existence of key. * @param[out] bufferList The buffer list needs to store data information. + * @param[out] exists The exist list of key. * @return Status of the result. */ Status MultiCreate(const std::vector &objectKeyList, const std::vector &dataSizeList, - const CreateParam ¶m, std::vector> &bufferList); + const FullParam ¶m, const bool skipCheckExistence, + std::vector> &bufferList, std::vector &exists); /** * @brief Publish multiple objects at a time to the worker. @@ -796,7 +785,7 @@ private: * @return ObjectBufferInfo The struct which stores buffer info. */ static ObjectBufferInfo SetObjectBufferInfo(const std::string &objectKey, uint8_t *pointer, uint64_t size, - uint64_t metaSize, const CreateParam ¶m, bool isSeal, + uint64_t metaSize, const FullParam ¶m, bool isSeal, uint32_t version, const std::string &shmId = {}, const std::shared_ptr &payloadPointer = nullptr, std::shared_ptr mmapEntry = nullptr); @@ -841,6 +830,13 @@ private: void GDecreaseRefRollback(const std::vector &rollbackObjectKeys, std::map &accessorTable); + /** + * @brief Check that the key is in the correct format. + * @param[in] key The key to check. + * @return K_OK on success; the error code otherwise. + */ + static Status CheckValidObjectKey(const std::string &key); + /** * @brief Check that the string inside the container is legitimate. * @param[in] vec Vector to check. @@ -848,16 +844,15 @@ private: * @return K_OK on success; the error code otherwise. */ template - Status CheckStringVector(const Vec &vec, bool nullable = false) + static Status CheckValidObjectKeyVector(const Vec &vec, bool nullable = false) { CHECK_FAIL_RETURN_STATUS(nullable || !vec.empty(), K_INVALID, "The keys are empty"); + size_t index = 0; for (const auto &objectKey : vec) { - CHECK_FAIL_RETURN_STATUS(!objectKey.empty(), K_INVALID, "The objectKey is empty"); CHECK_FAIL_RETURN_STATUS( - Validator::IsIdFormat(objectKey), K_INVALID, - FormatString( - "The key contains illegal char(s), allowed regex format: %s or the length of key is %d > 255.", - Validator::idFormat, objectKey.size())); + !objectKey.empty(), K_INVALID, FormatString("The objectKey at position %d is empty", index)); + RETURN_IF_NOT_OK(CheckValidObjectKey(objectKey)); + index++; } return Status::OK(); } @@ -980,7 +975,7 @@ private: * @param[in] existence Used by state api, to determine whether to set or not set the key if it does already exist. * @return K_OK on success; the error code otherwise. */ - Status ProcessShmPut(const std::string &objectKey, const uint8_t *data, uint64_t size, const CreateParam ¶m, + Status ProcessShmPut(const std::string &objectKey, const uint8_t *data, uint64_t size, const FullParam ¶m, const std::unordered_set &nestedObjectKeys, uint32_t ttlSecond, const std::shared_ptr &workerApi, int existence); @@ -1052,15 +1047,6 @@ private: return clientStateManager_->GetState(); } - /** - * @brief Convert devBlobList to dataInfoList. - * @param[in] devBlobList The blob list of device data. - * @param[out] dataInfoList The dataInfo list of device data. - * @return K_OK on success; the error code otherwise. - */ - Status ConvertToDataInfoList(const std::vector &devBlobList, - std::vector> &dataInfoList); - /** * @brief Convert a list of devBlobList to a list of device buffer pointers. * @param[in] keys A list of keys, each corresponding to a 2D device blob. diff --git a/src/datasystem/client/stream_cache/client_base_impl.cpp b/src/datasystem/client/stream_cache/client_base_impl.cpp new file mode 100644 index 0000000..de7d510 --- /dev/null +++ b/src/datasystem/client/stream_cache/client_base_impl.cpp @@ -0,0 +1,192 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache base class for producer and consumer. + */ + +#include "datasystem/client/stream_cache/client_base_impl.h" + +#include "datasystem/client/stream_cache/producer_consumer_worker_api.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/thread_local.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +ClientBaseImpl::ClientBaseImpl(std::string streamName, std::string tenantId, + std::shared_ptr workerApi, + std::shared_ptr client, MmapManager *mmapManager, + std::shared_ptr listenWorker) + : streamName_(std::move(streamName)), + client_(std::move(client)), + workerApi_(std::move(workerApi)), + mmapManager_(mmapManager), + listenWorker_(std::move(listenWorker)), + lockId_(0), + state_(State::NORMAL), + tenantId_(std::move(tenantId)) +{ +} + +ClientBaseImpl::~ClientBaseImpl() = default; + +Status ClientBaseImpl::Init() +{ + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(workArea_ != ShmView(), K_RUNTIME_ERROR, "ShmView not initialized"); + lockId_ = workerApi_->GetLockId(); + // Set up the shared memory communication area + auto shmUnitInfo = std::make_shared(workArea_.fd, workArea_.mmapSz); + RETURN_IF_NOT_OK(mmapManager_->LookupUnitsAndMmapFd(tenantId_, shmUnitInfo)); + INJECT_POINT("ClientBaseImpl.init_fail_before_cursor"); + cursor_ = std::make_unique(static_cast(shmUnitInfo->GetPointer()) + workArea_.off, workArea_.sz, + lockId_); + RETURN_IF_NOT_OK(cursor_->Init(mmapManager_->GetMmapEntryByFd(shmUnitInfo->fd))); + // If the worker is not down level, write something into the eye catcher area so that the worker + // can also know our compatibility + if (WorkAreaIsV2()) { + cursor_->SetClientVersion(Cursor::K_CURSOR_SIZE_V2); + workerVersion_ = cursor_->GetWorkerVersion(); + } + return Status::OK(); +} + +const std::string &ClientBaseImpl::GetStreamName() +{ + return streamName_; +} + +void ClientBaseImpl::SetInactive() +{ + { + std::unique_lock xlock(recvFdsMutex_); + recvFds_.clear(); + } + LOG_IF_ERROR(ChangeState(State::CLOSE), "SetInactive"); +} + +bool ClientBaseImpl::IsActive() const +{ + return state_ == State::NORMAL; +} + +Status ClientBaseImpl::CheckState() const +{ + return listenWorker_->CheckWorkerAvailable(); +} + +Status ClientBaseImpl::CheckNormalState() const +{ + if (client_) { + RETURN_IF_NOT_OK(client_->CheckWorkerLost()); + } else { + return Status(K_RUNTIME_ERROR, "Client must not be null to do operations on consumer"); + } + if (state_ == State::CLOSE) { + RETURN_STATUS_LOG_ERROR( + StatusCode::K_SC_ALREADY_CLOSED, + FormatString("[%s] has been closed or inactive, please do not operate it", LogPrefix())); + } else if (state_ == State::RESET) { + RETURN_STATUS_LOG_ERROR( + StatusCode::K_SC_STREAM_IN_RESET_STATE, + FormatString("[%s] in Reset state, please do not operate it until Resume() is called", LogPrefix())); + } + return CheckState(); +} + +Status ClientBaseImpl::ChangeState(State newState) +{ + RETURN_OK_IF_TRUE(state_ == newState); // no-op if the state is the same as before + if (state_ == State::CLOSE) { + RETURN_STATUS(K_SC_ALREADY_CLOSED, + FormatString("[%s] has been closed or inactive, please do not operate it", LogPrefix())); + } + state_ = newState; + return Status::OK(); +} + +bool ClientBaseImpl::CheckStreamNameAndTenantId(const std::string &streamName, const std::string &tenantId) +{ + return GetStreamName() == streamName && GetTenantId() == tenantId; +} + +Status ClientBaseImpl::GetShmInfo(const ShmView &shmView, std::shared_ptr &out, + std::shared_ptr &mmapEntry) +{ + // Do a preliminary check if the fd makes any sense. There is no reason a file descriptor + // can grow beyond the size of an unsigned short. + if (static_cast(shmView.fd) > std::numeric_limits::max()) { + RETURN_STATUS(K_OUT_OF_RANGE, FormatString("fd out of range. ShmView %s", shmView.ToStr())); + } + std::unique_lock xlock(recvFdsMutex_); + auto it = recvFds_.find(shmView.fd); + if (it == recvFds_.end()) { + auto pageUnit = std::make_shared(shmView.fd, shmView.mmapSz); + // Also return OUT_OF_RANGE to the caller that the given shmView is stale. + auto rc = mmapManager_->LookupUnitsAndMmapFd(tenantId_, pageUnit); + if (rc.IsError()) { + RETURN_STATUS(K_OUT_OF_RANGE, FormatString("mmap error %s. ShmView %s", rc.GetMsg(), shmView.ToStr())); + } + bool success; + std::tie(it, success) = recvFds_.emplace(shmView.fd, pageUnit); + CHECK_FAIL_RETURN_STATUS(success, K_RUNTIME_ERROR, + FormatString("Fail to insert ShmView [%s] into the map", shmView.ToStr())); + } + auto &pageUnit = it->second; + out = std::make_shared(); + out->pointer = pageUnit->pointer; + out->fd = pageUnit->fd; + out->mmapSize = pageUnit->mmapSize; + out->size = shmView.sz; + out->offset = shmView.off; + mmapEntry = mmapManager_->GetMmapEntryByFd(pageUnit->fd); + return Status::OK(); +} + +Status ClientBaseImpl::CheckAndSetInUse() +{ + bool expected = false; + if (inUse_.compare_exchange_strong(expected, true)) { + INJECT_POINT("CheckAndSetInUse.success.sleep"); + return Status::OK(); + } + return Status( + K_SC_STREAM_IN_USE, + FormatString( + "[%s] Another thread is using the producer/consumer, producer/consumer does not support multithreading.", + LogPrefix())); +} + +void ClientBaseImpl::UnsetInUse() +{ + bool expected = true; + if (!inUse_.compare_exchange_strong(expected, false)) { + // Sanity check failed. + LOG(ERROR) << FormatString( + "[%s] Runtime error: producer/consumer thread safety is not under protection by internal logic.", + LogPrefix()); + } +} + +bool ClientBaseImpl::WorkAreaIsV2() const +{ + bool isV2 = workArea_.sz == Cursor::K_CURSOR_SIZE_V2; + INJECT_POINT_NO_RETURN("ClientBaseImpl.force_downlevel_client", [&isV2] { isV2 = false; }); + return isV2; +} +} // namespace stream_cache +} // namespace client +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/client_base_impl.h b/src/datasystem/client/stream_cache/client_base_impl.h new file mode 100644 index 0000000..2c03815 --- /dev/null +++ b/src/datasystem/client/stream_cache/client_base_impl.h @@ -0,0 +1,188 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache base class for producer and consumer. + */ +#ifndef DATASYSTEM_CLIENT_STREAM_CACHE_CLIENT_BASE_IMPL_H +#define DATASYSTEM_CLIENT_STREAM_CACHE_CLIENT_BASE_IMPL_H + +#include +#include "datasystem/client/listen_worker.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/client/stream_cache/producer_consumer_worker_api.h" +#include "datasystem/common/util/thread_local.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class FirstCallTracer { +public: + bool NeedWriteLog(bool ready) + { + bool needWriteLog = false; + switch (state_) { + case State::INIT: + if (ready) { + state_ = State::CONTINUE; + } else { + state_ = State::WAIT_FIRST; + } + needWriteLog = true; + break; + case State::WAIT_FIRST: + if (ready) { + state_ = State::CONTINUE; + needWriteLog = true; + } + break; + case State::CONTINUE: + break; + } + return needWriteLog; + } + +private: + enum class State { INIT, WAIT_FIRST, CONTINUE }; + State state_{ State::INIT }; +}; +class ClientBaseImpl { +public: + enum class State { CLOSE, NORMAL, RESET }; + ClientBaseImpl(std::string streamName, std::string tenantId, std::shared_ptr workerApi, + std::shared_ptr client, MmapManager *mmapManager, + std::shared_ptr listenWorker); + virtual ~ClientBaseImpl(); + + /** + * @brief Get the stream name for the consumer. + * @return stream name for the consumer. + */ + const std::string &GetStreamName(); + + /** + * @brief Set the worker disconnect. + */ + virtual void SetInactive(); + + /** + * @brief Getter of state_. + * @return Return state_. + */ + bool IsActive() const; + + /** + * @brief Log helper. Creates the prefix for log messages. + * @return The generated log prefix for this Producer. + */ + virtual std::string LogPrefix() const = 0; + + /** + * @brief Get the Tenant Id object. + * @return std::string Tenant id. + */ + std::string GetTenantId() + { + return tenantId_; + } + + /** + * @brief CheckStreamNameAndTenantId + * @param[in] streamName streamName + * @param[in] tenantId tenantId + * @return true if check success + */ + bool CheckStreamNameAndTenantId(const std::string &streamName, const std::string &tenantId); + + /** + * @brief Get the stream name + * @return Stream name. + */ + const std::string &GetStreamName() const + { + return streamName_; + } + + /** + * @brief check if another thread is using the client (producer or consumer), if not, the calling thread will set + * inUse_ to true and use the client. + * @return Status of the call. + */ + Status CheckAndSetInUse(); + + /** + * @brief Set inUse_ to false. The calling thread needs to call CheckAndSetInUse() before using the client + * (producer or consumer) again. + */ + void UnsetInUse(); + + /** + * @brief Check if the work area is downlevel + */ + bool WorkAreaIsV2() const; +protected: + /** + * Base class initialization + * @return + */ + virtual Status Init(); + + /** + * State changing function + */ + Status ChangeState(State newState); + + /** + * @brief Check the state_ is NORMAL + * @return Status of the call. + */ + virtual Status CheckNormalState() const; + + /** + * @brief Check some of the states are normal, used for certain situations. + * @return Status of the call. + */ + virtual Status CheckState() const; + + Status GetShmInfo(const ShmView &shmView, std::shared_ptr &out, + std::shared_ptr &mmapEntry); + + const std::string streamName_; + std::shared_ptr client_; + std::shared_ptr workerApi_; + client::MmapManager *mmapManager_; + std::shared_ptr listenWorker_{ nullptr }; + uint32_t lockId_; + mutable std::mutex recvFdsMutex_; + std::unordered_map> recvFds_; + std::atomic state_; + // A work area that is shared between the corresponding worker::stream_cache::Consumer + // sz is the size of this work area. It is set up by the worker. + ShmView workArea_; + std::unique_ptr cursor_; + std::string tenantId_; + uint32_t workerVersion_ = 0; +private: + // True if there is at least 1 thread actively using the client (producer or consumer). + // There should be at most 1 thread actively using the client (producer or consumer) since it is not thread-safe. + std::atomic inUse_{ false }; +}; +} // namespace stream_cache +} // namespace client +} // namespace datasystem +#endif // DATASYSTEM_CLIENT_STREAM_CACHE_CLIENT_BASE_IMPL_H diff --git a/src/datasystem/client/stream_cache/client_worker_api.cpp b/src/datasystem/client/stream_cache/client_worker_api.cpp new file mode 100644 index 0000000..77024ee --- /dev/null +++ b/src/datasystem/client/stream_cache/client_worker_api.cpp @@ -0,0 +1,277 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement client to worker api. + */ +#include "datasystem/client/stream_cache/client_worker_api.h" + +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/rpc/plugin_generator/zmq_rpc_generator.h" +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/common/rpc/rpc_channel.h" +#include "datasystem/common/rpc/rpc_unary_client_impl.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/rpc/unix_sock_fd.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/protos/rpc_option.pb.h" +#include "datasystem/protos/stream_posix.pb.h" +#include "datasystem/stream/stream_config.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +ClientWorkerApi::ClientWorkerApi(const HostPort &hostPort, RpcCredential cred, + Signature *signature, std::string tenantId) + : ClientWorkerCommonApi(hostPort, cred, HeartbeatType::RPC_HEARTBEAT, signature, + std::move(tenantId)) +{ +} + +Status ClientWorkerApi::Init(int32_t timeoutMs) +{ + RETURN_IF_NOT_OK(ClientWorkerCommonApi::Init(timeoutMs)); + std::shared_ptr channel; + channel = std::make_shared(hostPort_, cred_); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Setting client-worker communication via Unix socket : %s", + (GetShmEnabled() ? "true" : "false")); + // We will enable uds after handshaking with the worker. + if (GetShmEnabled()) { + channel->SetServiceUdsEnabled(ClientWorkerSCService_Stub::FullServiceName(), + GetServiceSockName(ServiceSocketNames::DEFAULT_SOCK)); + } + rpcSession_ = std::make_unique(channel); + return Status::OK(); +} + +Status ClientWorkerApi::CreateProducer(const std::string &streamName, const std::string &producerId, + const ProducerConf &producerConf, ShmView &outPageView, + DataVerificationHeader::SenderProducerNo &senderProducerNo, + bool &enableStreamDataVerification, uint64_t &streamNo, + bool &enableSharedPage, uint64_t &sharedPageSize, ShmView &outStreamMetaView) +{ + CreateProducerReqPb req; + req.set_stream_name(streamName); + req.set_page_size(producerConf.pageSize); + req.set_client_id(GetClientId()); + req.set_producer_id(producerId); + req.set_max_stream_size(producerConf.maxStreamSize); + req.set_auto_cleanup(producerConf.autoCleanup); + req.set_retain_num_consumer(producerConf.retainForNumConsumers); + req.set_encrypt_stream(producerConf.encryptStream); + req.set_reserve_size(producerConf.reserveSize); + req.set_stream_mode(producerConf.streamMode); + reqTimeoutDuration.Init(ClientGetRequestTimeout(timeoutMs_)); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + PerfPoint point(PerfKey::RPC_WORKER_CREATE_PRODUCER); + RpcOptions opts; + opts.SetTimeout(timeoutMs_); + CreateProducerRspPb rsp; + RETURN_IF_NOT_OK(signature_->GenerateSignature(req)); + RETURN_IF_NOT_OK_EXCEPT(rpcSession_->CreateProducer(opts, req, rsp), StatusCode::K_DUPLICATED); + point.Record(); + + outPageView.off = static_cast(rsp.page_view().offset()); + outPageView.sz = rsp.page_view().size(); + outPageView.mmapSz = rsp.page_view().mmap_size(); + outPageView.fd = rsp.page_view().fd(); + + outStreamMetaView.sz = rsp.stream_meta_view().size(); + outStreamMetaView.mmapSz = rsp.stream_meta_view().mmap_size(); + outStreamMetaView.fd = rsp.stream_meta_view().fd(); + outStreamMetaView.off = static_cast(rsp.stream_meta_view().offset()); + + senderProducerNo = rsp.sender_producer_no(); + enableStreamDataVerification = rsp.enable_data_verification(); + streamNo = rsp.stream_no(); + enableSharedPage = rsp.enable_shared_page(); + sharedPageSize = rsp.shared_page_size(); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, P:%s] Create producer success.", LogPrefix(), streamName, + producerId); + return Status::OK(); +} + +Status ClientWorkerApi::Subscribe(const std::string &streamName, const std::string &consumerId, + const SubscriptionConfig &config, SubscribeRspPb &rsp) +{ + auto configReqPtr = std::make_unique(); + configReqPtr->set_subscription_name(config.subscriptionName); + configReqPtr->set_subscription_type(SubscriptionTypePb(config.subscriptionType)); + RpcOptions opts; + opts.SetTimeout(timeoutMs_); + SubscribeReqPb req; + req.set_stream_name(streamName); + req.set_allocated_subscription_config(configReqPtr.release()); + req.set_client_id(GetClientId()); + req.set_consumer_id(consumerId); + reqTimeoutDuration.Init(ClientGetRequestTimeout(timeoutMs_)); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + PerfPoint point(PerfKey::RPC_WORKER_CREATE_SUBSCRIBE); + RETURN_IF_NOT_OK(signature_->GenerateSignature(req)); + RETURN_IF_NOT_OK_EXCEPT(rpcSession_->Subscribe(opts, req, rsp), StatusCode::K_DUPLICATED); + point.Record(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s] Create consumer success.", LogPrefix(), streamName, + consumerId); + return Status::OK(); +} + +Status ClientWorkerApi::SetRpcTimeout(int64_t &requestedTimeout, int32_t &rpcTimeout, int64_t &adjustedTimeout) +{ + // User do not want to wait, but rpc request still takes time + if (requestedTimeout == 0) { + rpcTimeout = rpcTimeoutMs_; + adjustedTimeout = requestedTimeout; + return Status::OK(); + } + + // Make sure the timeout value does not overflow + if (requestedTimeout >= INT_MAX || requestedTimeout < 0) { + requestedTimeout = INT_MAX; + } + + // Adjust time already spent in the client, so less time for the worker + adjustedTimeout = ClientGetRequestTimeout(requestedTimeout); + + // Include request processing because this is a rpc round trip + rpcTimeout = (rpcTimeoutMs_ + adjustedTimeout < INT_MAX) ? rpcTimeoutMs_ + adjustedTimeout : INT_MAX; + + return Status::OK(); +} + +Status ClientWorkerApi::DeleteStream(const std::string &streamName) +{ + RpcOptions opts; + opts.SetTimeout(timeoutMs_); + DeleteStreamReqPb req; + DeleteStreamRspPb rsp; + req.set_stream_name(streamName); + req.set_client_id(GetClientId()); + reqTimeoutDuration.Init(ClientGetRequestTimeout(timeoutMs_)); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + PerfPoint point(PerfKey::RPC_WORKER_DELETE_STREAM); + RETURN_IF_NOT_OK(signature_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->DeleteStream(opts, req, rsp)); + point.Record(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Delete stream success.", LogPrefix(), streamName); + return Status::OK(); +} + +Status ClientWorkerApi::QueryGlobalProducersNum(const std::string &streamName, uint64_t &producerNum) +{ + LOG(INFO) << FormatString("[%s, Stream:%s], Start to query global producer count.", LogPrefix(), streamName); + QueryGlobalNumReqPb req; + QueryGlobalNumRsqPb rsp; + req.set_stream_name(streamName); + req.set_client_id(GetClientId()); + + reqTimeoutDuration.Init(ClientGetRequestTimeout(timeoutMs_)); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + RETURN_IF_NOT_OK(signature_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->QueryGlobalProducersNum(req, rsp)); + producerNum = rsp.global_count(); + return Status::OK(); +} + +Status ClientWorkerApi::QueryGlobalConsumersNum(const std::string &streamName, uint64_t &consumerNum) +{ + LOG(INFO) << FormatString("[%s, Stream:%s], Start to query global consumer count.", LogPrefix(), streamName); + QueryGlobalNumReqPb req; + QueryGlobalNumRsqPb rsp; + req.set_stream_name(streamName); + + reqTimeoutDuration.Init(ClientGetRequestTimeout(timeoutMs_)); + req.set_client_id(GetClientId()); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + RETURN_IF_NOT_OK(signature_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->QueryGlobalConsumersNum(req, rsp)); + consumerNum = rsp.global_count(); + return Status::OK(); +} + +Status ClientWorkerApi::ResetStreams(const std::vector &streamNames) +{ + ResetOrResumeStreamsReqPb req; + ResetOrResumeStreamsRspPb rsp; + + RpcOptions opts; + opts.SetTimeout(rpcTimeoutMs_); + for (auto &streamName : streamNames) { + req.add_stream_names(streamName); + } + req.set_client_id(GetClientId()); + + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + PerfPoint point(PerfKey::RPC_WORKER_RESET_STREAM); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + RetryOnError( + timeoutMs_, + [this, &opts, &req, &rsp](int32_t) { + RETURN_IF_NOT_OK(signature_->GenerateSignature(req)); + return rpcSession_->ResetStreams(opts, req, rsp); + }, + []() { return Status::OK(); }, { K_RPC_CANCELLED, K_RPC_DEADLINE_EXCEEDED, K_RPC_UNAVAILABLE }), + "Reset streams on worker error."); + point.Record(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Reset streams on worker success.", LogPrefix()); + return Status::OK(); +} + +Status ClientWorkerApi::ResumeStreams(const std::vector &streamNames) +{ + ResetOrResumeStreamsReqPb req; + ResetOrResumeStreamsRspPb rsp; + + RpcOptions opts; + opts.SetTimeout(rpcTimeoutMs_); + for (auto &streamName : streamNames) { + req.add_stream_names(streamName); + } + req.set_client_id(GetClientId()); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + PerfPoint point(PerfKey::RPC_WORKER_RESUME_STREAM); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + RetryOnError( + timeoutMs_, + [this, &opts, &req, &rsp](int32_t) { + RETURN_IF_NOT_OK(signature_->GenerateSignature(req)); + return rpcSession_->ResumeStreams(opts, req, rsp); + }, + []() { return Status::OK(); }, { K_RPC_CANCELLED, K_RPC_DEADLINE_EXCEEDED, K_RPC_UNAVAILABLE }), + "Resume streams on worker error."); + point.Record(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Resume streams on worker success.", LogPrefix()); + return Status::OK(); +} + +std::string ClientWorkerApi::LogPrefix() const +{ + return FormatString("ClientWorkerApi, EndPoint:%s", hostPort_.ToString()); +} + +std::string ClientWorkerApi::GetClientId() +{ + return ClientWorkerCommonApi::GetClientId(); +} +} // namespace stream_cache +} // namespace client +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/client_worker_api.h b/src/datasystem/client/stream_cache/client_worker_api.h new file mode 100644 index 0000000..1424050 --- /dev/null +++ b/src/datasystem/client/stream_cache/client_worker_api.h @@ -0,0 +1,159 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define api of stream cache. + */ +#ifndef DATASYSTEM_CLIENT_STREAM_CACHE_CLIENT_WORKER_API_H +#define DATASYSTEM_CLIENT_STREAM_CACHE_CLIENT_WORKER_API_H + +#include +#include + +#include "datasystem/common/flags/flags.h" +#include "datasystem/client/client_worker_common_api.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/client/stream_cache/receive_element.h" +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/shared_memory/shm_unit_info.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/protos/stream_posix.stub.rpc.pb.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/utils/optional.h" +#include "datasystem/utils/sensitive_value.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class ClientWorkerApi : public ClientWorkerCommonApi { +public: + /** + * @brief Construct ClientWorkerApi. + * @param[in] hostPort The address of the worker node. + * @param[in] cred The authentication credentials. + * @param[in] signature Used to do AK/SK authenticate. + * @param[in] tenantId TenantId of client user. + * + */ + ClientWorkerApi(const HostPort &hostPort, RpcCredential cred, + Signature *signature = nullptr, std::string tenantId = ""); + + /** + * @brief Initialize the ClientWorkerApi Object. + * @param[in] timeoutMs Timeout milliseconds. + * @return K_OK on success; the error code otherwise. + * K_INVALID: the input ip or port is invalid. + */ + Status Init(int32_t timeoutMs); + + /** + * @brief Send rpc request to worker to create one producer. + * @param[in] streamName The name of stream. + * @param[in] producerId The producer id generate by client. + * @param[in] producerConf The stream configurations, including pageSize, maxStreamSize, etc. + * @param[out] senderProducerNo The producer number generated by master. + * @param[out] enableStreamDataVerification Should data verification be on. + * @param[out] outPageView ShmView of cursor. + * @param[out] outStreamMetaView ShmView of streamMetaShm. + * @return Status of the call. + */ + Status CreateProducer(const std::string &streamName, const std::string &producerId, + const ProducerConf &producerConf, ShmView &outPageView, + DataVerificationHeader::SenderProducerNo &senderProducerNo, + bool &enableStreamDataVerification, uint64_t &streamNo, bool &enableSharedPage, + uint64_t &sharedPageSize, ShmView &outStreamMetaView); + + /** + * @brief Send rpc request to worker to create one consumer. + * @param[in] streamName The name of stream. + * @param[in] consumerId The consumer id generate by client. + * @param[in] config The configure of subscription, such as subscription name, subscription mode. + * @param[out] rsp The response of the subscription + * @return Status of the call. + */ + Status Subscribe(const std::string &streamName, const std::string &consumerId, + const struct SubscriptionConfig &config, SubscribeRspPb &rsp); + + /** + * @brief Send rpc request to worker to delete stream. + * @param[in] streamName The name of stream that will be delete. + * @return Status of the call. + */ + Status DeleteStream(const std::string &streamName); + + /** + * @brief Query global producer count by streamName. + * @param[in] streamName Target stream. + * @param[out] producerNum The producer count. + * @return Status of the call. + */ + Status QueryGlobalProducersNum(const std::string &streamName, uint64_t &producerNum); + + /** + * @brief Query global consumer count by streamName. + * @param[in] streamName Target stream. + * @param[out] consumerNum The consumer count. + * @return Status of the call. + */ + Status QueryGlobalConsumersNum(const std::string &streamName, uint64_t &consumerNum); + + /** + * @brief Send rpc request to worker to reset the provided streams. + * @param[in] streamNames Target streams. + * @return Status of the call. + */ + Status ResetStreams(const std::vector &streamNames); + + /** + * @brief Send rpc request to worker to resume the provided streams. + * @param[in] streamNames Target streams. + * @return Status of the call. + */ + Status ResumeStreams(const std::vector &streamNames); + + /** + * @brief Construct log prefix. + * @return Return the log prefix. + */ + std::string LogPrefix() const; + + /** + * @brief Ensure that the latest client ID is obtained. + * @return Return the client id. + */ + std::string GetClientId(); + +private: + friend class ProducerConsumerWorkerApi; + /** + * @brief To check and reset the timeout of receive. + * @param[in] requestedTimeout The client specified timeout value for time spent in the system. + * @param[out] rpcTimeout The timeout used for RPC. + * @param[out] adjustedTimeout The timeout used in worker. + * @return Status of the call. + */ + Status SetRpcTimeout(int64_t &requestedTimeout, int32_t &rpcTimeout, int64_t &adjustedTimeout); + + std::unique_ptr rpcSession_; +}; +} // namespace stream_cache +} // namespace client +} // namespace datasystem +#endif diff --git a/src/datasystem/client/stream_cache/consumer.cpp b/src/datasystem/client/stream_cache/consumer.cpp new file mode 100644 index 0000000..d3cd6a4 --- /dev/null +++ b/src/datasystem/client/stream_cache/consumer.cpp @@ -0,0 +1,90 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define api of stream cache consumer. + */ +#include "datasystem/stream/consumer.h" + +#include "datasystem/client/stream_cache/consumer_impl.h" +#include "datasystem/common/log/access_recorder.h" +#include "datasystem/common/log/trace.h" +#include "datasystem/utils/optional.h" + +namespace datasystem { + +Consumer::~Consumer() +{ + if (impl_->IsActive()) { + LOG(INFO) << FormatString("[%s] Implicit close consumer", impl_->LogPrefix()); + Status rc = Close(); + if (rc.IsError()) { + LOG(ERROR) << FormatString("[%s] Implicit close consumer failed %s.", impl_->LogPrefix(), rc.GetMsg()); + } + } +} + +Status Consumer::Receive(uint32_t expectNum, uint32_t timeoutMs, std::vector &outElements) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(impl_->CheckAndSetInUse(), "Receive"); + Raii unsetRaii([this]() { impl_->UnsetInUse(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + Optional expectedNumber(expectNum); + return impl_->Receive(expectedNumber, timeoutMs, outElements); +} + +Status Consumer::Receive(uint32_t timeoutMs, std::vector &outElements) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(impl_->CheckAndSetInUse(), "Receive"); + Raii unsetRaii([this]() { impl_->UnsetInUse(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + Optional expectedNumber; + return impl_->Receive(expectedNumber, timeoutMs, outElements); +} + +Status Consumer::Ack(uint64_t elementId) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(impl_->CheckAndSetInUse(), "Ack"); + Raii unsetRaii([this]() { impl_->UnsetInUse(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + return impl_->Ack(elementId); +} + +Status Consumer::Close() +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(impl_->CheckAndSetInUse(), "Close"); + Raii unsetRaii([this]() { impl_->UnsetInUse(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + AccessRecorder recorder(AccessRecorderKey::DS_STREAM_CLOSE_CONSUMER); + auto rc = impl_->Close(); + StreamRequestParam reqParam; + reqParam.streamName = impl_->GetStreamName(); + reqParam.consumerId = impl_->GetConsumerId(); + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Consumer::Consumer(std::unique_ptr impl) : impl_(std::move(impl)) +{ +} + +void Consumer::GetStatisticsMessage(uint64_t &totalElements, uint64_t ¬ProcessedElements) +{ + impl_->GetStatisticsMessage(totalElements, notProcessedElements); +} +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/consumer_impl.cpp b/src/datasystem/client/stream_cache/consumer_impl.cpp new file mode 100644 index 0000000..1a8a070 --- /dev/null +++ b/src/datasystem/client/stream_cache/consumer_impl.cpp @@ -0,0 +1,744 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implementation of stream cache consumer. + */ +#include "datasystem/client/stream_cache/consumer_impl.h" +#include +#include +#include "datasystem/client/listen_worker.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/client/stream_cache/receive_element.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/common/constants.h" +#include "datasystem/common/log/trace.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/util/memory.h" +#include "datasystem/common/util/queue/circular_queue.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/common/inject/inject_point.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +ConsumerImpl::ConsumerImpl(std::string streamName, std::string tenantId, SubscriptionConfig config, + std::string consumerId, const SubscribeRspPb &rsp, + std::shared_ptr workerApi, + std::shared_ptr client, client::MmapManager *mmapManager, + std::shared_ptr listenWorker, bool autoAck) + : ClientBaseImpl(std::move(streamName), std::move(tenantId), std::move(workerApi), std::move(client), mmapManager, + std::move(listenWorker)), + config_(std::move(config)), + consumerId_(std::move(consumerId)), + lastRecvCursor_(rsp.last_recv_cursor()), + pageBoundaryCursor_(lastRecvCursor_.load()), + consumedElements_(lastRecvCursor_.load()), + ackedElementId_(lastRecvCursor_.load()), + rsp_(rsp), + autoAck_(autoAck) +{ + receiveWp_.Set(); + workArea_.fd = rsp_.worker_fd(); + workArea_.mmapSz = rsp_.mmap_size(); + workArea_.off = static_cast(rsp_.offset()); + workArea_.sz = rsp_.size(); +} + +ConsumerImpl::~ConsumerImpl() +{ + if (state_ == State::NORMAL) { + LOG(INFO) << FormatString("[%s] Implicit close consumer.", LogPrefix()); + Status rc = Close(); + if (rc.IsError()) { + LOG(ERROR) << "[" + LogPrefix() + "] Implicit close consumer failed " + rc.GetMsg(); + } + } + client_->ClearConsumer(consumerId_); + LOG(INFO) << FormatString("[%s] Consumer destroy finish.", LogPrefix()); +} + +Status ConsumerImpl::Init() +{ + static const uint32_t SC_CACHE_MAX = 1048576; + static const uint32_t SC_CACHE_MIN = 64; + static const uint16_t PREFETCH_MAX = 100; // 100 percent + + RETURN_IF_NOT_OK(ClientBaseImpl::Init()); + // Sanity check. The worker has already populated the lastAckCursor in the work area. + // It should match the one we received. + auto lastRecvCursor = GetWALastAckCursor(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + lastRecvCursor == lastRecvCursor_, K_OUT_OF_RANGE, + FormatString("Work area eye catcher mismatch. Expect %zu but get %zu", lastRecvCursor_, lastRecvCursor)); + + // Init prefetch area and fields if needed + RETURN_RUNTIME_ERROR_IF_NULL(client_); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(config_.cacheCapacity <= SC_CACHE_MAX && config_.cacheCapacity >= SC_CACHE_MIN, + K_INVALID, "Cache capacity must be between 64 and 1048576"); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + config_.cachePrefetchLWM <= PREFETCH_MAX, K_INVALID, + "Cache prefetch LWM must be between 0 and 100 (it is used as a percentage value)."); + + elementCacheQueue_ = std::make_unique>(config_.cacheCapacity); + cachePrefetchLWM_ = (static_cast(config_.cachePrefetchLWM) / PREFETCH_MAX) * config_.cacheCapacity; + + if (cachePrefetchLWM_ != 0) { + RETURN_IF_NOT_OK(client_->CreatePrefetchPoolIfNotExist()); + } + LOG(INFO) << "Consumer initialized. Prefetch LWM: " << cachePrefetchLWM_ + << " Cache capacity: " << elementCacheQueue_->Capacity() << " LastAckCursor: " << lastRecvCursor; + return Status::OK(); +} + +void ConsumerImpl::SetInactive() +{ + // If we have a page, wake up any thread that is waiting on Receive + std::shared_lock lock(idxMutex_); + for (auto &ele : idx_) { + if (ele.second) { + ele.second->WakeUpConsumers(); + } else { + constexpr int timeSec = 10; + LOG_EVERY_T(WARNING, timeSec) << "Page is nullptr!"; + } + } + ClientBaseImpl::SetInactive(); +} + +Status ConsumerImpl::CheckNormalState() const +{ + // If the worker's state is unknown, do not try to access the shared memory work area + RETURN_IF_NOT_OK(ClientBaseImpl::CheckNormalState()); + if (cursor_ == nullptr || cursor_->ForceClose()) { + RETURN_STATUS(K_SC_PRODUCER_NOT_FOUND, "Force close. No producer"); + } + return Status::OK(); +} + +Status ConsumerImpl::ExtractBigElements(std::shared_ptr &dataPage, + std::vector &recvElements) +{ + for (size_t i = 0; i < recvElements.size(); ++i) { + auto &ele = recvElements[i]; + auto cursor = ele.id; + Status rc = ele.CheckAttribute(); + if (rc.IsError()) { + RETURN_STATUS_LOG_ERROR( + rc.GetCode(), FormatString("[%s] Page<%s> Cursor %zu:", LogPrefix(), dataPage->GetPageId(), cursor)); + } + if (ele.IsBigElement()) { + ShmView v; + RETURN_IF_NOT_OK(StreamDataPage::ParseShmViewPb(recvElements[i].ptr, recvElements[i].size, v)); + std::shared_ptr shmInfo; + std::shared_ptr mmapEntry; + RETURN_IF_NOT_OK(GetShmInfo(v, shmInfo, mmapEntry)); + auto bigElementPage = std::make_shared(shmInfo, true, mmapEntry); + RETURN_IF_NOT_OK(bigElementPage->Init()); + // Replace the original pointer with the big element pointer + recvElements[i].ptr = reinterpret_cast(bigElementPage->GetPointer()); + recvElements[i].size = bigElementPage->PageSize(); + LOG(INFO) << FormatString("[%s] Page<%s> Cursor %zu BigElement<%s> Size %zu", LogPrefix(), + dataPage->GetPageId(), cursor, bigElementPage->GetPageId(), + bigElementPage->PageSize()); + } + } + return Status::OK(); +} + +Status ConsumerImpl::PrefetchElements(uint32_t timeoutMs, std::shared_ptr &dataPage, uint32_t targetNum, + uint32_t &totalFetched, bool nonBlockingFetch) +{ + while (totalFetched < targetNum) { + uint32_t numFetched = 0; + // If the cache has no room, do nothing. + RETURN_OK_IF_TRUE(elementCacheQueue_->IsFull()); + auto remaining = elementCacheQueue_->Remaining(); + uint64_t lastRecvCursor = lastRecvCursor_.load(std::memory_order_relaxed); + std::vector recvElements; + // Grab whatever on the page. + RETURN_IF_NOT_OK(dataPage->Receive(lastRecvCursor, timeoutMs, recvElements, LogPrefix())); + RETURN_IF_NOT_OK(ExtractBigElements(dataPage, recvElements)); + RETURN_IF_NOT_OK(ProcessHeaders(recvElements)); + // Push them to the cache. Just copy up to remaining. We will resume next time + // from we left off once we move up lastRecvCursor_; + auto numElementsToAdd = std::min(remaining, recvElements.size()); + std::vector elementListToAdd(numElementsToAdd); + std::copy_n(recvElements.begin(), numElementsToAdd, elementListToAdd.begin()); + CHECK_FAIL_RETURN_STATUS(elementCacheQueue_->BatchPush(elementListToAdd), StatusCode::K_RUNTIME_ERROR, + "Fail to batch push"); + VLOG(SC_NORMAL_LOG_LEVEL) << "Prefetch added " << elementListToAdd.size() << " elements to the local cache."; + lastRecvCursor += elementListToAdd.size(); + lastRecvCursor_.store(lastRecvCursor); + numFetched = static_cast(elementListToAdd.size()); + totalFetched += numFetched; + cursor_->IncrementElementCount(numFetched); + RETURN_OK_IF_TRUE(totalFetched >= targetNum); + // For non-blocking fetch, re-check the same page if there is more coming or a new + // page has been created to continue. + if (nonBlockingFetch) { + continue; + } else { + // For blocking fetch, exit from here and the caller has its own clock to time it. + break; + } + } + return Status::OK(); +} + +Status ConsumerImpl::GetDataPage(const ShmView &shmView, std::shared_ptr &out) +{ + std::shared_ptr pageUnit; + std::shared_ptr mmapEntry; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(GetShmInfo(shmView, pageUnit, mmapEntry), "GetShmInfo"); + auto page = std::make_shared(pageUnit, lockId_, true, false, mmapEntry); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(page->Init(), "Page Init"); + // Worker Eyecatcher V1 works without consumer side client ref count. + if (workerVersion_ < Cursor::K_WORKER_EYECATCHER_V1) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(page->RefPage(LogPrefix()), "RefPage"); + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Page<%s> acquired", LogPrefix(), page->GetPageId()); + lastPage_ = page; + out = std::move(page); + return Status::OK(); +} + +Status ConsumerImpl::GetPrefetchPage(int64_t timeoutMs, std::shared_ptr &out) +{ + std::shared_ptr page; + GetDataPageReqPb req; + req.set_stream_name(streamName_); + req.set_subscription_name(config_.subscriptionName); + req.set_consumer_id(consumerId_); + req.set_last_recv_cursor(lastRecvCursor_); + req.set_timeout_ms(timeoutMs); + ShmView shmView; + RETURN_IF_NOT_OK(workerApi_->GetDataPage(req, shmView)); + RETURN_IF_NOT_OK(GetDataPage(shmView, page)); + out = page; + return Status::OK(); +} + +Status ConsumerImpl::LocatePrefetchPage(int64_t timeoutMs, std::shared_ptr &out) +{ + std::unique_lock xlock(idxMutex_); + if (!idx_.empty()) { + // Locate the page containing lastRecvCursor_; + // But that should be the last page. + out = idx_.rbegin()->second; + return Status::OK(); + } + // Without knowing where to start, send a rpc request to the worker. + std::shared_ptr startPage; + RETURN_IF_NOT_OK(GetPrefetchPage(timeoutMs, startPage)); + LOG(INFO) << FormatString("%s Fetch page [%s]", LogPrefix(), startPage->GetPageId()); + // Use the current cursor as the key. + idx_.emplace(startPage->GetBegCursor() - 1, startPage); + out = startPage; + return Status::OK(); +} + +Status ConsumerImpl::PrefetchEntry(uint32_t targetNum, uint32_t timeoutMs) +{ + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("Fetch at least %zu elements within %zu ms", targetNum, timeoutMs); + Timer t(timeoutMs); + uint32_t totalFetched = 0; + Status rc; + std::shared_ptr page; + const uint32_t futexWaitMs = 10; + // nonBlockingFetch is only true when user set timeoutMs to 0 + bool nonBlockingFetch = timeoutMs == 0; + do { + if (page == nullptr) { + rc = LocatePrefetchPage(t.GetRemainingTimeMs(), page); + if (rc.GetCode() == K_NOT_FOUND) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("Page for %zu not yet created", lastRecvCursor_); + return Status::OK(); + } + RETURN_IF_NOT_OK(rc); + } + RETURN_IF_NOT_OK(CheckNormalState()); + // There is a chance that remaining time is cast to int 0 if it is less than 1. + // Add bool nonBlockingFetch to check if it is really a nonBlockingFetch or not. + rc = PrefetchElements(futexWaitMs, page, targetNum, totalFetched, nonBlockingFetch); + if (rc.GetCode() == K_SC_END_OF_PAGE) { + // Without sending a rpc, following the pointer on the page to get the next page. + ShmView nextPageView = page->GetNextPage(); + std::shared_ptr nextPage; + // Add the page to stream index (similar to worker logic). This marks the page boundary + // where we will do ack or auto ack. + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + GetDataPage(nextPageView, nextPage), + FormatString("[%s] Error acquiring the next page. Current page<%s>", LogPrefix(), page->GetPageId())); + { + std::unique_lock xlock(idxMutex_); + // Just like the worker's logic, we use the last slot of the previous page as the key + uint64_t key = page->GetLastCursor(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(key + 1 == nextPage->GetBegCursor(), K_OUT_OF_RANGE, + FormatString("[%s] Expect begCursor %zu but get %zu", LogPrefix(), + key + 1, nextPage->GetBegCursor())); + bool success = idx_.emplace(key, nextPage).second; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(success, K_DUPLICATED, + FormatString("[%s] Duplicate key %zu", LogPrefix(), key)); + // It is a potential cursor to ack. + pageBoundaryCursor_.store(key, std::memory_order_relaxed); + } + page = nextPage; + rc = Status::OK(); + } + RETURN_IF_NOT_OK_EXCEPT(rc, K_TRY_AGAIN); + } while (totalFetched < targetNum + && ((!nonBlockingFetch && t.ElapsedMilliSecond() < static_cast(timeoutMs)) + || (nonBlockingFetch && rc.GetCode() != K_TRY_AGAIN))); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("Total element fetched %zu in %zu ms", totalFetched, + static_cast(t.ElapsedMilliSecond())); + return Status::OK(); +} + +Status ConsumerImpl::PrefetchReceive() +{ + PerfPoint point(PerfKey::CLIENT_PREFETCH_RECEIVE); + // Set up a receive call to get as much as possible, limited by cache capacity (how many slots remaining). + // This receive always uses zero timeout setting. Either the data is there or it + // is not, do not wait for it. + uint32_t totalCacheSlots = elementCacheQueue_->Remaining(); + VLOG(SC_NORMAL_LOG_LEVEL) << "Prefetch receive will be run. Total cache slots available: " << totalCacheSlots; + // Pass a zero timeout setting. Either the data is there or it is not, do not wait for it. + RETURN_IF_NOT_OK(PrefetchEntry(totalCacheSlots, 0)); + return Status::OK(); +} + +Status ConsumerImpl::CacheFetch(uint32_t cacheLength, uint32_t receiveNum, std::vector &outElements) +{ + PerfPoint point(PerfKey::CLIENT_CACHE_FETCH); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("%s In cache, current remain %zu element, expectNum:%zu", LogPrefix(), + cacheLength, receiveNum); + RETURN_IF_NOT_OK(GetElementFromCache(receiveNum, outElements)); + RETURN_IF_NOT_OK(PostRecvAckHandler(outElements)); + return Status::OK(); +} + +Status ConsumerImpl::CacheHandler(uint32_t &cacheLength, uint32_t receiveNum, uint32_t timeoutMs, + std::vector &outElements, bool &needsRecv) +{ + needsRecv = true; + // If there is enough data in the cache, then fetch this data and return. + if (cacheLength >= receiveNum) { + RETURN_IF_NOT_OK(CacheFetch(cacheLength, receiveNum, outElements)); + needsRecv = false; + + // Before returning, check if we should replenish the cache by invoking a prefetch. + // Conditions for submitting the prefetch task: + // - The cache is currently below the prefetch low water mark. + // - the prefetch thread pool is not at capacity (if too busy, just skip this prefetch) + // - there was already an inflight (or recently completed prefetch that has not been handled yet). + ThreadPool *pfPool = client_->GetPrefetchPool(); + uint64_t newCacheLength = elementCacheQueue_->Length(); + if (newCacheLength < cachePrefetchLWM_ && pfPool->GetWaitingTasksNum() == 0 && !prefetchStatus_.valid()) { + prefetchStatus_ = client_->GetPrefetchPool()->Submit([this]() { + RETURN_IF_NOT_OK(PrefetchReceive()); + return Status::OK(); + }); + } + + return Status::OK(); + } + + // At this point, there is not enough data in the cache. If a prefetch is running in the background (from + // a previous receive), it does not make sense to invoke a new recv again. Wait for the prefetch to complete and + // then perform cache fetch if possible. + if (cachePrefetchLWM_ > 0 && prefetchStatus_.valid()) { + prefetchStatus_.wait(); + Status rc = prefetchStatus_.get(); + LOG_IF_ERROR(rc, "A previous prefetch attempt got an error."); + + // refresh the cache length since we waited for a prefetch thread which may have put more in the cache. + // If there is now enough data in cache queue the prefetch task came back, fetch it and we're done. + cacheLength = elementCacheQueue_->Length(); + if (cacheLength >= receiveNum) { + VLOG(SC_NORMAL_LOG_LEVEL) << "Waited for a prefetch task to complete. Fetch from cache now."; + RETURN_IF_NOT_OK(CacheFetch(cacheLength, receiveNum, outElements)); + needsRecv = false; + } else if (timeoutMs == 0) { + // Not enough data in the cache, but a prefetch just completed (not enough data fetched). + // Do not tell the caller to receive if the timeout is 0, because we just did a fetch and didn't get + // enough data. No need to ask the worker again. + VLOG(SC_NORMAL_LOG_LEVEL) << "Waited for a prefetch task to complete. Not enough data."; + needsRecv = false; + } + } + return Status::OK(); +} + +Status ConsumerImpl::ReceiveImpl(Optional expectNum, uint32_t timeoutMs, std::vector &outElements) +{ + INJECT_POINT("consumerImpl.receive.fail"); + PerfPoint point(PerfKey::CLIENT_RECEIVE_ALL); + receiveWp_.Clear(); + Raii receiveRaii([this]() { receiveWp_.Set(); }); + RETURN_IF_NOT_OK(CheckNormalState()); + if (expectNum) { + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(expectNum.value() > 0, StatusCode::K_INVALID, + "The Receive expectNums must be greater than 0"); + } + + // if prefetching was enabled, collect the status of the previous prefetch if it was immediately ready. + if (cachePrefetchLWM_ > 0 && prefetchStatus_.valid() + && prefetchStatus_.wait_for(std::chrono::seconds(0)) == std::future_status::ready) { + Status rc = prefetchStatus_.get(); + LOG_IF_ERROR(rc, "A previous prefetch attempt got an error."); + } + outElements.clear(); + uint32_t cacheLength = elementCacheQueue_->Length(); + auto receiveNum = (expectNum) ? expectNum.value() : (cacheLength > 0) ? cacheLength : 1u; + + bool needsRecv; + RETURN_IF_NOT_OK(CacheHandler(cacheLength, receiveNum, timeoutMs, outElements, needsRecv)); + RETURN_OK_IF_TRUE(!needsRecv); // early exit if the cache satisfied it. No need to rpc receive. + + // x0 = expectNum - curCacheNum + auto expectElementNum = receiveNum - cacheLength; + VLOG(SC_NORMAL_LOG_LEVEL) << "Cache did not satisfy receive requirement. Invoke worker receive."; + RETURN_IF_NOT_OK(PrefetchEntry(expectElementNum, timeoutMs)); + INJECT_POINT("consumer_after_get_datapage"); + // Log status + LogConsumerCursors(); + + // Refresh the cache length + cacheLength = elementCacheQueue_->Length(); + receiveNum = (expectNum) ? expectNum.value() : (cacheLength > 0) ? cacheLength : 1u; + auto reserveSize = std::min(receiveNum, cacheLength); + outElements.reserve(reserveSize); + // get all local cache, then queue is empty + if (cacheLength > 0) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("%s, %zu element remain in cache", LogPrefix(), cacheLength); + RETURN_IF_NOT_OK(GetElementFromCache(reserveSize, outElements)); + } + if (outElements.empty()) { + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString("[%s] Consumer %s expect recv %d elements, got 0 elements", + LogPrefix(), consumerId_, receiveNum); + return CheckNormalState(); + } + + RETURN_IF_NOT_OK(PostRecvAckHandler(outElements)); + + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s] Consumer %s expect recv %d elements, got %d elements [id:%zu-%zu]", LogPrefix(), consumerId_, receiveNum, + outElements.size(), outElements.front().id, outElements.back().id); + return Status::OK(); +} + +Status ConsumerImpl::Receive(Optional expectNum, uint32_t timeoutMs, std::vector &outElements) +{ + cursor_->IncrementRequestCount(); + PreRecvAckHandler(); + Status rc = ReceiveImpl(expectNum, timeoutMs, outElements); + if (recvTracer_.NeedWriteLog(!outElements.empty())) { + LOG(INFO) << FormatString("[%s] Consumer first receive element count %zu with status %s", LogPrefix(), + outElements.size(), rc.ToString()); + } + // A special return code K_OUT_OF_RANGE may indicate the worker has already restarted. + // Go check the state. + consumedElements_ += outElements.size(); + if (rc.GetCode() == K_OUT_OF_RANGE) { + RETURN_IF_NOT_OK(CheckNormalState()); + } + return rc; +} + +Status ConsumerImpl::GetElementFromCache(uint32_t expectNum, std::vector &outElements) +{ + int elementNums = static_cast(std::min(static_cast(expectNum), elementCacheQueue_->Length())); + outElements.reserve(elementNums); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(elementCacheQueue_->BatchFetchAndPop(outElements, elementNums), + StatusCode::K_RUNTIME_ERROR, "circular queue is empty"); + DCHECK(!outElements.empty()) << "The element size should be greater than 0"; + return Status::OK(); +} + +Status ConsumerImpl::Ack(uint64_t elementId) +{ + PerfPoint point(PerfKey::CLIENT_ACK_ALL); + RETURN_IF_NOT_OK(CheckNormalState()); + // Make sure the data in cache can not be ack. + // There is a racing condition that the check below is not valid + // when the prefetching thread is running in the background. + if (cachePrefetchLWM_ == 0) { + uint64_t cacheLength = elementCacheQueue_->Length(); + uint64_t ackCheckNum = elementId + cacheLength; + auto lastRecvCursor = lastRecvCursor_.load(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + ackCheckNum <= lastRecvCursor, StatusCode::K_INVALID, + FormatString("The ack cursor %zu should less equal than lastRecvCursor %zu. cache size: %zu", elementId, + lastRecvCursor, cacheLength)); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Consumer %s ack %zu", consumerId_, elementId); + // Go through the stream index, and release the pages that we no longer need + std::unique_lock xlock(idxMutex_); + auto it = idx_.begin(); + while (it != idx_.end()) { + std::shared_ptr page = it->second; + // Get the last cursor on the page, which is also the key of the next page in idx_ + // if it exists. + uint64_t lastCursorOnPage = page->GetLastCursor(); + // Worker Eyecatcher V1 works without consumer side client ref count. + if (workerVersion_ >= Cursor::K_WORKER_EYECATCHER_V1 && lastCursorOnPage < elementId) { + it = idx_.erase(it); + bytesSinceLastAck_ = 0; + } else if (workerVersion_ < Cursor::K_WORKER_EYECATCHER_V1 && lastCursorOnPage <= elementId) { + // One more check. If this page is only partially full, we may still need it + // in the future. A simpler way to check it if we have reached the end of the + // index chain. + if (it->first == pageBoundaryCursor_) { + break; + } + RETURN_IF_NOT_OK(ReleasePage(page)); + // Notify the worker via shared memory that this page is no longer needed. + UpdateWALastAckCursor(lastCursorOnPage); + it = idx_.erase(it); + bytesSinceLastAck_ = 0; + } else { + break; + } + } + // Finally unconditionally set the value in the shared memory work area. + // Future newly added consumer(s) will start from this elementId + if (elementId > ackedElementId_.load()) { + UpdateWALastAckCursor(elementId); + ackedElementId_.store(elementId); + } + return Status::OK(); +} + +Status ConsumerImpl::Close() +{ + PerfPoint point(PerfKey::CLIENT_CLOSE_CONSUMER_ALL); + // If consumer is already closed, return OK + RETURN_OK_IF_TRUE(state_ == State::CLOSE); + // If the current state is RESET, it is allowed to close + RETURN_IF_NOT_OK(CheckState()); + { + std::unique_lock lock(idxMutex_); + for (auto &ele : idx_) { + ReleasePage(ele.second); + } + if (cursor_) { + UpdateWALastAckCursor(lastRecvCursor_); + } + idx_.clear(); + } + INJECT_POINT("ConsumerImpl.CloseConsumerRPC.Fail"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerApi_->CloseConsumer(streamName_, config_.subscriptionName, consumerId_), + FormatString("[%s] CloseConsumer request error", LogPrefix())); + RETURN_IF_NOT_OK(ChangeState(State::CLOSE)); + LOG(INFO) << FormatString("[%s] Close consumer success!", LogPrefix()); + return Status::OK(); +} + +Status ConsumerImpl::SetStateToReset() +{ + RETURN_IF_NOT_OK(ChangeState(State::RESET)); + // Wake up consumer and wait for Receive in progress to complete + { + std::shared_lock lock(idxMutex_); + for (auto &ele : idx_) { + ele.second->WakeUpConsumers(); + } + } + receiveWp_.Wait(); + // Wait for prefetch in progress to complete + if (cachePrefetchLWM_ > 0 && prefetchStatus_.valid()) { + prefetchStatus_.wait(); + prefetchStatus_.get(); // call get() so that prefetchStatus_ is no longer valid + } + return Status::OK(); +} + +Status ConsumerImpl::Reset() +{ + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + state_ == State::RESET, K_RUNTIME_ERROR, + FormatString("[%s] The consumer should be in reset state already to cleanup data and metadata", LogPrefix())); + elementCacheQueue_->Clear(); + for (auto &e : idx_) { + RETURN_IF_NOT_OK(ReleasePage(e.second)); + } + idx_.clear(); + lastRecvCursor_ = 0; + UpdateWALastAckCursor(0); + pageBoundaryCursor_ = 0; + bytesSinceLastAck_ = 0; + avgEleSize_ = 0; + lastAvgCount_ = 0; + ackedElementId_ = 0; + consumedElements_ = 0; + return Status::OK(); +} + +Status ConsumerImpl::Resume() +{ + return ChangeState(State::NORMAL); +} + +std::string ConsumerImpl::LogPrefix() const +{ + return FormatString("S:%s, Sub:%s, C:%s", streamName_, config_.subscriptionName, consumerId_); +} + +void ConsumerImpl::GetStatisticsMessage(uint64_t &totalElements, uint64_t ¬ProcessedElements) +{ + if (lastPage_ == nullptr || cursor_ == nullptr) { + totalElements = 0; + notProcessedElements = 0; + return; + } + const uint64_t &proNum = cursor_->GetWALastAckCursor(); + totalElements = lastPage_->GetLastCursor(); + notProcessedElements = totalElements - proNum; +} + +void ConsumerImpl::LogConsumerCursors() +{ + if (lastPage_) { + auto lastAppendCursor = lastPage_->GetLastCursor(); + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s] %d received elements, %d consumed elements, %d elements acked", LogPrefix(), lastAppendCursor, + consumedElements_, ackedElementId_); + } +} + +void ConsumerImpl::PreRecvAckHandler() +{ + if (autoAck_) { + // Trigger ack for the previous received data + uint64_t cursorToAck = consumedElements_; + Status rc = Ack(cursorToAck); + if (rc.IsError()) { + LOG(ERROR) << FormatString("Ack failed: %s", rc.GetMsg()); + } + } +} + +Status ConsumerImpl::PostRecvAckHandler(const std::vector &fetchedElements) +{ + // Track the amount of bytes from the receive that was just done, and move the next ack position further along. + uint64_t elementByteSum = std::accumulate(fetchedElements.begin(), fetchedElements.end(), static_cast(0), + [](uint64_t sum, const Element &e) { return e.size + sum; }); + bytesSinceLastAck_ += elementByteSum; + + // Compute a running average to learn what the average element size is. Implicit conversion rules convert all these + // types to match the float type of avgEleSize + avgEleSize_ = ((avgEleSize_ * lastAvgCount_) + elementByteSum) / (lastAvgCount_ + fetchedElements.size()); + lastAvgCount_ += fetchedElements.size(); + + VLOG(SC_NORMAL_LOG_LEVEL) << "Post Recv Ack stats:" + << "\nBytes since last ack : " << bytesSinceLastAck_ + << "\nNext cursor ack : " << pageBoundaryCursor_ + << "\nNew elements recv count : " << fetchedElements.size() + << "\nAverage element size : " << avgEleSize_; + + return Status::OK(); +} + +Status ConsumerImpl::ReleasePage(std::shared_ptr &page) const +{ + // Worker Eyecatcher V1 works without consumer side client ref count. + RETURN_OK_IF_TRUE(workerVersion_ >= Cursor::K_WORKER_EYECATCHER_V1); + Status rc = page->ReleasePage(LogPrefix()); + RETURN_OK_IF_TRUE(rc.IsOk()); + // Map all other errors to OUT_OF_RANGE. Let the high level code to + // detect if the worker has restarted. + RETURN_STATUS_LOG_ERROR(K_OUT_OF_RANGE, rc.ToString()); +} + +Status ConsumerImpl::ExtractVersion(DataElement &element, ElementHeader::Version &version) +{ + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + element.size > sizeof(ElementHeader::Version), K_OUT_OF_RANGE, + FormatString("[%s] Element (version + header + data) size %llu is not greater than header version size %lu", + LogPrefix(), element.size, sizeof(ElementHeader::Version))); + version = *(element.ptr); + element.ptr++; + element.size--; + return Status::OK(); +} + +Status ConsumerImpl::ProcessHeaders(std::vector &recvElements) +{ + for (auto &element : recvElements) { + if (!element.HasHeader()) { + continue; + } + ElementHeader::Version version; + RETURN_IF_NOT_OK(ExtractVersion(element, version)); + ElementHeader header; + if (version == DATA_VERIFICATION_HEADER) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(DataVerificationHeader::ExtractHeader(element, header), + FormatString("[%s]", LogPrefix())); + RETURN_IF_NOT_OK(VerifyElement(element, header)); + } else { + RETURN_STATUS_LOG_ERROR( + K_INVALID, FormatString("[%s] Does not support element's header version %u", LogPrefix(), version)); + } + } + return Status::OK(); +} + +Status ConsumerImpl::VerifyElement(const DataElement &recvElement, const ElementHeader &recvHeader) +{ + struct DataVerificationHeader header(recvHeader); + const std::string key = + FormatString("%llu-%lu-%lu", header.GetSenderProducerNo(), header.GetAddress(), header.GetPort()); + auto iter = producerLastSeqNoReceive_.find(key); + if (iter == producerLastSeqNoReceive_.end()) { + producerLastSeqNoReceive_.emplace(key, header.GetSeqNo()); + LOG(INFO) << FormatString("[%s] Data Verification: First time receive from producer = %s, seqNo = %llu", + LogPrefix(), key, header.GetSeqNo()); + return Status::OK(); + } + const DataVerificationHeader::SeqNo previousSeqNo = iter->second; + iter->second = header.GetSeqNo(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + header.GetSeqNo() == previousSeqNo + 1, K_DATA_INCONSISTENCY, + FormatString("[%s] Data Verification Failed: length = %zu, producer = %s, seqNo = %llu, " + "expect seqNo = %llu", + LogPrefix(), recvElement.size, key, header.GetSeqNo(), previousSeqNo + 1)); + // Success + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s] Data Verification Success: length = %zu, producer = %s, " + "seqNo = %llu", + LogPrefix(), recvElement.size, key, header.GetSeqNo()); + INJECT_POINT("VerifyProducerNo", [&header, &recvElement]() { + std::string expected(reinterpret_cast(recvElement.ptr), recvElement.size); + std::string producerNo = "producer" + std::to_string(header.GetSenderProducerNo()); + if (expected == producerNo) { + return Status::OK(); + } + RETURN_STATUS_LOG_ERROR( + K_DATA_INCONSISTENCY, + FormatString("DataVerification Failed: incorrect producerNo, get %s, expect %s", producerNo, expected)); + }); + return Status::OK(); +} +} // namespace stream_cache +} // namespace client +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/consumer_impl.h b/src/datasystem/client/stream_cache/consumer_impl.h new file mode 100644 index 0000000..eaac673 --- /dev/null +++ b/src/datasystem/client/stream_cache/consumer_impl.h @@ -0,0 +1,311 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define api of stream cache consumer implementation. + */ + +#ifndef DATASYSTEM_CLIENT_STREAM_CACHE_CONSUMER_IMPL_H +#define DATASYSTEM_CLIENT_STREAM_CACHE_CONSUMER_IMPL_H + +#include "datasystem/client/stream_cache/client_base_impl.h" +#include "datasystem/client/stream_cache/producer_consumer_worker_api.h" +#include "datasystem/common/eventloop/timer_queue.h" +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/stream_cache/cursor.h" +#include "datasystem/common/util/queue/circular_queue.h" +#include "datasystem/stream/element.h" +#include "datasystem/utils/optional.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class __attribute((visibility("default"))) ConsumerImpl : public ClientBaseImpl { +public: + /** + * @brief Construct Consumer. + * @param[in] streamName The name of the stream. + * @param[in] config The configure of subscription, such as subscription name, subscription mode. + * @param[in] consumerId The uuid of the consumer which is generated by Worker. + * @param[in] rsp The response from rpc request + * @param[in] workerApi Used to call worker service through rpc. + * @param[in] mmapManager Used to receive and mmap fd that passed from worker. + * @param[in] listenWorker Listening to the worker survival status. + * @param[in] autoAck Toggles if auto ack is enabled or not for this consumer + */ + ConsumerImpl(std::string streamName, std::string tenantId, SubscriptionConfig config, std::string consumerId, + const SubscribeRspPb &rsp, std::shared_ptr workerApi, + std::shared_ptr client, client::MmapManager *mmapManager, + std::shared_ptr listenWorker, bool autoAck); + + ~ConsumerImpl() override; + + /** + * @brief Initialize the ConsumerImpl + * @return status of the call. + */ + Status Init() override; + + /** + * @brief Sets the state of the Producer to CLOSE + */ + void SetInactive() override; + + /** + * @brief Get expectNum elements form the subscription. + * @param[in] expectNum Parameter to indicate the number of expected elements to receive. If this optional + * parameter is not set, it will receive any amount. + * @param[in] timeoutMs The timeout millisecond of elements to be Receive. + * @param[out] outElements The received elements to be read. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_RPC_UNAVAILABLE: didn't receive any response from server. + * K_DUPLICATED: the consumer already had pending receive. + * K_SC_PRODUCER_NOT_FOUND: one or more producer in the stream are dead + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state + * K_SC_ALREADY_CLOSED: the consumer is already closed + */ + Status Receive(Optional expectNum, uint32_t timeoutMs, std::vector &outElements); + + /** + * @brief Acknowledge elements that had been read by this consumer. + * @param[in] elementId The element id that to be acknowledged. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state. + * K_SC_ALREADY_CLOSED: the consumer is already closed + */ + Status Ack(uint64_t elementId); + + /** + * @brief Close the consumer, after close it will not allow Receive and Ack Elements. + * Calling Close() on an already closed consumer will return K_OK. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_RUNTIME_ERROR: delete sub node in global scope fail on master process. + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state. + */ + Status Close(); + + /** + * @brief Reset the consumer by cleaning up data and metadata. + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: Consumer is not in reset State. + */ + Status Reset(); + + /** + * @brief Set consumer state to Reset. Cancel prefetch of data. During reset it will not allow Receive, Ack. + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: consumer is in close state already. + */ + Status SetStateToReset(); + + /** + * @brief Resume the consumer, allowing Receive, Ack and Close again. + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: consumer is not in reset state. + */ + Status Resume(); + + /** + * @brief Get prefix of log entry. + * @param[in] withSubName Whether print subscription name, by default false. + */ + std::string LogPrefix() const override; + + /** + * @brief Get the amount of received elements since this consumer construct, and the amount of elements + * not processed. + * @param[out] totalElements the amount of elements received by this consumer. + * @param[out] notProcessedElements the amount of elements not processed (received but not ack-ed). + */ + void GetStatisticsMessage(uint64_t &totalElements, uint64_t ¬ProcessedElements); + + /** + * @brief Get the consumer id + * @return Consumer id. + */ + const std::string &GetConsumerId() const + { + return consumerId_; + } + +private: + /** + * @brief Check the state_ of consumer and return status. + * @return Status of the call. + */ + Status CheckNormalState() const override; + + /** + * @brief Get expectNum elements form the local cache queue. + * @param[in] expectNum The number of elements to be read. + * @param[out] outElements The received elements to be read. + * @return Status of the call. + */ + Status GetElementFromCache(uint32_t expectNum, std::vector &outElements); + + /** + * @brief When the cache is sufficient to satisfy the receive, return these cached elements. + * @param[in] cacheLength The length of the cache + * @param[in] receiveNum The number of elements to fetch + * @param[out] outElements The fetched elements output + */ + Status CacheFetch(uint32_t cacheLength, uint32_t receiveNum, std::vector &outElements); + + /** + * @brief Check if the cache can satisfy a receive and drive the cache fetch if needed. Also, invoke cache prefetch + * (if configured for it). + * @param[in/out] cacheLength The length of the cache. + * @param[in] receiveNum The number of elements to fetch + * @param[in] timeoutMs The timeout for receives + * @param[out] outElements The fetched elements output + * @param[out] needsRecv T/F if the cache handler informs that a receive will be needed. + */ + Status CacheHandler(uint32_t &cacheLength, uint32_t receiveNum, uint32_t timeoutMs, + std::vector &outElements, bool &needsRecv); + + /** + * @brief Called from the receive codepath, this function decides if it should send an Ack or not and also gathers + * some statistics about Ack logistics to help decide on future Acks + * @param[in] fetchedElements The element list that was just recently fetched from the worker + * @return status of the call + */ + Status PostRecvAckHandler(const std::vector &fetchedElements); + + /** + * @brief Called from the receive codepath, this function sends Ack before receive if AutoAck is enabled + */ + void PreRecvAckHandler(); + + /** + * @brief Get the last ack cursor from the work area + * @return last ack cursor + */ + uint64_t GetWALastAckCursor() const + { + return cursor_->GetWALastAckCursor(); + } + + /** + * @brief Update the last ack cursor in work area + * @param elementId + */ + void UpdateWALastAckCursor(uint64_t elementId) const + { + cursor_->UpdateWALastAckCursor(elementId); + } + + /** + * @brief Internal function to replenish the local client cache by calling receive to the worker. + * @return status of the call + */ + Status PrefetchReceive(); + + /** + * @brief Get expectNum elements form the subscription. + * @param[in] expectNum Parameter to indicate the number of expected elements to receive. If this optional + * parameter is not set, it will receive any amount. + * @param[in] timeoutMs The timeout millisecond of elements to be Receive. + * @param[out] outElements The received elements to be read. + * @return K_OK on success; the error code otherwise. + * K_UNKNOWN_ERROR: it's up to return message. + * K_NOT_FOUND: the id of stream is not found. + * K_INVALID: invalid parameter. + * K_RPC_UNAVAILABLE: didn't receive any response from server. + * K_DUPLICATED: the consumer already had pending receive. + * K_SC_PRODUCER_NOT_FOUND: one or more producer in the stream are dead + * K_SC_STREAM_IN_RESET_STATE: stream currently in reset state + */ + Status ReceiveImpl(Optional expectNum, uint32_t timeoutMs, std::vector &outElements); + + Status PrefetchEntry(uint32_t targetNum, uint32_t timeoutMs); + Status GetDataPage(const ShmView &shmView, std::shared_ptr &out); + Status PrefetchElements(uint32_t timeoutMs, std::shared_ptr &dataPage, uint32_t targetNum, + uint32_t &totalFetched, bool nonBlockingFetch); + Status GetPrefetchPage(int64_t timeoutMs, std::shared_ptr &out); + Status LocatePrefetchPage(int64_t timeoutMs, std::shared_ptr &out); + Status ReleasePage(std::shared_ptr &page) const; + Status ExtractBigElements(std::shared_ptr &dataPage, std::vector &recvElements); + + /** + * @brief Logs the received elements, consumed elements, and acked (processed) elements + */ + void LogConsumerCursors(); + + /** + * @brief Extract header's version from element. + * @param[in] element Element (header + data). + * @param[out] version The header's version. + * @return K_OK on success. + */ + Status ExtractVersion(DataElement &element, ElementHeader::Version &version); + + /** + * @brief Extract headers from elements and process each header according to the header version. + * @param[in] elements Element (header + data) + * @return K_OK on success. + */ + Status ProcessHeaders(std::vector &elements); + + /** + * @brief Verify element is received in order for its producer. + * @param[in] recvElement Element's data + * @param[in] recvHeader Element's header + * @return K_OK on success + */ + Status VerifyElement(const DataElement &recvElement, const ElementHeader &recvHeader); + + // for make_shared to access private/protected constructor. + friend std::shared_ptr std::make_shared(); + + const SubscriptionConfig config_; + const std::string consumerId_; + std::atomic lastRecvCursor_; + std::atomic pageBoundaryCursor_; + std::atomic bytesSinceLastAck_{ 0 }; + std::atomic consumedElements_{ 0 }; + float avgEleSize_{ 0 }; + uint64_t lastAvgCount_{ 0 }; + uint32_t cachePrefetchLWM_{ 0 }; + std::atomic_uint64_t ackedElementId_; // The ID of last Ack'ed element. + mutable std::shared_timed_mutex idxMutex_; + std::map> idx_; + std::shared_ptr lastPage_; // last page. ref count > 0 + const SubscribeRspPb rsp_; + + std::unique_ptr> elementCacheQueue_; + bool autoAck_{ false }; + std::future prefetchStatus_; + WaitPost receiveWp_; + + // recording the last element's seqNo from a particicular producer, any normal update should be +1 only. + std::unordered_map producerLastSeqNoReceive_; + FirstCallTracer recvTracer_; +}; +} // namespace stream_cache +} // namespace client +} // namespace datasystem + +#endif // DATASYSTEM_CLIENT_STREAM_CACHE_CONSUMER_IMPL_H diff --git a/src/datasystem/client/stream_cache/producer.cpp b/src/datasystem/client/stream_cache/producer.cpp new file mode 100644 index 0000000..c9ec7b9 --- /dev/null +++ b/src/datasystem/client/stream_cache/producer.cpp @@ -0,0 +1,74 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define stream cache producer. + */ +#include "datasystem/stream/producer.h" + +#include "datasystem/client/stream_cache/producer_impl.h" +#include "datasystem/common/log/access_recorder.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/log/trace.h" +#include "datasystem/common/rpc/rpc_constants.h" +#include "datasystem/common/util/strings_util.h" + +namespace datasystem { +Producer::Producer(std::shared_ptr impl) : impl_(std::move(impl)) +{ +} + +Producer::~Producer() +{ + if (impl_->IsActive()) { + LOG(INFO) << FormatString("[%s] Implicit close producer", impl_->LogPrefix()); + Status rc = Close(); + if (rc.IsError()) { + LOG(ERROR) << FormatString("[%s] Implicit close producer failed %s.", impl_->LogPrefix(), rc.GetMsg()); + } + } +} + +Status Producer::Send(const Element &element) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(impl_->CheckAndSetInUse(), "Send"); + Raii unsetRaii([this]() { impl_->UnsetInUse(); }); + return impl_->Send(element, Optional()); +} + +Status Producer::Send(const Element &element, int64_t timeoutMs) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(impl_->CheckAndSetInUse(), "Send"); + Raii unsetRaii([this]() { impl_->UnsetInUse(); }); + return impl_->Send(element, Optional(timeoutMs)); +} + +Status Producer::Close() +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(impl_->CheckAndSetInUse(), "Close"); + Raii unsetRaii([this]() { impl_->UnsetInUse(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + AccessRecorder recorder(AccessRecorderKey::DS_STREAM_CLOSE_PRODUCER); + auto rc = impl_->Close(); + StreamRequestParam reqParam; + reqParam.streamName = impl_->GetStreamName(); + reqParam.producerId = impl_->GetProducerId(); + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/producer_consumer_worker_api.cpp b/src/datasystem/client/stream_cache/producer_consumer_worker_api.cpp new file mode 100644 index 0000000..7a6cf2b --- /dev/null +++ b/src/datasystem/client/stream_cache/producer_consumer_worker_api.cpp @@ -0,0 +1,245 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache base class for producer and consumer. + */ + +#include "datasystem/client/stream_cache/producer_consumer_worker_api.h" +#include +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/shared_memory/shm_unit_info.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/status_helper.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +ProducerConsumerWorkerApi::ProducerConsumerWorkerApi(const std::string tenantId, + std::shared_ptr workerApi) + : tenantId_(tenantId), workerApi_(workerApi){}; + +Status ProducerConsumerWorkerApi::GetDataPage(GetDataPageReqPb &req, ShmView &outPage) +{ + int64_t timeoutMs = req.timeout_ms(); + // Compute the rpc timeout and pass in the adjustedTimeout that the worker will use for blocking/waiting. + // The adjustedTimeout will be slightly smaller than the regular outer rpc timeout. + int32_t rpcTimeout; + int64_t adjustedTimeout; + RETURN_IF_NOT_OK(workerApi_->SetRpcTimeout(timeoutMs, rpcTimeout, adjustedTimeout)); + RpcOptions opts; + opts.SetTimeout(rpcTimeout); + req.set_timeout_ms(adjustedTimeout); + + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + req.set_client_id(workerApi_->GetClientId()); + RETURN_IF_NOT_OK(workerApi_->signature_->GenerateSignature(req)); + GetDataPageRspPb rsp; + + RETURN_IF_NOT_OK(workerApi_->rpcSession_->GetDataPage(opts, req, rsp)); + outPage.off = static_cast(rsp.page_view().offset()); + outPage.sz = rsp.page_view().size(); + outPage.mmapSz = rsp.page_view().mmap_size(); + outPage.fd = rsp.page_view().fd(); + return Status::OK(); +} + +Status ProducerConsumerWorkerApi::AllocBigElementMemory(const std::string &streamName, const std::string &producerId, + size_t sizeNeeded, int64_t timeoutMs, ShmView &outView) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + CreateLobPageReqPb req; + req.set_stream_name(streamName); + req.set_producer_id(producerId); + req.set_page_size(sizeNeeded); + req.set_client_id(workerApi_->GetClientId()); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + CreateLobPageRspPb rsp; + RpcOptions opts; + int32_t rpcTimeout; + int64_t adjustedTimeout; + RETURN_IF_NOT_OK(workerApi_->SetRpcTimeout(timeoutMs, rpcTimeout, adjustedTimeout)); + opts.SetTimeout(rpcTimeout); + req.set_sub_timeout(adjustedTimeout); + std::unordered_set retryCode = { StatusCode::K_RPC_CANCELLED, StatusCode::K_RPC_UNAVAILABLE, + StatusCode::K_RPC_DEADLINE_EXCEEDED }; + PerfPoint point(PerfKey::RPC_WORKER_CREATE_WRITE_PAGE); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(RetryOnError(adjustedTimeout, + [this, &opts, &req, &rsp](int32_t) { + // Set timestamp for worker to determine request order + req.set_timestamp(GetSystemClockTimeStampUs()); + RETURN_IF_NOT_OK(workerApi_->signature_->GenerateSignature(req)); + return workerApi_->rpcSession_->AllocBigShmMemory(opts, req, rsp); + }, + []() { return Status::OK(); }, retryCode), + FormatString("[%s] Client create big element page failed", producerId)); + point.Record(); + outView.off = static_cast(rsp.page_view().offset()); + outView.sz = rsp.page_view().size(); + outView.mmapSz = rsp.page_view().mmap_size(); + outView.fd = rsp.page_view().fd(); + LOG(INFO) << FormatString("[%s, S:%s, P:%s] Client created big element page success. ShmView %s", + workerApi_->LogPrefix(), streamName, producerId, outView.ToStr()); + return Status::OK(); +} + +Status ProducerConsumerWorkerApi::ReleaseBigElementMemory(const std::string &streamName, const std::string &producerId, + const ShmView &pageView) +{ + ReleaseLobPageReqPb req; + req.set_stream_name(streamName); + req.set_producer_id(producerId); + req.set_client_id(workerApi_->GetClientId()); + + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + ShmViewPb pb; + pb.set_fd(pageView.fd); + pb.set_mmap_size(pageView.mmapSz); + pb.set_offset(pageView.off); + pb.set_size(pageView.sz); + req.mutable_page_view()->CopyFrom(pb); + RETURN_IF_NOT_OK(workerApi_->signature_->GenerateSignature(req)); + ReleaseLobPageRspPb rsp; + INJECT_POINT("ProducerConsumerWorkerApi.ReleaseBigElementMemory.preReleaseBigShmMemory"); + RETURN_IF_NOT_OK(workerApi_->rpcSession_->ReleaseBigShmMemory(req, rsp)); + LOG(INFO) << FormatString("[%s, S:%s, P:%s] Client release big element page success. ShmView %s", + workerApi_->LogPrefix(), streamName, producerId, pageView.ToStr()); + return Status::OK(); +} + +Status ProducerConsumerWorkerApi::CreateWritePage(const std::string &streamName, const std::string &producerId, + int64_t timeoutMs, const ShmView &curView, ShmView &outPage) +{ + CreateShmPageReqPb req; + CreateShmPageRspPb rsp; + + // Compute the rpc timeout and pass in the adjustedTimeout that the worker will use for blocking/waiting. + // The adjustedTimeout will be slightly smaller than the regular outer rpc timeout. + int32_t rpcTimeout; + int64_t adjustedTimeout; + RETURN_IF_NOT_OK(workerApi_->SetRpcTimeout(timeoutMs, rpcTimeout, adjustedTimeout)); + INJECT_POINT("client.CreateWritePage", [&rpcTimeout, timeoutMs]() { + rpcTimeout = timeoutMs; + return Status::OK(); + }); + + RpcOptions opts; + opts.SetTimeout(rpcTimeout); + + req.set_stream_name(streamName); + req.set_producer_id(producerId); + req.set_sub_timeout(adjustedTimeout); + req.set_client_id(workerApi_->GetClientId()); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + ShmViewPb pb; + pb.set_fd(curView.fd); + pb.set_mmap_size(curView.mmapSz); + pb.set_offset(curView.off); + pb.set_size(curView.sz); + req.mutable_cur_view()->CopyFrom(pb); + + LOG(INFO) << "Client creating write page. Stream: " << streamName << " producer: " << producerId + << " adjusted timeout: " << adjustedTimeout << " user timeout " << timeoutMs + << " rpc timeout: " << rpcTimeout; + std::unordered_set retryCode = { StatusCode::K_RPC_CANCELLED, StatusCode::K_RPC_UNAVAILABLE, + StatusCode::K_RPC_DEADLINE_EXCEEDED }; + PerfPoint point(PerfKey::RPC_WORKER_CREATE_WRITE_PAGE); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(RetryOnError(adjustedTimeout, + [this, &opts, &req, &rsp](int32_t) { + // Set timestamp for worker to determine request order + req.set_timestamp(GetSystemClockTimeStampUs()); + RETURN_IF_NOT_OK(workerApi_->signature_->GenerateSignature(req)); + return workerApi_->rpcSession_->CreateShmPage(opts, req, rsp); + }, + []() { return Status::OK(); }, retryCode), + "CreateShmPage request error"); + point.Record(); + outPage.off = static_cast(rsp.last_page_view().offset()); + outPage.sz = rsp.last_page_view().size(); + outPage.mmapSz = rsp.last_page_view().mmap_size(); + outPage.fd = rsp.last_page_view().fd(); + LOG(INFO) << FormatString("[%s, S:%s, P:%s] Client created write page success. ShmView %s", + workerApi_->LogPrefix(), streamName, producerId, outPage.ToStr()); + return Status::OK(); +} + +Status ProducerConsumerWorkerApi::CloseProducer(const std::string &streamName, const std::string &producerId) +{ + RpcOptions opts; + opts.SetTimeout(workerApi_->timeoutMs_); + CloseProducerReqPb req; + req.set_stream_name(streamName); + req.set_producer_id(producerId); + req.set_client_id(workerApi_->GetClientId()); + reqTimeoutDuration.Init(workerApi_->ClientGetRequestTimeout(workerApi_->timeoutMs_)); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + PerfPoint point(PerfKey::RPC_WORKER_CLOSE_PRODUCER); + CloseProducerRspPb rsp; + RETURN_IF_NOT_OK(workerApi_->signature_->GenerateSignature(req)); + RETURN_IF_NOT_OK_EXCEPT(workerApi_->rpcSession_->CloseProducer(opts, req, rsp), + StatusCode::K_SC_PRODUCER_NOT_FOUND); + point.Record(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, P:%s] Close producer success", workerApi_->LogPrefix(), + streamName, producerId); + return Status::OK(); +} + +Status ProducerConsumerWorkerApi::CloseConsumer(const std::string &streamName, const std::string &subscriptionName, + const std::string &consumerId) +{ + RpcOptions opts; + opts.SetTimeout(workerApi_->timeoutMs_); + CloseConsumerReqPb req; + req.set_stream_name(streamName); + req.set_subscription_name(subscriptionName); + req.set_consumer_id(consumerId); + req.set_client_id(workerApi_->GetClientId()); + reqTimeoutDuration.Init(workerApi_->ClientGetRequestTimeout(workerApi_->timeoutMs_)); + + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + PerfPoint point(PerfKey::RPC_WORKER_CLOSE_CONSUMER); + CloseConsumerRspPb rsp; + RETURN_IF_NOT_OK(workerApi_->signature_->GenerateSignature(req)); + RETURN_IF_NOT_OK_EXCEPT(workerApi_->rpcSession_->CloseConsumer(opts, req, rsp), + StatusCode::K_SC_CONSUMER_NOT_FOUND); + point.Record(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, P:%s] Close consumer success", workerApi_->LogPrefix(), + streamName, consumerId); + return Status::OK(); +} + +Status ProducerConsumerWorkerApi::GetLastAppendCursor(const std::string &streamName, uint64_t &lastAppendCursor) +{ + RpcOptions opts; + opts.SetTimeout(workerApi_->rpcTimeoutMs_); + LastAppendCursorReqPb req; + req.set_stream_name(streamName); + req.set_client_id(workerApi_->clientId_); + RETURN_IF_NOT_OK(SetTokenAndTenantId(req)); + + RETURN_IF_NOT_OK(workerApi_->signature_->GenerateSignature(req)); + LastAppendCursorRspPb rsp; + RETURN_IF_NOT_OK(workerApi_->rpcSession_->GetLastAppendCursor(opts, req, rsp)); + lastAppendCursor = rsp.last_append_cursor(); + return Status::OK(); +} +} // namespace stream_cache +} // namespace client +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/producer_consumer_worker_api.h b/src/datasystem/client/stream_cache/producer_consumer_worker_api.h new file mode 100644 index 0000000..613afd0 --- /dev/null +++ b/src/datasystem/client/stream_cache/producer_consumer_worker_api.h @@ -0,0 +1,127 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache base class for producer and consumer. + */ +#ifndef DATASYSTEM_CLIENT_STREAM_CACHE_PRODUCER_CONSUMER_WORKER_API_H +#define DATASYSTEM_CLIENT_STREAM_CACHE_PRODUCER_CONSUMER_WORKER_API_H + +#include +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class ProducerConsumerWorkerApi { +public: + ProducerConsumerWorkerApi(const std::string tenantId, std::shared_ptr workerApi); + + /** + * @brief Get the Lock Id object + * @return uint32_t lock id + */ + uint32_t GetLockId() + { + return workerApi_->GetLockId(); + } + + /** + * @brief Get the Data Page object + * @param[in/out] req GetDataPageReqPb info. + * @param[out] outPage out page. + * @return Status of the call + */ + Status GetDataPage(GetDataPageReqPb &req, ShmView &outPage); + + /** + * @brief Allocate big element memory + * @param[in] streamName stream name + * @param[in] producerId producer id + * @param[in] sizeNeeded need size + * @param[in] timeoutMs timeout ms + * @param[out] outView The out Shmview + * @return Status of the call + */ + Status AllocBigElementMemory(const std::string &streamName, const std::string &producerId, size_t sizeNeeded, + int64_t timeoutMs, ShmView &outView); + + /** + * @brief Release big element memory + * @param[in] streamName stream name + * @param[in] producerId producer id + * @param[in] pageView page view to release + * @return Status of the call + */ + Status ReleaseBigElementMemory(const std::string &streamName, const std::string &producerId, + const ShmView &pageView); + + /** + * @brief Send rpc request to worker to create WritePage for producer. + * @param[in] streamName The name of stream. + * @param[in] producerId The producer uuid. + * @param[in] pageId The page id. + * @param[in] timeoutMs The timeout for the call + * @param[out] outPage The memory page that producer will send element. + * @param[out] isFlushOK If the flush operation was successful + * @param[in] elementsMeta The meta info of element. + * @return Status of the call. + */ + Status CreateWritePage(const std::string &streamName, const std::string &producerId, int64_t timeoutMs, + const ShmView &curView, ShmView &outPage); + + /** + * @brief Send rpc request to worker to close producer. + * @param[in] streamName The name of stream that will be close. + * @param[in] producerId The name of producer that will be close. + * @return Status of the call. + */ + Status CloseProducer(const std::string &streamName, const std::string &producerId); + + /** + * @brief Send rpc request to worker to close consumer. + * @param[in] streamName The name of stream that will be close. + * @param[in] subscriptionName The name of subscription that will be close. + * @param[in] consumerId The uuid of consumer that will be close. + * @return Status of the call. + */ + Status CloseConsumer(const std::string &streamName, const std::string &subscriptionName, + const std::string &consumerId); + + /** + * @brief Get last append cursor in worker consumer. + * @param[in] streamName Target stream. + * @param[out] lastAppendCursor Last append cursor in worker consumer. + * @return Status of the call. + */ + Status GetLastAppendCursor(const std::string &streamName, uint64_t &lastAppendCursor); + + template + Status SetTokenAndTenantId(ReqType &req) + { + req.set_tenant_id(tenantId_); + return Status::OK(); + } + +private: + std::string tenantId_; + std::shared_ptr workerApi_; +}; +} // namespace stream_cache +} // namespace client +} // namespace datasystem +#endif \ No newline at end of file diff --git a/src/datasystem/client/stream_cache/producer_impl.cpp b/src/datasystem/client/stream_cache/producer_impl.cpp new file mode 100644 index 0000000..7357fc7 --- /dev/null +++ b/src/datasystem/client/stream_cache/producer_impl.cpp @@ -0,0 +1,573 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache producer. + */ +#include "datasystem/client/stream_cache/producer_impl.h" + +#include + +#include + +#include "datasystem/client/mmap_manager.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/common/constants.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/log/trace.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/stream_cache/cursor.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/util/bitmask_enum.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/memory.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +namespace client { +namespace stream_cache { + +ProducerImpl::ProducerImpl(std::string streamName, std::string tenantId, std::string producerId, int64_t delayFlushTime, + int64_t pageSize, std::shared_ptr workerApi, + std::shared_ptr client, MmapManager *mmapManager, + std::shared_ptr listenWorker, const ShmView &workArea, + uint64_t maxStreamSize, const DataVerificationHeader::SenderProducerNo senderProducerNo, + const bool enableStreamDataVerification, const DataVerificationHeader::Address address, + const DataVerificationHeader::Port port, StreamMode streamMode, uint64_t streamNo, + bool enableSharedPage, uint64_t sharedPageSize, const ShmView &streamMetaView) + : ClientBaseImpl(std::move(streamName), std::move(tenantId), std::move(workerApi), std::move(client), mmapManager, + std::move(listenWorker)), + producerId_(std::move(producerId)), + delayFlushTime_(delayFlushTime), + pageSize_(pageSize), + maxStreamSize_(maxStreamSize), + senderProducerNo_(senderProducerNo), + enableStreamDataVerification_(enableStreamDataVerification), + address_(address), + port_(port), + streamMode_(streamMode), + streamNo_(streamNo), + enableSharedPage_(enableSharedPage), + sharedPageSize_(sharedPageSize) +{ + workArea_ = workArea; + streamMetaView_ = streamMetaView; + maxElementSize_ = static_cast(pageSize_) - StreamDataPage::PageOverhead(false); + if (enableSharedPage) { + uint64_t maxElementSizeForSharedPage = sharedPageSize - StreamDataPage::PageOverhead(true); + maxElementSize_ = std::min(maxElementSize_, maxElementSizeForSharedPage); + } +} + +ProducerImpl::~ProducerImpl() +{ + if (delayFlushTimer_) { + TimerQueue::GetInstance()->Cancel(*delayFlushTimer_); + } + if (unfixTimer_) { + TimerQueue::GetInstance()->Cancel(*unfixTimer_); + } + client_->ClearProducer(producerId_); +} + +Status ProducerImpl::Init() +{ + RETURN_IF_NOT_OK(ClientBaseImpl::Init()); + if (enableSharedPage_) { + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(streamMetaView_ != ShmView(), K_RUNTIME_ERROR, + "streamMetaView_ not initialized"); + auto shmUnitInfo = std::make_shared(streamMetaView_.fd, streamMetaView_.mmapSz); + RETURN_IF_NOT_OK(mmapManager_->LookupUnitsAndMmapFd(tenantId_, shmUnitInfo)); + streamMetaShm_ = std::make_unique( + streamName_, static_cast(shmUnitInfo->GetPointer()) + streamMetaView_.off, streamMetaView_.sz, + maxStreamSize_); + RETURN_IF_NOT_OK(streamMetaShm_->Init(mmapManager_->GetMmapEntryByFd(shmUnitInfo->fd))); + } + + unfixWaitPost_ = std::make_unique(); + return Status::OK(); +} + +void ProducerImpl::ExecAndCancelTimer() +{ + std::unique_lock flushLock(flushMutex_); + if (delayFlushTimer_) { + INJECT_POINT_NO_RETURN("ProducerImpl.ExecAndCancelTimer.sleep"); + TimerQueue::GetInstance()->Cancel(*delayFlushTimer_); + delayFlushTimer_.reset(); + flushLock.unlock(); + // We do not use the function inside the timer because we can be in the destructor and the function in timer + // can only be execute when producer still exist and not in destructor. + ExecFlush(); + } +} + +Status ProducerImpl::SetUnfixPageTimer() +{ + CHECK_FAIL_RETURN_STATUS(unfixTimer_ == nullptr, StatusCode::K_RUNTIME_ERROR, "Timer should be cancelled."); + const int DEFAULT_UNFIX_INTERVAL = 1000; + TimerQueue::TimerImpl timer; + unfixWaitPost_->Clear(); + RETURN_IF_NOT_OK(TimerQueue::GetInstance()->AddTimer( + DEFAULT_UNFIX_INTERVAL, + [this]() { + ExecAndCancelTimer(); + LOG_IF_ERROR(UnfixPage(), ""); + unfixWaitPost_->Set(); + }, + timer)); + unfixTimer_ = std::make_unique(timer); + return Status::OK(); +} + +void ProducerImpl::UnsetUnfixPageTimer() +{ + // If timer is already set, cancel the timer + if (unfixTimer_) { + bool canceled = TimerQueue::GetInstance()->Cancel(*unfixTimer_); + // Wait for callback to finish if it is already executing + if (!canceled) { + unfixWaitPost_->Wait(); + } + unfixTimer_ = nullptr; + } +} + +void ProducerImpl::SetInactive() +{ + UnsetUnfixPageTimer(); + ExecAndCancelTimer(); + ClientBaseImpl::SetInactive(); +} + +Status ProducerImpl::CheckNormalState() const +{ + // If the worker's state is unknown, do not try to access the shared memory work area + RETURN_IF_NOT_OK(ClientBaseImpl::CheckNormalState()); + if (cursor_ == nullptr || cursor_->ForceClose()) { + RETURN_STATUS(K_SC_CONSUMER_NOT_FOUND, "Force close. No consumer"); + } + return Status::OK(); +} + +Status ProducerImpl::HandleNoSpaceFromInsert(int64_t timeoutMs, const Status &rc) +{ + if (rc.GetCode() != K_NO_SPACE) { + return rc; + } + ShmView nextPage; + bool isFreePage; + writePage_->nextPage_->GetView(nextPage, isFreePage, std::numeric_limits::max()); + // If there is no next page encoded on the page, just return no space + if (nextPage.fd <= 0) { + return rc; + } + // There is a next page. Check if the current page has been sealed. + if (isFreePage) { + // Now we will seal the current page, acquire the next page. + auto func = [this](const ShmView &nextPage, std::shared_ptr &out) { + std::shared_ptr pageInfo; + std::shared_ptr mmapEntry; + RETURN_IF_NOT_OK(GetShmInfo(nextPage, pageInfo, mmapEntry)); + auto page = std::make_shared(pageInfo, lockId_, true, false, mmapEntry); + RETURN_IF_NOT_OK(page->Init()); + out = page; + return Status::OK(); + }; + RETURN_IF_NOT_OK_EXCEPT(writePage_->Seal(nextPage, timeoutMs, func, LogPrefix()), K_DUPLICATED); + } + RETURN_STATUS(K_SC_END_OF_PAGE, "New empty page is created"); +} + +Status ProducerImpl::InsertBigElement(const HeaderAndData &element, Optional timeoutMs) +{ + int64_t rpcTimeoutMs = timeoutMs ? timeoutMs.value() : RPC_TIMEOUT; + ShmView pageView; + RETURN_IF_NOT_OK( + workerApi_->AllocBigElementMemory(streamName_, producerId_, element.TotalSize(), rpcTimeoutMs, pageView)); + fixPageFromRpc_++; + InsertFlags flag = InsertFlags::BIG_ELEMENT; + // If we hit any error below, we will send a rpc to the worker to release the shared memory acquired above. + RaiiPlus raiiP([this, &pageView, &flag]() { + if (!TESTFLAG(flag, InsertFlags::INSERT_SUCCESS)) { + LOG_IF_ERROR(workerApi_->ReleaseBigElementMemory(streamName_, producerId_, pageView), + "ReleaseBigElementMemory"); + } + }); + std::shared_ptr pageUnit; + std::shared_ptr mmapEntry; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(GetShmInfo(pageView, pageUnit, mmapEntry), "GetShmInfo"); + auto page = std::make_shared(pageUnit, true, mmapEntry); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(page->Init(), "Page Init"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(page->Insert(element), "Insert"); + std::string pointerString; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(StreamDataPage::SerializeToShmViewPb(pageView, pointerString), "Serialization"); + HeaderAndData ele(reinterpret_cast(pointerString.data()), pointerString.size(), streamNo_); + if (enableSharedPage_) { + auto eleSize = ele.size; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamMetaShm_->TryIncUsage(eleSize), + "Failed to increase the usage of shared memory for stream: " + streamName_); + raiiP.AddTask([this, eleSize]() { + LOG_IF_ERROR(streamMetaShm_->TryDecUsage(eleSize), + "Failed to decrease the usage of shared memory for stream: " + streamName_); + }); + } + INJECT_POINT("ProducerImpl.ReleaseBigElementMemory"); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, %s, %zu] Inserting a BigElement [S:%zu] into BigElement page: %s", + streamName_, producerId_, (cursor_->GetElementCount() + 1), element.size, + page->GetPageId()); + flag |= (element.headerSize_ > 0) ? InsertFlags::HEADER : InsertFlags::NONE; + RETURN_IF_NOT_OK(SendImpl(ele, flag, timeoutMs)); + raiiP.ClearAllTask(); + return Status::OK(); +} + +Status ProducerImpl::SendImpl(const HeaderAndData &element, InsertFlags &flag, Optional userTimeoutMs) +{ + PerfPoint point(PerfKey::CLIENT_SEND_ALL); + // Check if the element is correct + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(element.size > 0, K_INVALID, "Element size should be greater than 0"); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(element.ptr != nullptr, K_INVALID, "Element ptr should not be a nullptr"); + // Cancel the timer before Send, and reset the timer after Send to account for the producer idle case + UnsetUnfixPageTimer(); + Raii unfixPage([this]() { LOG_IF_ERROR(SetUnfixPageTimer(), ""); }); + Status rc; + int64_t timeoutMs = userTimeoutMs ? userTimeoutMs.value() : RPC_TIMEOUT; + Timer t(timeoutMs); + // Locate the page to insert which is most likely the last page. + // The page can be the current page. + INJECT_POINT("client.Producer.beforeCheckCursor", [] { return inject::Set("SharedMemViewImpl.GetView", "abort"); }); + RETURN_IF_NOT_OK(FixPage(userTimeoutMs ? Optional(t.GetRemainingTimeMs()) : Optional(), false)); + + SETFLAG(flag, (delayFlushTime_ > 0) ? InsertFlags::DELAY_WAKE : InsertFlags::NONE); + SETFLAG(flag, (element.headerSize_ > 0) ? InsertFlags::HEADER : InsertFlags::NONE); + + // Start the loop + do { + // The timeout here is the time we wait on a shared memory lock, not rpc + rc = writePage_->Insert(element, t.GetRemainingTimeMs(), flag, LogPrefix()); + auto rcCode = rc.GetCode(); + if (rcCode == K_OK) { + cursor_->IncrementElementCount(); + lastSendElementSeqNo_++; + pageDirty_ = true; + INJECT_POINT("ProducerImpl.SendImpl.postInsertSuccess"); + return DelayFlush(); + } + Status status = HandleNoSpaceFromInsert(timeoutMs, rc); + rcCode = status.GetCode(); + if (rcCode == K_SC_END_OF_PAGE || rcCode == K_NO_SPACE) { + RETURN_IF_NOT_OK(DelayFlush(pageDirty_)); // Ensure we wake up reader before we move to a new page + // If there is a next page, we follow the pointer. Otherwise, we send another rpc request. + INJECT_POINT("client.Producer.beforeCheckNewPage", + [] { return inject::Set("SharedMemViewImpl.GetView", "abort"); }); + auto nextPage = writePage_->GetNextPage(); + if (nextPage.fd > 0) { + RETURN_IF_NOT_OK(CreatePagePostProcessing(nextPage, fixPageFromNextPage_)); + } else { + RETURN_IF_NOT_OK( + CreateWritePage(userTimeoutMs ? Optional(t.GetRemainingTimeMs()) : Optional())); + } + continue; + } else if (rcCode == K_TRY_AGAIN) { // If we can't get a lock on the page + if (t.GetRemainingTimeMs() == 0) { + break; + } + // See if there is a newer page to try. + // Require to CreateWritePage if the cursor don't have a new page. + RETURN_IF_NOT_OK( + FixPage(userTimeoutMs ? Optional(t.GetRemainingTimeMs()) : Optional(), true)); + continue; + } + return rc; + } while (true); + RETURN_STATUS(K_OUT_OF_MEMORY, + FormatString("[%s] Producer unable to secure enough memory for the element within %zu ms", + LogPrefix(), timeoutMs)); +} + +Status ProducerImpl::Send(const Element &element, Optional userTimeoutMs) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + cursor_->IncrementRequestCount(); + RETURN_IF_NOT_OK(CheckNormalState()); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + !userTimeoutMs || *userTimeoutMs >= 0, K_INVALID, + FormatString("[%s] The send timeout must be greater than or equal to 0", LogPrefix())); + DataVerificationHeader dataVerificationHeader; + ElementHeader elementHeader; + if (enableStreamDataVerification_) { + dataVerificationHeader.Set(lastSendElementSeqNo_ + 1, senderProducerNo_, address_, port_); + INJECT_POINT("DataVerificationOutOfOrder", [&dataVerificationHeader](int num) { + dataVerificationHeader.hdr.seqNo += static_cast(num); + return Status::OK(); + }); + elementHeader.Set(reinterpret_cast(&dataVerificationHeader), + dataVerificationHeader.HeaderSize(), DATA_VERIFICATION_HEADER); + } + HeaderAndData headerAndData(element, elementHeader, streamNo_); + uint64_t finalElementSize = headerAndData.TotalSize(); + RaiiPlus raiiP; + Status rc(K_UNKNOWN_ERROR, "To initialize an error status, no special meaning"); + if (enableSharedPage_) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamMetaShm_->TryIncUsage(finalElementSize), + "Failed to increase the usage of shared memory for stream: " + streamName_); + raiiP.AddTask([this, &rc, &finalElementSize]() { + if (!rc) { + LOG_IF_ERROR(streamMetaShm_->TryDecUsage(finalElementSize), + "Failed to decrease the usage of shared memory for stream: " + streamName_); + } + }); + } + if (finalElementSize <= static_cast(maxElementSize_)) { + auto flag = InsertFlags::NONE; + rc = SendImpl(headerAndData, flag, userTimeoutMs); + } else { + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + finalElementSize < maxStreamSize_, K_INVALID, + FormatString("[%s] Element size must be smaller than Stream size. [element size, stream size] : [%zu, %zu]", + LogPrefix(), finalElementSize, maxStreamSize_)); + rc = InsertBigElement(headerAndData, userTimeoutMs); + } + if (rc) { + raiiP.ClearAllTask(); + } + if (sendTracer_.NeedWriteLog(rc.IsOk())) { + LOG(INFO) << FormatString("[%s] Producer first send element with status %s", LogPrefix(), rc.ToString()); + } + // A special return code K_OUT_OF_RANGE may indicate the worker has already restarted. + // Go check the state. + if (rc.GetCode() == K_OUT_OF_RANGE) { + RETURN_IF_NOT_OK(CheckNormalState()); + } + return rc; +} + +Status ProducerImpl::DelayFlush(bool force) +{ + PerfPoint point(PerfKey::CLIENT_FLUSH_ELEMENT_ALL); + RETURN_IF_NOT_OK(CheckNormalState()); + RETURN_OK_IF_TRUE(delayFlushTime_ == 0); + // Synchronize with the timer queue which can be running at the same time + std::unique_lock flushLock(flushMutex_); + RETURN_OK_IF_TRUE(!pageDirty_); + if (force) { + if (delayFlushTimer_) { + TimerQueue::GetInstance()->Cancel(*delayFlushTimer_); + delayFlushTimer_ = nullptr; + } + pageDirty_ = false; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[Cursor %zu] Wake up consumers", cursor_->GetElementCount()); + return writePage_->WakeUpConsumers(); + } + RETURN_OK_IF_TRUE(delayFlushTimer_ != nullptr); + TimerQueue::TimerImpl timer; + RETURN_IF_NOT_OK(TimerQueue::GetInstance()->AddTimer( + delayFlushTime_, + [w = this->weak_from_this()]() { + INJECT_POINT_NO_RETURN("ProducerImpl.ExecFlush.sleep"); + // Producer may be deallocated, check if the producer still there through weak pointer. + auto producerImpl = w.lock(); + if (producerImpl) { + producerImpl->ExecFlush(); + } + }, + timer)); + delayFlushTimer_ = std::make_unique(timer); + return Status::OK(); +} + +Status ProducerImpl::Close() +{ + PerfPoint point(PerfKey::CLIENT_CLOSE_PRODUCER_ALL); + // If producer is already closed, return OK + RETURN_OK_IF_TRUE(state_ == State::CLOSE); + // If the current state is RESET, it is allowed to close + RETURN_IF_NOT_OK(CheckState()); + UnsetUnfixPageTimer(); + ExecAndCancelTimer(); + RETURN_IF_NOT_OK(UnfixPage()); + std::string logStr = FormatString("%zu Elements sent.", lastSendElementSeqNo_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerApi_->CloseProducer(streamName_, producerId_), + FormatString("[%s] CloseProducer request error", LogPrefix())); + RETURN_IF_NOT_OK(ChangeState(State::CLOSE)); + std::ostringstream oss; + auto totalGetPageCall = fixPageFromNextPage_ + fixPageFromWorkArea_ + fixPageFromRpc_; + oss << "[" << LogPrefix() << "] Close producer success! " << logStr << " Total get page count " << totalGetPageCall + << ". rpc: " << fixPageFromRpc_.load() << ", work area: " << fixPageFromWorkArea_.load() + << ", next pointer: " << fixPageFromNextPage_.load(); + LOG(INFO) << oss.str(); + return Status::OK(); +} + +Status ProducerImpl::SetStateToReset() +{ + RETURN_IF_NOT_OK(ChangeState(State::RESET)); + std::unique_lock flushLock(flushMutex_); + // Cancel any pending AutoFlush in the timer + if (delayFlushTimer_ != nullptr) { + TimerQueue::GetInstance()->Cancel(*delayFlushTimer_); + // Set timer_ to nullptr after canceling it + delayFlushTimer_ = nullptr; + } + return Status::OK(); +} + +Status ProducerImpl::Reset() +{ + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + state_ == State::RESET, K_RUNTIME_ERROR, + FormatString("[%s] The producer should be in reset state already to cleanup data and metadata", LogPrefix())); + { + std::unique_lock flushLock(flushMutex_); // Wait for in flight flush to finish + if (delayFlushTimer_) { + TimerQueue::GetInstance()->Cancel(*delayFlushTimer_); + } + delayFlushTimer_ = nullptr; + pageDirty_ = false; + } + RETURN_IF_NOT_OK(UnfixPage()); + pageUnit_ = nullptr; + return Status::OK(); +} + +Status ProducerImpl::Resume() +{ + return ChangeState(State::NORMAL); +} + +Status ProducerImpl::CreateWritePage(Optional timeoutMs) +{ + INJECT_POINT("ProducerImpl.beforeCreateWritePage"); + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + PerfPoint point(PerfKey::CLIENT_CREATE_WRITE_PAGE_ALL); + RETURN_IF_NOT_OK(CheckNormalState()); + // If we have a page, decrement its count before we send out the rpc request + if (writePage_) { + RETURN_IF_NOT_OK(UnfixPage()); + writePage_ = nullptr; + pageUnit_ = nullptr; + } + ShmView lastPageView; + RETURN_IF_NOT_OK(workerApi_->CreateWritePage(streamName_, producerId_, timeoutMs ? timeoutMs.value() : 0, curView_, + lastPageView)); + INJECT_POINT("ProducerImpl.afterCreateWritePage"); + return CreatePagePostProcessing(lastPageView, fixPageFromRpc_); +} + +Status ProducerImpl::CreatePagePostProcessing(const ShmView &lastPageView, std::atomic &fixPageCount) +{ + PerfPoint point(PerfKey::PAGE_POST_PROCESSING); + if (writePage_ == nullptr || pageUnit_->fd != lastPageView.fd || pageUnit_->offset != lastPageView.off) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Next page ShmInfo: %s", LogPrefix(), lastPageView.ToStr()); + std::shared_ptr mmapEntry; + RETURN_IF_NOT_OK(GetShmInfo(lastPageView, pageUnit_, mmapEntry)); + auto page = std::make_shared(pageUnit_, lockId_, true, false, mmapEntry); + RETURN_IF_NOT_OK(page->Init()); + RETURN_IF_NOT_OK(UnfixPage()); + writePage_ = std::move(page); + RETURN_IF_NOT_OK(writePage_->RefPage(LogPrefix())); + curView_ = writePage_->GetShmView(); + pageId_ = writePage_->GetPageId(); + fixPageCount++; + // We may (or may not) lock this page. But let's track it in the work area + if (WorkAreaIsV2()) { + RETURN_IF_NOT_OK(cursor_->SetLastLockedPage(curView_, DEFAULT_TIMEOUT_MS)); + } + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Acquire page id: %s, isSharedPage: %d, lastPageView: %s", + LogPrefix(), pageId_, writePage_->IsSharedPage(), lastPageView.ToStr()); + return Status::OK(); +} + +Status ProducerImpl::UnfixPage() +{ + RETURN_OK_IF_TRUE(writePage_ == nullptr); + RETURN_IF_NOT_OK(writePage_->WakeUpConsumers()); + PerfPoint point(PerfKey::PAGE_UNFIX); + RETURN_IF_NOT_OK(writePage_->ReleasePage(LogPrefix())); + // Clear the entry in the work area since we don't have any lock on this page + if (WorkAreaIsV2()) { + RETURN_IF_NOT_OK(cursor_->SetLastLockedPage(ShmView(), DEFAULT_TIMEOUT_MS)); + } + point.Record(); + writePage_.reset(); + return Status::OK(); +} + +Status ProducerImpl::GetLastPageView(ShmView &lastPageView, bool &switchToSharedPage) +{ + // Client will timeout when worker has crashed + const uint64_t DEFAULT_TIMEOUT_MS = 1000; + if (!enableSharedPage_) { + return cursor_->GetLastPageView(lastPageView, DEFAULT_TIMEOUT_MS); + } + + auto func = [this](const ShmView &lastPageRefView, std::shared_ptr &shmUnitInfo, + std::shared_ptr &mmapEntry) { + VLOG(SC_DEBUG_LOG_LEVEL) << LogPrefix() << " lastPageRefView:" << lastPageRefView.ToStr(); + return GetShmInfo(lastPageRefView, shmUnitInfo, mmapEntry); + }; + return cursor_->GetLastPageViewByRef(lastPageView, switchToSharedPage, DEFAULT_TIMEOUT_MS, std::move(func)); +} + +Status ProducerImpl::FixPage(Optional timeoutMs, const bool requireNewPage) +{ + // Look into the work area is there is a page. + // Otherwise, send a rpc request to the worker + ShmView lastPageView; + bool switchToSharedPage = false; + RETURN_IF_NOT_OK(GetLastPageView(lastPageView, switchToSharedPage)); + if (requireNewPage && lastPageView == curView_) { + // Request a new page from worker if the page does not changed. + return CreateWritePage(timeoutMs); + } + if (lastPageView.fd > 0) { + return CreatePagePostProcessing(lastPageView, fixPageFromWorkArea_); + } + // Check if the producer has already fixed a page and not switch to shared page. + if (writePage_ && !switchToSharedPage) { + return Status::OK(); + } + return CreateWritePage(timeoutMs); +} + +std::string ProducerImpl::LogPrefix() const +{ + return FormatString("S:%s, P:%s", streamName_, producerId_); +} + +void ProducerImpl::ExecFlush() +{ + std::unique_lock flushLock(flushMutex_); + if (pageDirty_) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[Cursor %zu] Wake up consumers", cursor_->GetElementCount()); + pageDirty_ = false; + if (writePage_ != nullptr) { + writePage_->WakeUpConsumers(); + } + delayFlushTimer_ = nullptr; + } +} +} // namespace stream_cache +} // namespace client +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/producer_impl.h b/src/datasystem/client/stream_cache/producer_impl.h new file mode 100644 index 0000000..43af25c --- /dev/null +++ b/src/datasystem/client/stream_cache/producer_impl.h @@ -0,0 +1,248 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache producer. + */ +#ifndef DATASYSTEM_CLIENT_STREAM_CACHE_PRODUCER_IMPL_H +#define DATASYSTEM_CLIENT_STREAM_CACHE_PRODUCER_IMPL_H + +#include +#include +#include +#include + +#include "datasystem/client/stream_cache/client_base_impl.h" +#include "datasystem/client/stream_cache/producer_consumer_worker_api.h" +#include "datasystem/common/eventloop/timer_queue.h" +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/stream_cache/stream_meta_shm.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/optional.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class ProducerImpl : public ClientBaseImpl, public std::enable_shared_from_this { +public: + /** + * @brief Construct Producer. + * @param[in] streamName The name of the stream. + * @param[in] producerId The uuid of the producer which is generated by Worker. + * @param[in] delayFlushTime The time used in automatic flush after send and default is 5ms. + * @param[in] pageSize The size of the page + * @param[in] workerApi Used to call worker service through rpc. + * @param[in] mmapManager Used to receive and mmap fd that passed from worker. + * @param[in] listenWorker Listening to the worker survival status. + * @param[in] senderProducerNo The producer number generated by worker. + * @param[in] enableStreamDataVerification Should data verification be on. + * @param[in] address Local worker address. + * @param[in] port Local worker port. + */ + ProducerImpl(std::string streamName, std::string tenantId, std::string producerId, int64_t delayFlushTime, + int64_t pageSize, std::shared_ptr workerApi, + std::shared_ptr client, MmapManager *mmapManager, + std::shared_ptr listenWorker, const ShmView &workArea, uint64_t maxStreamSize, + const DataVerificationHeader::SenderProducerNo senderProducerNo, + const bool enableStreamDataVerification, const DataVerificationHeader::Address address, + const DataVerificationHeader::Port port, StreamMode streamMode, uint64_t streamNo, + bool enableSharedPage, uint64_t sharedPageSize, const ShmView &streamMetaView); + + ~ProducerImpl() override; + + /** + * @brief Produce send one element of the stream each time. + * @param[in] element The element that to be written. + * @param[in] timeoutMs The timeout for the call + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: depends on the error message. + * K_SC_ALREADY_CLOSED: producer is already closed/inactive. + */ + Status Send(const Element &element, Optional timeoutMs); + + /** + * @brief If flush elements are not flushed, the local buffer may keep some elements. + * And the flush operation will ensure that all elements are written to the stream. + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: depends on the error message. + * K_SC_ALREADY_CLOSED: producer is already closed/inactive. + */ + Status Flush(); + + /** + * @brief Close the producer, after close it will not allow new elements to be sent. + * Calling Close() on an already closed consumer will return K_OK. + * Flush operations will trigger when the local buffer had not flushed elements. + * @return Status of the call. + */ + Status Close(); + + /** + * @brief Reset the producer by cleaning up data and metadata. + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: Producer is not in reset state. + */ + Status Reset(); + + /** + * @brief Set producer state to Reset. Cancel flush timer. During reset it will not allow Send, Flush. + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: producer is in close state already. + */ + Status SetStateToReset(); + + /** + * @brief Resume the producer, allowing Send, Flush and Close again. + * @return K_OK on success; the error code otherwise. + * K_RUNTIME_ERROR: producer is not in reset state. + */ + Status Resume(); + + /** + * @brief Init producer and create create share memory page. + * @return Status of the call. + */ + Status Init() override; + + /** + * @brief Sets the state of the Producer to CLOSE + */ + void SetInactive() override; + + /** + * @brief Log helper. Creates the prefix for log messages. + * @return The generated log prefix for this Producer. + */ + std::string LogPrefix() const override; + + /** + * @brief Get the producer id. + * @return Producer id. + */ + const std::string &GetProducerId() const + { + return producerId_; + } + + /** + * @brief Execute a flush. + */ + void ExecFlush(); + +private: + /** + * @brief Check the state_ of consumer and return status. + * @return Status of the call. + */ + Status CheckNormalState() const override; + + /** + * @brief Send request to create memory page for element. + * @param[in] timeoutMs Timeout for the call + * @return Status of the call. + */ + Status CreateWritePage(Optional timeoutMs); + + /** + * @brief Update write page address and shm info after successful write page creation. + * @param[in] writePage Pointer to the created write page + * @return Status of the call. + */ + Status CreatePagePostProcessing(const ShmView &lastPageView, std::atomic &fixPageCount); + + /** + * @brief Produce send one element of the stream each time. + * @param[in] element The element that to be written (header + data) + * @param[in] timeoutMs The timeout for the call + * @return Status of the call. + */ + Status SendImpl(const HeaderAndData &element, InsertFlags &flags, Optional timeoutMs); + + /** + * @brief Locate page to insert, likely the last page. + * @param[in] timeoutMs The timeout for the call. + * @param[in] requireNewPage If true, require a CreateWritePage to worker if the cursor's page is the current page. + * @return Status of the call. + */ + Status FixPage(Optional timeoutMs, const bool requireNewPage); + + Status UnfixPage(); + Status DelayFlush(bool force = false); + void ExecAndCancelTimer(); + Status HandleNoSpaceFromInsert(int64_t timeoutMs, const Status &rc); + + /** + * @brief Set the timer to UnfixPage when producer is idle. + * @return Status of the call. + */ + Status SetUnfixPageTimer(); + + /** + * @brief Cancel the timer to UnfixPage. + */ + void UnsetUnfixPageTimer(); + + Status InsertBigElement(const HeaderAndData &element, Optional timeoutMs); + + /** + * @brief Get the last page view from cursor. + * @param[out] lastPageView The last page view. + * @param[out] switchToSharedPage Whether swith to shared page. + * @return Status of this call. + */ + Status GetLastPageView(ShmView &lastPageView, bool &switchToSharedPage); + + const std::string producerId_; + + int64_t delayFlushTime_; + int64_t pageSize_; + std::shared_ptr writePage_{ nullptr }; + std::unique_ptr delayFlushTimer_; + std::unique_ptr unfixTimer_; + std::unique_ptr unfixWaitPost_{ nullptr }; + std::shared_ptr pageUnit_; + std::string pageId_; + ShmView curView_; + std::mutex flushMutex_; // Guarantee FIFO, single on-the-fly flush. + std::atomic pageDirty_{ false }; + std::atomic fixPageFromRpc_{ 0 }; + std::atomic fixPageFromWorkArea_{ 0 }; + std::atomic fixPageFromNextPage_{ 0 }; + uint64_t maxStreamSize_; + uint64_t maxElementSize_; + const DataVerificationHeader::SenderProducerNo senderProducerNo_; + const bool enableStreamDataVerification_; + const DataVerificationHeader::Address address_; + const DataVerificationHeader::Port port_; + std::atomic lastSendElementSeqNo_{ 0 }; + FirstCallTracer sendTracer_; + StreamMode streamMode_; + uint64_t streamNo_; + bool enableSharedPage_; + uint64_t sharedPageSize_; + + ShmView streamMetaView_; + std::unique_ptr streamMetaShm_; + + friend class Producer; +}; +} // namespace stream_cache +} // namespace client +} // namespace datasystem +#endif diff --git a/src/datasystem/client/stream_cache/receive_element.h b/src/datasystem/client/stream_cache/receive_element.h new file mode 100644 index 0000000..71619f6 --- /dev/null +++ b/src/datasystem/client/stream_cache/receive_element.h @@ -0,0 +1,36 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define stream cache struct for receiving element. + */ +#ifndef DATASYSTEM_CLIENT_STREAM_CACHE_RECEIVE_ELEMENT_H +#define DATASYSTEM_CLIENT_STREAM_CACHE_RECEIVE_ELEMENT_H + +#include + +namespace datasystem { +namespace client { +namespace stream_cache { +struct ReceiveElement { + int workerFd; + uint64_t eleOffset; + uint64_t eleSize; +}; +} // namespace stream_cache +} // namespace client +} // namespace datasystem +#endif \ No newline at end of file diff --git a/src/datasystem/client/stream_cache/stream_client.cpp b/src/datasystem/client/stream_cache/stream_client.cpp new file mode 100644 index 0000000..cdabee4 --- /dev/null +++ b/src/datasystem/client/stream_cache/stream_client.cpp @@ -0,0 +1,145 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define stream cache client. + */ +#include "datasystem/stream_client.h" + +#include "datasystem/common/flags/flags.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/common/log/access_recorder.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/utils/optional.h" + +namespace datasystem { +StreamClient::StreamClient(ConnectOptions connectOptions) + : ip_(std::move(connectOptions.host)), port_(connectOptions.port) +{ + impl_ = std::make_shared(connectOptions); +} + +StreamClient::~StreamClient() +{ + LOG(INFO) << "Destroy StreamClient"; + if (impl_) { + impl_.reset(); + } +} + +Status StreamClient::ShutDown() +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + if (impl_) { + bool needRollbackState; + auto rc = impl_->ShutDown(needRollbackState); + impl_->CompleteHandler(rc.IsError(), needRollbackState); + return rc; + } + return Status::OK(); +} + +Status StreamClient::Init(bool reportWorkerLost) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + bool needRollbackState; + auto rc = impl_->Init(ip_, port_, needRollbackState, reportWorkerLost); + impl_->CompleteHandler(rc.IsError(), needRollbackState); + return rc; +} + +Status StreamClient::CreateProducer(const std::string &streamName, std::shared_ptr &outProducer, + ProducerConf producerConf) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + PerfPoint point(PerfKey::CLIENT_CREATE_PRODUCER_ALL); + AccessRecorder recorder(AccessRecorderKey::DS_STREAM_CREATE_PRODUCER); + auto rc = impl_->CreateProducer(streamName, outProducer, producerConf); + StreamRequestParam reqParam; + reqParam.streamName = streamName; + reqParam.delayFlushTime = Optional(producerConf.delayFlushTime); + reqParam.pageSize = Optional(producerConf.pageSize); + reqParam.maxStreamSize = Optional(producerConf.maxStreamSize); + reqParam.autoCleanup = Optional(producerConf.autoCleanup); + reqParam.retainForNumConsumers = Optional(producerConf.retainForNumConsumers); + reqParam.encryptStream = Optional(producerConf.encryptStream); + reqParam.streamMode = Optional(producerConf.streamMode); + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Status StreamClient::Subscribe(const std::string &streamName, const struct SubscriptionConfig &config, + std::shared_ptr &outConsumer, bool autoAck) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + PerfPoint point(PerfKey::CLIENT_CREATE_SUB_ALL); + AccessRecorder recorder(AccessRecorderKey::DS_STREAM_SUBSCRIBE); + std::string consumerId; + auto rc = impl_->Subscribe(streamName, config, outConsumer, autoAck); + StreamRequestParam reqParam; + reqParam.streamName = streamName; + reqParam.subscriptionName = config.subscriptionName; + reqParam.autoAck = Optional(autoAck); + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Status StreamClient::DeleteStream(const std::string &streamName) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + PerfPoint point(PerfKey::CLIENT_DELETE_STREAM_ALL); + AccessRecorder recorder(AccessRecorderKey::DS_STREAM_DELETE_STREAM); + auto rc = impl_->DeleteStream(streamName); + StreamRequestParam reqParam; + reqParam.streamName = streamName; + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Status StreamClient::QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + AccessRecorder recorder(AccessRecorderKey::DS_STREAM_QUERY_PRODUCERS_NUM); + auto rc = impl_->QueryGlobalProducersNum(streamName, gProducerNum); + StreamRequestParam reqParam; + reqParam.streamName = streamName; + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + rspParam.count = Optional(gProducerNum); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Status StreamClient::QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + AccessRecorder recorder(AccessRecorderKey::DS_STREAM_QUERY_CONSUMERS_NUM); + auto rc = impl_->QueryGlobalConsumersNum(streamName, gConsumerNum); + StreamRequestParam reqParam; + reqParam.streamName = streamName; + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + rspParam.count = Optional(gConsumerNum); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/stream_client_impl.cpp b/src/datasystem/client/stream_cache/stream_client_impl.cpp new file mode 100644 index 0000000..684fe30 --- /dev/null +++ b/src/datasystem/client/stream_cache/stream_client_impl.cpp @@ -0,0 +1,420 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache client. + */ +#include "datasystem/client/stream_cache/stream_client_impl.h" + +#include +#include + +#include "datasystem/client/client_flags_monitor.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/client/stream_cache/consumer_impl.h" +#include "datasystem/client/stream_cache/producer_impl.h" +#include "datasystem/client/stream_cache/producer_consumer_worker_api.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/log/logging.h" +#include "datasystem/common/log/spdlog/provider.h" +#include "datasystem/common/util/container_util.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/thread_local.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/common/util/validator.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" + +const std::string LOG_FILENAME = "ds_client"; + +namespace datasystem { +namespace client { +namespace stream_cache { + +StreamClientImpl::StreamClientImpl(const std::string &clientPublicKey, const SensitiveValue &clientPrivateKey, + const std::string &serverPublicKey, const std::string &accessKey, + const SensitiveValue &secretKey) + : signature_(std::make_unique(accessKey, secretKey)) +{ + (void)Provider::Instance(); + clientStateManager_ = std::make_unique(); + authKeys_.SetClientPublicKey(clientPublicKey); + authKeys_.SetClientPrivateKey(clientPrivateKey); + authKeys_.SetServerKey(WORKER_SERVER_NAME, serverPublicKey); +} + +StreamClientImpl::StreamClientImpl(const ConnectOptions &connectOptions) +{ + (void)Provider::Instance(); + clientStateManager_ = std::make_unique(); + signature_ = std::make_unique(connectOptions.accessKey, connectOptions.secretKey); + tenantId_ = connectOptions.tenantId; + timeoutMs_ = connectOptions.connectTimeoutMs; + authKeys_.SetClientPublicKey(connectOptions.clientPublicKey); + authKeys_.SetClientPrivateKey(connectOptions.clientPrivateKey); + authKeys_.SetServerKey(WORKER_SERVER_NAME, connectOptions.serverPublicKey); +} + +StreamClientImpl::~StreamClientImpl() +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + LOG(INFO) << "Destroy StreamClientImpl"; + auto shutdownFunc = std::bind(&StreamClientImpl::ShutDown, this, true, true); + clientStateManager_->ProcessDestruct(shutdownFunc); +} + +Status StreamClientImpl::ShutDown(bool &needRollbackState, bool isDestruct) +{ + INJECT_POINT("StreamClient.ShutDown.skip"); + // Step0: Check client's status to determine whether it meets the conditions for executing shutdown. + auto rc = clientStateManager_->ProcessShutdown(needRollbackState, isDestruct); + if (!needRollbackState) { + return rc; + } + + // Step1: Clear existing producers and consumers. + ClearProducerAndConsumer(); + + // Step2: Shutdown heartbeat. + if (listenWorker_ != nullptr) { + listenWorker_->RemoveCallBackFunc(this); + listenWorker_->StopListenWorker(true); + } + + // Step3: Send notice to worker before disconnection. + if (clientWorkerApi_ != nullptr) { + RETURN_IF_NOT_OK(clientWorkerApi_->Disconnect()); + } + return Status::OK(); +} + +Status StreamClientImpl::Init(const std::string &ip, const int &port, bool &needRollbackState, bool reportWorkerLost) +{ + Logging::GetInstance()->Start(LOG_FILENAME, true); + FlagsMonitor::GetInstance()->Start(); + auto rc = clientStateManager_->ProcessInit(needRollbackState); + if (!needRollbackState) { + return rc; + } + CHECK_FAIL_RETURN_STATUS(Validator::IsIpv4OrUrl(ip, false), K_INVALID, "Invalid IP address."); + CHECK_FAIL_RETURN_STATUS(Validator::IsInPortRange(port, false), K_INVALID, "Invalid port number."); + + RpcCredential cred; + RETURN_IF_NOT_OK(RpcAuthKeyManager::CreateClientCredentials(authKeys_, WORKER_SERVER_NAME, cred)); + + clientWorkerApi_ = std::make_shared(HostPort(ip, port), cred, signature_.get(), tenantId_); + RETURN_IF_NOT_OK(clientWorkerApi_->Init(timeoutMs_)); + VLOG(SC_NORMAL_LOG_LEVEL) << "clientWorkerApi_ init success"; + mmapManager_ = std::make_unique( + std::dynamic_pointer_cast(clientWorkerApi_)); + listenWorker_ = std::make_shared(clientWorkerApi_, HeartbeatType::RPC_HEARTBEAT); + callBack_ = [this]() { + LOG(INFO) << "Disconnected from worker, clear mmap and try to reconnect..."; + if (reportWorkerLost_) { + workerWasLost_ = true; + } + ClearProducerAndConsumer(); + mmapManager_->CleanInvalidMmapTable(); + Status reconnectStatus = clientWorkerApi_->Reconnect(); + if (reconnectStatus.IsError()) { + LOG(ERROR) << "Reconnect to worker failed, please check network and worker status and restart client." + << reconnectStatus.ToString(); + return; + } + listenWorker_->SetWorkerAvailable(true); + LOG(INFO) << "Reconnect to worker success"; + }; + listenWorker_->StartListenWorker(); + listenWorker_->AddCallBackFunc(this, callBack_); + listenWorker_->SetReleaseFdCallBack( + [this](const std::vector &fds) { mmapManager_->ClearExpiredFds(fds); }); + reportWorkerLost_ = reportWorkerLost; + isInit_ = true; + return Status::OK(); +} + +Status StreamClientImpl::CreatePrefetchPoolIfNotExist() +{ + const int numPrefetchThreads = 8; + std::lock_guard lock(initMutex_); + if (!prefetchThdPool_) { + RETURN_IF_EXCEPTION_OCCURS(prefetchThdPool_ = std::make_unique(numPrefetchThreads)); + } + return Status::OK(); +} + +uint32_t StreamClientImpl::GetLockId() const +{ + return clientWorkerApi_->GetLockId(); +} + +Status StreamClientImpl::VerifyProducerConfig(const ProducerConf &producerConf) +{ + const static int64_t smallestPageSize = 4 * 1024; // 4KB + // The maxPageSize is determined by the maximum offset supported in a page slot. + const static int64_t maxPageSize = SLOT_VALUE_MASK + 1; // 16MB max limit + const static uint64_t minStreamSize = 64 * 1024; // 64K + const static uint64_t maxRetainForNumConsumers = 16; + const static uint64_t minNumPages = 2; + CHECK_FAIL_RETURN_STATUS(producerConf.pageSize > 0 && producerConf.pageSize % smallestPageSize == 0, K_INVALID, + FormatString("Page size not multiple of %d.", smallestPageSize)); + CHECK_FAIL_RETURN_STATUS(producerConf.pageSize <= maxPageSize, K_INVALID, + FormatString("Page size exceeds the maximum. [page size, max size] : [ %zu, %zu ]", + producerConf.pageSize, maxPageSize)); + CHECK_FAIL_RETURN_STATUS( + producerConf.retainForNumConsumers <= maxRetainForNumConsumers, K_INVALID, + FormatString("retainForNumConsumers exceeds the maximum. [retain value, max limit] : [ %zu, %zu ]", + producerConf.retainForNumConsumers, maxRetainForNumConsumers)); + CHECK_FAIL_RETURN_STATUS( + producerConf.maxStreamSize >= minStreamSize, K_INVALID, + FormatString("Stream size must be at least the minimum size. [stream size, min size] : [ %zu, %zu ]", + producerConf.maxStreamSize, minStreamSize)); + CHECK_FAIL_RETURN_STATUS( + static_cast(producerConf.pageSize) <= producerConf.maxStreamSize, K_INVALID, + FormatString("Page size exceeds the maximum stream size. [page size, max stream size] : [ %zu, %zu ]", + producerConf.pageSize, producerConf.maxStreamSize)); + CHECK_FAIL_RETURN_STATUS( + producerConf.reserveSize <= producerConf.maxStreamSize, K_INVALID, + FormatString("Reserve size exceeds the maximum stream size. [reserve size, max stream size] : [ %zu, %zu ]", + producerConf.reserveSize, producerConf.maxStreamSize)); + CHECK_FAIL_RETURN_STATUS( + producerConf.reserveSize % producerConf.pageSize == 0, K_INVALID, + FormatString("Reserve size not a multiple of page size. [page size, reserve size] : [ %zu, %zu ]", + producerConf.pageSize, producerConf.reserveSize)); + CHECK_FAIL_RETURN_STATUS( + producerConf.maxStreamSize / producerConf.pageSize >= minNumPages, K_INVALID, + FormatString("Stream size must be at least twice the page size. [page size, max stream size] : [ %zu, %zu ]", + producerConf.pageSize, producerConf.maxStreamSize)); + return Status::OK(); +} + +Status StreamClientImpl::CreateProducer(const std::string &streamName, std::shared_ptr &outProducer, + const ProducerConf &producerConf) +{ + CHECK_FAIL_RETURN_STATUS(Validator::IsRegexMatch(idRe_, streamName), K_INVALID, + "The streamName contains illegal char(s)."); + RETURN_IF_NOT_OK(VerifyProducerConfig(producerConf)); + RETURN_IF_NOT_OK(CheckConnectByUds()); + RETURN_IF_NOT_OK(IsClientReady()); + RETURN_IF_NOT_OK(CheckWorkerLost()); + RETURN_IF_NOT_OK(listenWorker_->CheckWorkerAvailable()); + std::string producerId = GetStringUuid(); + ShmView pageView, streamMetaView; + DataVerificationHeader::SenderProducerNo senderProducerNo; + DataVerificationHeader::Address address; + inet_pton(AF_INET, clientWorkerApi_->GetWorkHost().c_str(), &address); + DataVerificationHeader::Port port = static_cast(clientWorkerApi_->GetWorkPort()); + bool enableStreamDataVerification; + uint64_t streamNo; + bool enableSharedPage; + uint64_t sharedPageSize; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + clientWorkerApi_->CreateProducer(streamName, producerId, producerConf, pageView, senderProducerNo, + enableStreamDataVerification, streamNo, enableSharedPage, sharedPageSize, + streamMetaView), + "CreateProducer request error"); + INJECT_POINT("Mimic.Producer.Old.Version", [&enableStreamDataVerification, &senderProducerNo]() { + enableStreamDataVerification = false; + senderProducerNo = 0; + return Status::OK(); + }); + std::string tenantId = g_ContextTenantId.empty() ? tenantId_ : g_ContextTenantId; + std::shared_ptr clientWorkerApi = + std::make_shared(tenantId, clientWorkerApi_); + auto impl = std::make_shared( + streamName, tenantId, producerId, producerConf.delayFlushTime, producerConf.pageSize, clientWorkerApi, + shared_from_this(), mmapManager_.get(), listenWorker_, pageView, producerConf.maxStreamSize, senderProducerNo, + enableStreamDataVerification, address, port, producerConf.streamMode, streamNo, enableSharedPage, + sharedPageSize, streamMetaView); + + Status rc = impl->Init(); + class ProducerHelper : public Producer { + public: + explicit ProducerHelper(std::shared_ptr impl) : Producer(std::move(impl)) + { + } + }; + outProducer = std::make_shared(std::move(impl)); + if (rc.IsError()) { + RETURN_IF_NOT_OK(outProducer->Close()); + RETURN_STATUS( + StatusCode::K_RUNTIME_ERROR, + FormatString("Fail to init mmap memory for producer:<%s> with status: %s", producerId, rc.GetMsg())); + } + std::lock_guard lk(clearMutex_); + producers_.emplace(producerId, outProducer); + LOG(INFO) << FormatString( + "[%s] Create producer success. AutoCleanup is %s, senderProducerNo = %lu, " + "enableStreamDataVerification = %s, workerArea = %s, streamNo = %llu, enableSharedPage = %s, sharedPageSize = " + "%lu", + outProducer->impl_->LogPrefix(), producerConf.autoCleanup ? "true" : "false", senderProducerNo, + (enableStreamDataVerification ? "true" : "false"), (outProducer->impl_->WorkAreaIsV2() ? "V2" : "V1"), streamNo, + (enableSharedPage ? "true" : "false"), sharedPageSize); + return Status::OK(); +} + +Status StreamClientImpl::Subscribe(const std::string &streamName, const struct SubscriptionConfig &config, + std::shared_ptr &outConsumer, bool autoAck) +{ + RETURN_IF_NOT_OK(CheckConnectByUds()); + RETURN_IF_NOT_OK(IsClientReady()); + CHECK_FAIL_RETURN_STATUS(Validator::IsRegexMatch(idRe_, streamName), K_INVALID, + "The streamName contains illegal char(s)."); + CHECK_FAIL_RETURN_STATUS(Validator::IsRegexMatch(idRe_, config.subscriptionName), K_INVALID, + "The subscriptionName contains illegal char(s)."); + RETURN_IF_NOT_OK(CheckWorkerLost()); + RETURN_IF_NOT_OK(listenWorker_->CheckWorkerAvailable()); + std::string consumerId = GetStringUuid(); + SubscribeRspPb rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerApi_->Subscribe(streamName, consumerId, config, rsp), + "Subscribe request error"); + std::string tenantId = g_ContextTenantId.empty() ? tenantId_ : g_ContextTenantId; + std::shared_ptr clientWorkerApi = + std::make_shared(tenantId, clientWorkerApi_); + auto consumerImpl = std::make_unique(streamName, tenantId, config, consumerId, rsp, clientWorkerApi, + shared_from_this(), mmapManager_.get(), listenWorker_, autoAck); + Status rc = consumerImpl->Init(); + if (rc.GetCode() == K_OUT_OF_RANGE) { + // This is a special return code indicating that the worker may have restarted. + // The call back function ClearProducerAndConsumer may set each producer/consumer to CLOSE state + // but, our consumer is not yet in the consumers_ map yet. So we need to check the state again + RETURN_IF_NOT_OK(CheckWorkerLost()); + RETURN_IF_NOT_OK(listenWorker_->CheckWorkerAvailable()); + } + RETURN_IF_NOT_OK(rc); + + // When initializing Consumer, the first element to receive. + class ConsumerHelper : public Consumer { + public: + explicit ConsumerHelper(std::unique_ptr impl) : Consumer(std::move(impl)) + { + } + }; + outConsumer = std::make_shared(std::move(consumerImpl)); + + std::lock_guard lk(clearMutex_); + consumers_.emplace(consumerId, outConsumer); + LOG(INFO) << FormatString("[%s] Create consumer success. AutoAck is %s, workerArea is %s", + outConsumer->impl_->LogPrefix(), autoAck ? "true" : "false", + (outConsumer->impl_->WorkAreaIsV2() ? "V2" : "V1")); + return Status::OK(); +} + +Status StreamClientImpl::DeleteStream(const std::string &streamName) +{ + RETURN_IF_NOT_OK(IsClientReady()); + CHECK_FAIL_RETURN_STATUS(Validator::IsRegexMatch(idRe_, streamName), K_INVALID, + "The streamName contains illegal char(s)."); + RETURN_IF_NOT_OK(listenWorker_->CheckWorkerAvailable()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerApi_->DeleteStream(streamName), + FormatString("[S:%s] Delete stream failed.", streamName)); + LOG(INFO) << FormatString("[S:%s] Delete stream success.", streamName); + return Status::OK(); +} + +Status StreamClientImpl::QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum) +{ + RETURN_IF_NOT_OK(IsClientReady()); + CHECK_FAIL_RETURN_STATUS(Validator::IsRegexMatch(idRe_, streamName), K_INVALID, + "The streamName contains illegal char(s)."); + gProducerNum = 0; + RETURN_IF_NOT_OK(CheckWorkerLost()); + RETURN_IF_NOT_OK(listenWorker_->CheckWorkerAvailable()); + return clientWorkerApi_->QueryGlobalProducersNum(streamName, gProducerNum); +} + +Status StreamClientImpl::QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum) +{ + RETURN_IF_NOT_OK(IsClientReady()); + CHECK_FAIL_RETURN_STATUS(Validator::IsRegexMatch(idRe_, streamName), K_INVALID, + "The streamName contains illegal char(s)."); + gConsumerNum = 0; + RETURN_IF_NOT_OK(CheckWorkerLost()); + RETURN_IF_NOT_OK(listenWorker_->CheckWorkerAvailable()); + return clientWorkerApi_->QueryGlobalConsumersNum(streamName, gConsumerNum); +} + +void StreamClientImpl::CleanupProdsCons(std::vector> &resetProducers, + std::vector> &resetConsumers) +{ + for (auto &producer : resetProducers) { + if (auto ptr = producer.lock()) { + ptr->impl_->Reset(); + } + } + for (auto &consumer : resetConsumers) { + if (auto ptr = consumer.lock()) { + ptr->impl_->Reset(); + } + } +} + +void StreamClientImpl::ClearProducer(const std::string &producerId) +{ + std::lock_guard lk(clearMutex_); + auto num = producers_.erase(producerId); + LOG_IF(WARNING, num == 0) << "Producer " << producerId << " not found in client."; +} + +void StreamClientImpl::ClearConsumer(const std::string &consumerId) +{ + std::lock_guard lk(clearMutex_); + auto num = consumers_.erase(consumerId); + LOG_IF(WARNING, num == 0) << "Consumer " << consumerId << " not found in client."; +} + +void StreamClientImpl::ClearProducerAndConsumer() +{ + auto func = [](auto &user) { + // Ensure that this consumer or producer used is not destroyed. + if (auto ptr = user.second.lock()) { + ptr->impl_->SetInactive(); + } + }; + std::lock_guard lk(clearMutex_); + LOG(INFO) << FormatString("Begin to clear %zu producers and %zu consumers", producers_.size(), consumers_.size()); + if (!producers_.empty()) { + std::for_each(producers_.begin(), producers_.end(), func); + producers_.clear(); + } + if (!consumers_.empty()) { + std::for_each(consumers_.begin(), consumers_.end(), func); + consumers_.clear(); + } + LOG(INFO) << "Clear producer and consumer success"; +} + +inline Status StreamClientImpl::IsClientReady() +{ + uint16_t clientState = clientStateManager_->GetState(); + CHECK_FAIL_RETURN_STATUS(clientState == (uint16_t)ClientState::INITIALIZED, StatusCode::K_NOT_READY, + clientStateManager_->ToStringForUser(clientState)); + return Status::OK(); +} + +Status StreamClientImpl::CheckConnectByUds() +{ + RETURN_OK_IF_TRUE(clientWorkerApi_->GetShmEnabled()); + RETURN_STATUS( + StatusCode::K_RUNTIME_ERROR, + "Connection to worker not using unix domain socket, please check if the connection is to a local worker."); +} +} // namespace stream_cache +} // namespace client +} // namespace datasystem diff --git a/src/datasystem/client/stream_cache/stream_client_impl.h b/src/datasystem/client/stream_cache/stream_client_impl.h new file mode 100644 index 0000000..0192646 --- /dev/null +++ b/src/datasystem/client/stream_cache/stream_client_impl.h @@ -0,0 +1,237 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement stream cache client. + */ +#ifndef DATASYSTEM_CLIENT_STREAM_CACHE_STREAM_CLIENT_IMPL_H +#define DATASYSTEM_CLIENT_STREAM_CACHE_STREAM_CLIENT_IMPL_H + +#include +#include + +#include + +#include "datasystem/client/client_state_manager.h" +#include "datasystem/client/listen_worker.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/common/ak_sk/signature.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/sensitive_value.h" +#include "datasystem/utils/status.h" +#include "datasystem/utils/connection.h" + +namespace datasystem { +namespace client { +namespace stream_cache { +class StreamClientImpl : public std::enable_shared_from_this { +public: + explicit StreamClientImpl(const std::string &clientPublicKey = "", const SensitiveValue &clientPrivateKey = "", + const std::string &serverPublicKey = "", const std::string &accessKey = "", + const SensitiveValue &secretKey = ""); + + explicit StreamClientImpl(const ConnectOptions &connectOptions); + + virtual ~StreamClientImpl(); + + /** + * @brief Shutdown a object client instance. + * @param[out] needRollbackState If the client status is successfully changed to INTERMEDIATE, + * the status needs to be rolled back based on the completion status when the request is completed. + * @param[in] isDestruct Since shutdown will also be called during client's destruction, + * this parameter is used to avoid redundant log printing in the destruction scenario. + * @return K_OK on success; the error code otherwise. + */ + Status ShutDown(bool &needRollbackState, bool isDestruct = false); + + /** + * @brief Create one Producer to send element. + * @param[in] streamName The name of stream. + * @param[out] outProducer The output Producer that user can use it to send element. + * @param[in] producerConf The producer configure. + * @return Status of the call. + */ + Status CreateProducer(const std::string &streamName, std::shared_ptr &outProducer, + const ProducerConf &producerConf = {}); + + /** + * @brief Create the relation of subscribe and generate one Consumer to receive elements. + * @param[in] streamName The name of stream. + * @param[in] config The config of subscription. + * @param[out] outConsumer The output Consumer that user can use it to receive data elements. + * @param[in] autoAck Toggles if autoAck is on or off + * @return Status of the call. + */ + Status Subscribe(const std::string &streamName, const struct SubscriptionConfig &config, + std::shared_ptr &outConsumer, bool autoAck); + + /** + * @brief Delete one stream. + * @param[in] streamName The name of stream. + * @return Status of the call. + */ + Status DeleteStream(const std::string &streamName); + + /** + * @brief Query the number of global producers + * @param[in] streamName The target stream + * @param[out] gProducerNum The number of of global producers + * @return Status of the call. + */ + Status QueryGlobalProducersNum(const std::string &streamName, uint64_t &gProducerNum); + + /** + * @brief Query the number of global consumers. + * @param[in] streamName The target stream. + * @param[out] gConsumerNum The number of of global consumers. + * @return Status of the call. + */ + Status QueryGlobalConsumersNum(const std::string &streamName, uint64_t &gConsumerNum); + + /** + * @brief Initialize the Ds client connector. + * @param[in] ip The worker ip address. + * @param[in] port The worker port. + * @param[in] reportWorkerLost Whether to report to the caller when worker had crashed or worker lost the client. + * @return K_OK on success; the error code otherwise. + * K_INVALID: the input ip or port is invalid. + */ + Status Init(const std::string &ip, const int &port, bool &needRollbackState, bool reportWorkerLost); + + /** + * Fetch the lock id granted to this client. + * @return 4 byte lock id + */ + uint32_t GetLockId() const; + + /** + * @brief Set producer and consumer inactive and clear. + */ + void ClearProducerAndConsumer(); + + /** + * @brief Init/Shutdown complete handler. + * @param[in] failed Init/Shutdown success or not. + * @param[out] needRollbackState If the client status is successfully changed to INTERMEDIATE, + * the status needs to be rolled back based on the completion status when the request is completed. + */ + void CompleteHandler(bool failed, bool needRollbackState) + { + clientStateManager_->CompleteHandler(failed, needRollbackState); + } + + /** + * @brief Creates the prefetch pool if it is not created yet. + * @return Status of the call + */ + Status CreatePrefetchPoolIfNotExist(); + + /** + * @brief Returns a pointer to the prefetch pool. + * @return pointer to the prefetch pool + */ + ThreadPool *GetPrefetchPool() + { + return prefetchThdPool_.get(); + } + + /** + * @brief Check if worker was lost between heartbeats. + * @return Status of the call + */ + inline const Status CheckWorkerLost() + { + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + !workerWasLost_, StatusCode::K_SC_WORKER_WAS_LOST, + FormatString("Client %s detected worker %s was lost", clientWorkerApi_->GetClientId(), + clientWorkerApi_->GetWorkHost())); + return Status::OK(); + } + + /** + * @brief Remove closed producer. + * @param[in] producerId Producer Id. + */ + void ClearProducer(const std::string &producerId); + + /** + * @brief Remove closed consumer. + * @param[in] consumerId Consumer Id. + */ + void ClearConsumer(const std::string &consumerId); + +private: + /** + * @brief Check the client is ready to execute any api + * @return Status + */ + inline Status IsClientReady(); + + /** + * @brief Check whether connect by unix domain socket. + * @return Status + */ + Status CheckConnectByUds(); + + /** + * @brief Clear all data and metadata for the producers and consumers for a resetting stream. + * @param[in] resetProducers Pointer to the list of producers getting cleaned up + * @param[in] resetConsumers Pointer to the list of consumers getting cleaned up + */ + void CleanupProdsCons(std::vector> &resetProducers, + std::vector> &resetConsumers); + + /** + * @brief verify ProducerConfig. + * @param[in] producerConf The producer config + * @return Status of this call. + */ + static Status VerifyProducerConfig(const ProducerConf &producerConf); + + std::mutex initMutex_; + std::unique_ptr signature_{ nullptr }; + std::shared_ptr clientWorkerApi_; + std::unique_ptr mmapManager_; + std::function callBack_; // Fail callback handle, if worker disconnect this function would be call. + std::shared_timed_mutex clearMutex_; // Protect producers_ and consumers_. + std::unordered_map> + producers_; // Ensure that the producer can be automatically destroyed. Key is producerId. + std::unordered_map> + consumers_; // Ensure that the consumer can be automatically destroyed. Key is consumerId. + bool isInit_ = { false }; + bool reportWorkerLost_{ false }; + std::atomic workerWasLost_{ false }; + std::unique_ptr clientStateManager_{ nullptr }; + std::unique_ptr prefetchThdPool_{ nullptr }; + + // Listenworker needs to be placed at the bottom to ensure that it is destructed first. + std::shared_ptr listenWorker_{ nullptr }; + + RpcAuthKeys authKeys_; + std::string tenantId_; + int32_t timeoutMs_ = RPC_TIMEOUT; + // verify object key format. + re2::RE2 idRe_{ "^[a-zA-Z0-9\\~\\.\\-\\/_!@#%\\^\\&\\*\\(\\)\\+\\=\\:;]*$" }; +}; +} // namespace stream_cache +} // namespace client +} // namespace datasystem +#endif diff --git a/src/datasystem/common/CMakeLists.txt b/src/datasystem/common/CMakeLists.txt index 62b3464..321130c 100644 --- a/src/datasystem/common/CMakeLists.txt +++ b/src/datasystem/common/CMakeLists.txt @@ -12,11 +12,11 @@ add_subdirectory(httpclient) add_subdirectory(iam) add_subdirectory(encrypt) add_subdirectory(immutable_string) +add_subdirectory(string_intern) add_subdirectory(object_cache) add_subdirectory(device) +add_subdirectory(stream_cache) add_subdirectory(l2cache) add_subdirectory(flags) add_subdirectory(signal) add_subdirectory(rdma) - - diff --git a/src/datasystem/common/constants.h b/src/datasystem/common/constants.h index 7d9bcba..1419354 100644 --- a/src/datasystem/common/constants.h +++ b/src/datasystem/common/constants.h @@ -48,8 +48,20 @@ static const std::string HEALTH_TABLE = "health_table"; / static const int CHECK_FILE_EXIST_INTERVAL_S = 5; // The time for back up replica check db path exsit. +// sc meta tables +static const std::string STREAM_TABLE_NAME = "stream_table"; +static const std::string PUB_TABLE_NAME = "pub_table"; +static const std::string SUB_TABLE_NAME = "sub_table"; +static const std::string NOTIFY_PUB_TABLE_NAME = "notify_pub_table"; +static const std::string NOTIFY_SUB_TABLE_NAME = "notify_sub_table"; +static const std::string STREAM_CON_CNT_TABLE_NAME = "stream_consumer_count"; +static const std::string STREAM_PRODUCER_COUNT = "stream_producer_count"; + static const int ASYNC_LOGGER_STOP_MAX_WAIT_SEC = 15; // The max wait time when AsyncLogger Stop. +// stream data object +static const uint64_t DEFAULT_TIMEOUT_MS = 1000; + // cluster info table in rocksDb static const std::string CLUSTER_TABLE = "cluster_table"; static const std::string HASHRING_TABLE = "hashring_table"; @@ -87,6 +99,10 @@ struct CreateDeviceParam { LifetimeType lifetime = LifetimeType::REFERENCE; bool cacheLocation = true; }; + +// ub device +static const std::string ENV_UB_DEVICE_NAME = "DS_UB_DEV_NAME"; +static const std::string DEFAULT_UB_DEVICE_NAME = "bonding_dev_0"; } // namespace datasystem #endif // DATASYSTEM_COMMON_CONSTANTS_H diff --git a/src/datasystem/common/device/ascend/acl_device_manager.cpp b/src/datasystem/common/device/ascend/acl_device_manager.cpp index 6f68fb1..1d1d835 100644 --- a/src/datasystem/common/device/ascend/acl_device_manager.cpp +++ b/src/datasystem/common/device/ascend/acl_device_manager.cpp @@ -27,6 +27,7 @@ #include #include "datasystem/common/ak_sk/hasher.h" +#include "datasystem/common/log/trace.h" #include "datasystem/common/perf/perf_manager.h" #include "datasystem/common/util/dlutils.h" #include "datasystem/common/util/file_util.h" @@ -68,10 +69,16 @@ AclDeviceManager *AclDeviceManager::Instance() void AclDeviceManager::Init() { #ifdef BUILD_HETERO - loadPluginThread_ = std::make_unique([this]() { LOG_IF_ERROR(this->LoadPlugin(), "Load plugin failed."); }); + auto traceId = Trace::Instance().GetTraceID(); + loadPluginThread_ = std::make_unique([this, traceId]() { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + Status loadStatus = this->LoadPlugin(); + waitPost_->SetWithStatus(loadStatus); + }); #else - state_ = State::INIT_ERROR; - waitPost_->Set(); + waitPost_->SetWithStatus(Status(K_RUNTIME_ERROR, + "Heterogeneous api is currently unavailable. Ensure that the compilation switch '-X on' is enabled when " + "building datasystem.")); #endif } @@ -79,15 +86,12 @@ Status AclDeviceManager::LoadPlugin() { Dl_info dlInfo; Status lastRc = Status::OK(); - Raii raii([this]() { + Raii raii([this]() { waitPost_->Set(); }); if (dladdr(reinterpret_cast(AclDeviceManager::Instance), &dlInfo) == 0) { - lastRc = - Status(K_RUNTIME_ERROR, FormatString("Load Ascend plugin failed, get dladdr error: %s", GetDlErrorMsg())); - LOG(ERROR) << lastRc; - state_ = State::INIT_ERROR; - return lastRc; + RETURN_STATUS_LOG_ERROR( + K_RUNTIME_ERROR, FormatString("Load Ascend plugin failed, get dladdr error: %s", GetDlErrorMsg())); } std::string curSoPath = dlInfo.dli_fname; std::string aclPluginPath = std::string(dirname(const_cast(curSoPath.data()))) + "/" + AclPluginLibrary; @@ -95,14 +99,11 @@ Status AclDeviceManager::LoadPlugin() RETURN_IF_NOT_OK(VerifyingSha256(aclPluginPath)); pluginHandle_ = dlopen(aclPluginPath.c_str(), RTLD_LAZY | RTLD_LOCAL); if (pluginHandle_ == nullptr) { - lastRc = Status(K_INVALID, FormatString("Load Ascend plugin failed, dlopen error: %s", GetDlErrorMsg())); - LOG(ERROR) << lastRc; - state_ = State::INIT_ERROR; - return lastRc; + RETURN_STATUS_LOG_ERROR( + K_INVALID, FormatString("Load Ascend plugin failed, dlopen error: %s", GetDlErrorMsg())); } else { DlsymFuncObj(); } - state_ = State::INIT_OK; return Status::OK(); } @@ -169,45 +170,18 @@ void AclDeviceManager::DlsymFuncObj() Status AclDeviceManager::CheckPluginOk() { - std::call_once(hasLoadPlugin_, [this]() { - state_ = State::PENDING_INIT; - instance_->Init(); - }); - waitPost_->Wait(); - return CheckState(); -} - -Status AclDeviceManager::CheckState() -{ - switch (state_) { - case State::INIT_ERROR: - RETURN_STATUS( - K_RUNTIME_ERROR, - "The Ascend plugin is not load success, please check as follows: 1. Ensure that the compilation switch " - "'-X on' is enabled when building datasystem; 2. Ensure that the environment have ascendcl.so in " - "LD_LIBRARY_PATH and libacl_plugin.so in library path; 3. Please make sure that ascendcl.so and " - "libacl_plugin.so have been loaded correctly."); - case State::PENDING_INIT: - RETURN_STATUS(K_RUNTIME_ERROR, "AclDeviceManager is not init ready."); - case State::NOT_INIT: - RETURN_STATUS(K_INVALID, "AclDeviceManager is not init ready."); - case State::INIT_OK: - break; + std::call_once(hasLoadPlugin_, []() { instance_->Init(); }); + Status loadStatus = waitPost_->WaitAndGetStatus(); + if (!loadStatus.IsOk()) { + return loadStatus; } return Status::OK(); } Status AclDeviceManager::VerifyDeviceId(std::vector deviceIds) { - constexpr int normalCode = 0; // ACL_RT_DEVICE_STATUS_NORMAL - for (int devId : deviceIds) { - int32_t status; - auto rc = aclrtQueryDeviceStatus(devId, &status); - if (rc.IsError() || status != normalCode) { - std::string errorMsg = FormatString("Got Error/ABNORMAL device, deviceId: %d, code: %s, msg: %s", devId, - status, rc.ToString()); - return Status(K_INVALID, errorMsg); - } + for (const auto devId : deviceIds) { + RETURN_IF_NOT_OK(aclrtQueryDeviceStatus(devId)); } return Status::OK(); } @@ -241,7 +215,10 @@ Status AclDeviceManager::VerifyingSha256(const std::string &aclPluginPath) // Step 4: Check whether the hash values are consistent. if (ss.str() != ACL_PLUGIN_SHA256) { RETURN_STATUS_LOG_ERROR(K_NOT_AUTHORIZED, - "Load Ascend plugin failed, which fails to pass the integrity check."); + "Load Ascend plugin failed, which fails to pass the integrity check. " + "Possible causes and solutions: " + "1.This usually occurs when libacl_plugin.so and libdatasystem.so are from different versions. Ensure " + "both libraries are from the same version."); } return Status::OK(); } @@ -265,7 +242,15 @@ Status AclDeviceManager::GetDeviceIdx(int32_t &deviceIdx) { RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(GetDeviceIdxFunc_); - RETURN_ACL_RESULT(GetDeviceIdxFunc_(deviceIdx)); + int aclRet = GetDeviceIdxFunc_(deviceIdx); + constexpr int normalCode = 0; // ACL_RT_DEVICE_STATUS_NORMAL + if (aclRet != normalCode) { + RETURN_STATUS_LOG_ERROR(K_INVALID, + FormatString( + "May not create context or set device in this thread. Detail: acl api failed with error code %d", + aclRet)); + } + return Status::OK(); } Status AclDeviceManager::SetDeviceIdx(int32_t deviceId) @@ -335,7 +320,15 @@ Status AclDeviceManager::DSHcclGetRootInfo(HcclRootInfo *rootInfo) { RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(DSHcclGetRootInfoFunc_); - RETURN_HCCL_RESULT(DSHcclGetRootInfoFunc_(rootInfo)); + int hcclRet = DSHcclGetRootInfoFunc_(rootInfo); + if (hcclRet == 1) { // HCCL_E_PARA = 1 + RETURN_STATUS(K_HCCL_ERROR, + "HcclGetRootInfoapi failed with error code 1 (parameter error). Possible cause: HCCL failed to obtain IP " + "address (null). Please check Ascend logs for detailed error information. " + "Solution: If IP address acquisition failed, configure environment variable " + "HCCL_IF_IP with the current host IP address."); + } + RETURN_HCCL_RESULT(hcclRet); } Status AclDeviceManager::DSHcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, @@ -343,7 +336,7 @@ Status AclDeviceManager::DSHcclCommInitRootInfo(uint32_t nRanks, const HcclRootI { RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(DSHcclCommInitRootInfoFunc_); - RETURN_HCCL_RESULT(DSHcclCommInitRootInfoFunc_(nRanks, rootInfo, rank, comm)); + return HandleHcclResult(DSHcclCommInitRootInfoFunc_(nRanks, rootInfo, rank, comm)); } Status AclDeviceManager::DSHcclSend(void *sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank, @@ -352,7 +345,7 @@ Status AclDeviceManager::DSHcclSend(void *sendBuf, uint64_t count, HcclDataType PerfPoint point(PerfKey::DS_HCCL_SEND); RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(DSHcclSendFunc_); - RETURN_HCCL_RESULT(DSHcclSendFunc_(sendBuf, count, dataType, destRank, comm, stream)); + return HandleHcclResult(DSHcclSendFunc_(sendBuf, count, dataType, destRank, comm, stream)); } Status AclDeviceManager::DSHcclRecv(void *recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank, @@ -361,7 +354,7 @@ Status AclDeviceManager::DSHcclRecv(void *recvBuf, uint64_t count, HcclDataType PerfPoint point(PerfKey::DS_HCCl_RECEIVE); RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(DSHcclRecvFunc_); - RETURN_HCCL_RESULT(DSHcclRecvFunc_(recvBuf, count, dataType, srcRank, comm, stream)); + return HandleHcclResult(DSHcclRecvFunc_(recvBuf, count, dataType, srcRank, comm, stream)); } Status AclDeviceManager::DSHcclCommDestroy(HcclComm comm) @@ -406,11 +399,13 @@ Status AclDeviceManager::DSAclrtDestroyEvent(aclrtEvent event) RETURN_ACL_RESULT(DSAclrtDestroyEventFunc_(event)); } -Status AclDeviceManager::DSHcclGetCommAsyncError(HcclComm comm, HcclResult *asyncError) +Status AclDeviceManager::DSHcclGetCommAsyncError(HcclComm comm) { RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(DSHcclGetCommAsyncErrorFunc_); - RETURN_HCCL_RESULT(DSHcclGetCommAsyncErrorFunc_(comm, asyncError)); + HcclResult asyncError; + HandleHcclResult(DSHcclGetCommAsyncErrorFunc_(comm, &asyncError)); + return HandleHcclResult(asyncError); } Status AclDeviceManager::aclInit(const char *configPath) @@ -483,11 +478,22 @@ Status AclDeviceManager::aclrtGetDeviceCount(uint32_t *count) RETURN_ACL_RESULT(DSAclrtGetDeviceCountFunc_(count)); } -Status AclDeviceManager::aclrtQueryDeviceStatus(uint32_t deviceId, int32_t *deviceStatus) +Status AclDeviceManager::aclrtQueryDeviceStatus(uint32_t deviceId) { RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(DSAclrtQueryDeviceStatusFunc_); - RETURN_ACL_RESULT(DSAclrtQueryDeviceStatusFunc_(deviceId, deviceStatus)); + int32_t deviceStatus; + int aclRet = DSAclrtQueryDeviceStatusFunc_(deviceId, &deviceStatus); + constexpr int normalCode = 0; // ACL_RT_DEVICE_STATUS_NORMAL + if (aclRet != normalCode) { + RETURN_STATUS_LOG_ERROR(K_INVALID, + FormatString( + "Got Error/ABNORMAL device, deviceId: %d. Detail: acl api failed with error code %d, deviceStatus: %d", + deviceId, + aclRet, + deviceStatus)); + } + return Status::OK(); } AclDeviceManager::~AclDeviceManager() @@ -497,14 +503,10 @@ AclDeviceManager::~AclDeviceManager() void AclDeviceManager::Shutdown() { - if (loadPluginThread_ == nullptr) { - return; + if (loadPluginThread_ != nullptr) { + loadPluginThread_->join(); + loadPluginThread_.reset(); } - if (state_ == State::PENDING_INIT) { - waitPost_->Wait(); - } - loadPluginThread_->join(); - loadPluginThread_.reset(); } Status AclDeviceManager::DSAclrtQueryEventStatus(aclrtEvent event) @@ -561,11 +563,13 @@ Status AclDeviceManager::DSP2PRecv(void *recvBuf, uint64_t count, HcclDataType d RETURN_HCCL_RESULT(DSP2PRecvFunc_(recvBuf, count, dataType, comm, stream)); } -Status AclDeviceManager::DSP2PGetCommAsyncError(P2PComm comm, HcclResult *asyncError) +Status AclDeviceManager::DSP2PGetCommAsyncError(P2PComm comm) { RETURN_IF_NOT_OK(CheckPluginOk()); RETURN_RUNTIME_ERROR_IF_NULL(DSP2PGetCommAsyncErrorFunc_); - RETURN_HCCL_RESULT(DSP2PGetCommAsyncErrorFunc_(comm, asyncError)); + HcclResult asyncError; + HandleHcclResult(DSP2PGetCommAsyncErrorFunc_(comm, &asyncError)); + return HandleHcclResult(asyncError); } Status AclDeviceManager::RtNotifyCreate(int32_t deviceId, void **notify) @@ -637,5 +641,35 @@ Status AclDeviceManager::AclrtUnSubscribeReport(uint64_t threadId, aclrtStream s RETURN_RUNTIME_ERROR_IF_NULL(DSAclrtUnSubscribeReportFunc_); RETURN_ACL_RESULT(DSAclrtUnSubscribeReportFunc_(threadId, stream)); } + +Status AclDeviceManager::HandleHcclResult(int hcclResult) +{ + constexpr int hcclEUnavail = 7; // HCCL_E_UNAVAIL + constexpr int hcclERemote = 21; // HCCL_E_REMOTE + constexpr int hcclESuspending = 22; // HCCL_E_SUSPENDING + switch (hcclResult) { + case hcclEUnavail: + return Status(StatusCode::K_HCCL_ERROR, __LINE__, __FILE__, + "HCCL api operation failed with error code: 7 (resource unavailable). " + "Possible causes:" + "1. NPU are occupied or device unavailability."); + + case hcclERemote: + return Status(StatusCode::K_HCCL_ERROR, __LINE__, __FILE__, + "HCCL api operation failed with error code: 21 (error cqe). Indicates that an 'RDMA ERROR " + "CQE' error has occurred within this communication domain."); + + case hcclESuspending: + return Status(StatusCode::K_HCCL_ERROR, __LINE__, __FILE__, + "HCCL api operation failed with error code: 22 (error communicator suspending). " + "This usually occurs when the device state was reset unexpectedly, " + "causing the communication domain to be destroyed. " + "Possible causes:" + "1. Device reset operation after HCCL communicator initialization."); + + default: + RETURN_HCCL_RESULT(hcclResult); + } +} } // namespace acl } // namespace datasystem diff --git a/src/datasystem/common/device/ascend/acl_device_manager.h b/src/datasystem/common/device/ascend/acl_device_manager.h index c18cd16..e454ebd 100644 --- a/src/datasystem/common/device/ascend/acl_device_manager.h +++ b/src/datasystem/common/device/ascend/acl_device_manager.h @@ -36,14 +36,14 @@ #include "datasystem/common/device/ascend/p2phccl_types.h" #include "datasystem/utils/status.h" -#define RETURN_CANN_RESULT(aclRet, interType) \ - do { \ - int _aclRet = (aclRet); \ - if (_aclRet != 0) { \ - std::string errMsg = FormatString("%s api failed with error code %d ", interType, _aclRet); \ - return Status(StatusCode::K_ACL_ERROR, __LINE__, __FILE__, errMsg); \ - } \ - return Status::OK(); \ +#define RETURN_CANN_RESULT(aclRet, interType) \ + do { \ + int _aclRet = (aclRet); \ + if (_aclRet != 0) { \ + return Status(StatusCode::K_ACL_ERROR, __LINE__, __FILE__, FormatString("%s api failed with error code %d" \ + ", please refer to %s documentation for detailed error information. ", interType, aclRet, interType)); \ + } \ + return Status::OK(); \ } while (false) #define RETURN_ACL_RESULT(aclRet) \ @@ -76,12 +76,6 @@ public: */ virtual void Shutdown(); - /** - * @brief Check the plugin state. - * @return OK if plugin is ready. - */ - Status CheckState(); - /** * @brief Check the plugin is loaded ok. * @return OK if plugin is ready. @@ -299,7 +293,7 @@ public: * For details about other return values, see HcclResult Type. * @return Status of the call. */ - virtual Status DSHcclGetCommAsyncError(HcclComm comm, HcclResult *asyncError); + virtual Status DSHcclGetCommAsyncError(HcclComm comm); virtual Status aclInit(const char *configPath); @@ -309,7 +303,7 @@ public: virtual Status aclrtGetDeviceCount(uint32_t *count); - virtual Status aclrtQueryDeviceStatus(uint32_t deviceId, int32_t *deviceStatus); + virtual Status aclrtQueryDeviceStatus(uint32_t deviceId); virtual Status aclrtMemcpyAsync(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind, aclrtStream stream); @@ -387,7 +381,7 @@ public: * For details about other return values, see HcclResult Type. * @return Status of the call. */ - virtual Status DSP2PGetCommAsyncError(P2PComm comm, HcclResult *asyncError); + virtual Status DSP2PGetCommAsyncError(P2PComm comm); virtual Status RtNotifyCreate(int32_t deviceId, void **notify); virtual Status RtNotifyDestroy(void *notify); @@ -421,6 +415,18 @@ private: */ void DlsymFuncObj(); + void LoadResearchPlugin(); + + /** + * @brief Handle HCCL operation result and provide detailed error information + * @param hcclResult The HCCL operation result code to be checked + * @return Status Returns Status with detailed message if result is HCCL_E_SUSPENDING, + * otherwise returns the original HCCL result conversion + * @note Special handling for HCCL_E_SUSPENDING (22): Provides troubleshooting guidance for communicator suspension + * caused by device state reset or communication domain destruction + */ + Status HandleHcclResult(int hcclResult); + // Register plugin function as function pointer in class member. REG_METHOD(MallocDeviceMemory, int, size_t, void *&); REG_METHOD(FreeDeviceMemory, int, void *); @@ -485,14 +491,6 @@ private: static std::once_flag hasLoadPlugin_; static std::unique_ptr instance_; - enum class State : int { - NOT_INIT = 0, - PENDING_INIT, - INIT_OK, - INIT_ERROR, - }; - - std::atomic state_{ State::NOT_INIT }; void *pluginHandle_{ nullptr }; std::unique_ptr waitPost_; std::unique_ptr loadPluginThread_{ nullptr }; diff --git a/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.cpp b/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.cpp index d664e86..a89bd64 100644 --- a/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.cpp +++ b/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.cpp @@ -30,7 +30,7 @@ namespace datasystem { namespace acl { namespace { -bool EnableMerge(const std::vector blobs) +bool EnableMerge(const std::vector blobs) { if (blobs.size() <= 1) { return false; @@ -42,7 +42,7 @@ bool EnableMerge(const std::vector blobs) return true; }); auto totalSize = std::accumulate(blobs.cbegin(), blobs.cend(), 0ul, - [](size_t total, const DataInfo &info) { return total + info.size; }); + [](size_t total, const Blob &info) { return total + info.size; }); return totalSize / blobs.size() <= mergeLimit; } } // namespace @@ -58,7 +58,7 @@ PipeLineP2PBase::~PipeLineP2PBase() } } -Status PipeLineP2PBase::AllocTransferBuffer(size_t objectSize, DataInfo &transBuffer, uint64_t &seq) +Status PipeLineP2PBase::AllocTransferBuffer(size_t objectSize, Blob &transBuffer, uint64_t &seq) { auto tryReuse = [this](size_t objectSize, uint64_t ackSeq, std::vector &transferVec) { const uint64_t cacheSize = 2; @@ -111,11 +111,8 @@ Status PipeLineP2PBase::AllocTransferBuffer(size_t objectSize, DataInfo &transBu CHECK_FAIL_RETURN_STATUS(transferVec.begin()->GetSize() >= objectSize, K_RUNTIME_ERROR, FormatString("The transfer memory chunk size %zu too small, expect size %zu.", transferVec.begin()->GetSize(), objectSize)); - transBuffer.devPtr = transferVec.begin()->GetPointer(); - transBuffer.dataType = DataType::DATA_TYPE_INT8; - transBuffer.count = objectSize; + transBuffer.pointer = transferVec.begin()->GetPointer(); transBuffer.size = objectSize; - transBuffer.deviceIdx = 0; std::lock_guard locker(mutex_); seq = seq_.fetch_add(1); (void)transferUnitPools_.emplace(seq, std::move(transferVec)); @@ -139,7 +136,7 @@ Status PipeLineP2PSend::Submit(P2PSendTask &&task) if (!EnableMerge(task.srcBuffers)) { VLOG(1) << "Direct P2PSend."; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(task.comm->P2PSend(task.srcBuffers, task.event, PrimaryStream()), - "P2PRecv failed"); + "P2PSend failed"); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(task.event->RecordEvent(PrimaryStream()), "Record send event failed"); return Status::OK(); } @@ -165,11 +162,11 @@ Status PipeLineP2PSend::RunTaskPhaseOneImpl(size_t pipelineIndex, const P2PSendT auto &fftsDispatcher = GetResource()->fftsDispatcher; std::vector lastTaskId(MAX_FFTS_TASKS_COUNT, -1); size_t offset = 0; - auto transferPtr = task.transBuffer.devPtr; + auto transferPtr = task.transBuffer.pointer; auto transferSize = task.transBuffer.size; size_t srcCount = task.srcBuffers.size(); for (size_t n = 0; n < srcCount; n++) { - auto srcPtr = task.srcBuffers[n].devPtr; + auto srcPtr = task.srcBuffers[n].pointer; size_t srcSize = task.srcBuffers[n].size; void *destPtr = static_cast(static_cast(transferPtr) + offset); offset += srcSize; @@ -252,11 +249,11 @@ Status PipeLineP2PRecv::RunTaskPhaseTwoImpl(size_t pipelineIndex, const P2PRecvT auto &fftsDispatcher = GetResource()->fftsDispatcher; std::vector lastTaskId(MAX_FFTS_TASKS_COUNT, -1); size_t offset = 0; - auto transferPtr = task.transBuffer.devPtr; + auto transferPtr = task.transBuffer.pointer; auto transferSize = task.transBuffer.size; size_t srcCount = task.destBuffers.size(); for (size_t n = 0; n < srcCount; n++) { - auto destPtr = task.destBuffers[n].devPtr; + auto destPtr = task.destBuffers[n].pointer; size_t destSize = task.destBuffers[n].size; void *srcPtr = static_cast(static_cast(transferPtr) + offset); offset += destSize; diff --git a/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.h b/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.h index 9a9c8de..1ef4189 100644 --- a/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.h +++ b/src/datasystem/common/device/ascend/acl_pipeline_p2p_task.h @@ -33,11 +33,11 @@ namespace datasystem { class CommWrapperBase; namespace acl { struct P2PSendTask { - std::vector srcBuffers; + std::vector srcBuffers; size_t totalSize; std::shared_ptr comm; std::shared_ptr event; - DataInfo transBuffer; + Blob transBuffer; uint64_t seq{ 0 }; }; @@ -53,7 +53,7 @@ public: static void NotifyCallback(void *userData); protected: - Status AllocTransferBuffer(size_t objectSize, DataInfo &transBuffer, uint64_t &seq); + Status AllocTransferBuffer(size_t objectSize, Blob &transBuffer, uint64_t &seq); AclResourceManager *aclResourceMgr_; @@ -101,11 +101,11 @@ private: }; struct P2PRecvTask { - std::vector destBuffers; + std::vector destBuffers; size_t totalSize; std::shared_ptr comm; std::shared_ptr event; - DataInfo transBuffer; + Blob transBuffer; uint64_t seq{ 0 }; }; diff --git a/src/datasystem/common/device/ascend/acl_resource_manager.cpp b/src/datasystem/common/device/ascend/acl_resource_manager.cpp index 9ea36ac..c6a8649 100644 --- a/src/datasystem/common/device/ascend/acl_resource_manager.cpp +++ b/src/datasystem/common/device/ascend/acl_resource_manager.cpp @@ -135,7 +135,7 @@ Status AclMemMgrBase::Allocate(const std::vector &bMeta, std::ve } size_t allocSize = (type_ == AllocateType::DEV_DEVICE) ? maxAllocateSize : bMeta[i].size; - rc = memoryPool[i].AllocateMemory(DEFAULT_TENANTID, allocSize, false, type_); + rc = memoryPool[i].AllocateMemory(DEFAULT_TENANTID, allocSize, false, ServiceType::OBJECT, type_); if (retryNums <= 0) { break; } diff --git a/src/datasystem/common/device/ascend/callback_thread.cpp b/src/datasystem/common/device/ascend/callback_thread.cpp index ade212b..b5bb813 100644 --- a/src/datasystem/common/device/ascend/callback_thread.cpp +++ b/src/datasystem/common/device/ascend/callback_thread.cpp @@ -20,12 +20,16 @@ #include "datasystem/common/device/ascend/callback_thread.h" +#include "datasystem/common/log/trace.h" + namespace datasystem { namespace acl { CallbackThread::CallbackThread() { aclDeviceManager_ = acl::AclDeviceManager::Instance(); - thread_ = std::make_unique([this] { + auto traceId = Trace::Instance().GetTraceID(); + thread_ = std::make_unique([this, traceId] { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); constexpr uint32_t CALLBACK_TIMEOUT_MS = 100; wp_.Wait(); while (!exitFlag_) { diff --git a/src/datasystem/common/device/ascend/cann_types.h b/src/datasystem/common/device/ascend/cann_types.h index 3937a73..db0658a 100644 --- a/src/datasystem/common/device/ascend/cann_types.h +++ b/src/datasystem/common/device/ascend/cann_types.h @@ -92,6 +92,7 @@ typedef enum { HCCL_E_NETWORK = 19, /**< call network api fail */ HCCL_E_AGAIN = 20, /**< try again */ HCCL_E_REMOTE = 21, /**< error cqe */ + HCCL_E_SUSPENDING = 22, /**< error communicator suspending */ HCCL_E_RESERVED /**< reserved */ } HcclResult; diff --git a/src/datasystem/common/device/ascend/comm_wrapper_base.cpp b/src/datasystem/common/device/ascend/comm_wrapper_base.cpp index f93c1b9..e12f8bf 100644 --- a/src/datasystem/common/device/ascend/comm_wrapper_base.cpp +++ b/src/datasystem/common/device/ascend/comm_wrapper_base.cpp @@ -49,7 +49,7 @@ CommWrapperBase::CommWrapperBase(const std::string &commId, int localDeviceId, i }; pool_->Execute([func]() { (void)func(); }); - hcclDetailState_ = HCCL_SUCCESS; + hcclDetailState_ = Status::OK(); } CommWrapperBase::~CommWrapperBase() @@ -61,6 +61,80 @@ aclrtStream CommWrapperBase::GetStream() return resource_->PrimaryStream(); } +bool CommWrapperBase::IsCommReady() const +{ + return commReady_.load(); +} + +void CommWrapperBase::SetCommReady(bool ready) +{ + bool wasReady = commReady_.exchange(ready); + if (ready && !wasReady) { + ExecuteReadyCallbacks(); + } +} + +void CommWrapperBase::ExecuteReadyCallbacks() +{ + std::vector> callbacksToExecute; + { + std::lock_guard lock(stateMutex_); + + // Move all pending callbacks to local vector for execution + callbacksToExecute = std::move(readyCallbacks_); + readyCallbacks_.clear(); + + // Set flag to indicate callback execution is in progress + executingCallbacks_ = true; + } + + // Execute all callbacks without holding the lock + for (auto &callback : callbacksToExecute) { + callback(); + } + + { + std::lock_guard lock(stateMutex_); + + // Reset execution flag after all callbacks are completed + executingCallbacks_ = false; + } +} + +void CommWrapperBase::AddReadyCallback(std::function callback) +{ + bool shouldExecute = false; + { + std::lock_guard lock(stateMutex_); + + // If communication is ready, no callbacks are pending, and not currently executing, + // the callback can be executed immediately + if (IsCommReady() && !executingCallbacks_ && readyCallbacks_.empty()) { + shouldExecute = true; + } else { + // Add callback to the queue for ordered execution + readyCallbacks_.push_back(callback); + + // If communication is ready and not currently executing callbacks, + // trigger execution after releasing the lock + if (IsCommReady() && !executingCallbacks_) { + shouldExecute = true; + } + } + } + + // Execute callback or trigger execution outside of lock + if (shouldExecute) { + if (readyCallbacks_.empty()) { + // Direct execution for immediate case + callback(); + } else { + // Batch execution for queued callbacks + ExecuteReadyCallbacks(); + } + } +} + void CommWrapperBase::SetStatus(const Status &commStatus) { if (commStatus.IsOk()) { @@ -70,8 +144,9 @@ void CommWrapperBase::SetStatus(const Status &commStatus) } } -HcclResult CommWrapperBase::GetDetailStatus() const +Status CommWrapperBase::GetDetailStatus() const { + std::lock_guard lock(hcclDetailStateMutex_); return hcclDetailState_; } @@ -80,9 +155,12 @@ HcclCommState CommWrapperBase::GetCommStatus() const return hcclCommState_; } -void CommWrapperBase::SetHcclDetailState(HcclResult result) +void CommWrapperBase::SetHcclDetailState(Status result) { - hcclDetailState_ = result; + std::lock_guard lock(hcclDetailStateMutex_); + if (hcclDetailState_.IsOk()) { + hcclDetailState_ = result; + } } int CommWrapperBase::GetLocalDeviceId() const @@ -144,15 +222,29 @@ Status CommWrapperBase::InitPipeline(HcclCommDirection direction) } } +Status CommWrapperBase::CheckTranPointer(const void *pointer, const std::string &pointerName) +{ + if (pointer == nullptr) { + auto rc = GetDetailStatus(); + std::string errMsg = FormatString("The pointer [%s] is null, " + "which usually indicates that the hccl communication domain creation failed. " + "Specifically: [%s]", + pointerName, + rc.GetMsg()); + return Status(rc.GetCode(), errMsg); + } + return Status::OK(); +} + Status CommWrapperBase::SubmitPipelineTask(acl::P2PSendTask task) { - RETURN_RUNTIME_ERROR_IF_NULL(sender_); + RETURN_IF_NOT_OK(CheckTranPointer(sender_.get(), "sender_")); return sender_->Submit(std::move(task)); } Status CommWrapperBase::SubmitPipelineTask(acl::P2PRecvTask task) { - RETURN_RUNTIME_ERROR_IF_NULL(receiver_); + RETURN_IF_NOT_OK(CheckTranPointer(receiver_.get(), "receiver_")); return receiver_->Submit(std::move(task)); } diff --git a/src/datasystem/common/device/ascend/comm_wrapper_base.h b/src/datasystem/common/device/ascend/comm_wrapper_base.h index c5613b9..8a61fe3 100644 --- a/src/datasystem/common/device/ascend/comm_wrapper_base.h +++ b/src/datasystem/common/device/ascend/comm_wrapper_base.h @@ -35,16 +35,6 @@ enum class HcclCommState { UNCREATE, CREATING, VALID, INVALID, DESTROY }; enum class HcclCommDirection { SEND, RECV }; constexpr int WARM_UP_DATA_COUNT = 1; -/** - * @brief Return the data info string. - * @param[in] info The DataInfo object. - * @return The format string of the data info. - */ -inline std::string DataInfoToString(const DataInfo &info) -{ - return FormatString("Data info: dataType [%d], count [%llu]", static_cast(info.dataType), info.count); -} - class CommWrapperBase : public AclPointerWrapper { public: explicit CommWrapperBase(const std::string &commId, int localDeviceId, int remoteDeviceId, @@ -62,6 +52,34 @@ public: pool_->Execute(std::forward(f), std::forward(args)...); } + /** + * @brief Checks if the communicator is ready for collective operations. + * @return true if communicator is initialized and ready, false otherwise. + */ + bool IsCommReady() const; + + /** + * @brief Sets the communicator ready state and triggers ready callbacks when becoming ready. + * @param ready The new ready state to set. + * @note If transitioning from not-ready to ready state, all registered ready callbacks will be executed. + */ + void SetCommReady(bool ready); + + /** + * @brief Executes all registered ready callbacks in a thread-safe manner. + * @note Callbacks are moved to a local vector to minimize lock holding time. + * This ensures callbacks execute without holding the mutex. + */ + void ExecuteReadyCallbacks(); + + /** + * @brief Adds a callback to be executed when communicator becomes ready. + * @param callback The callback function to register. + * @note If communicator is already ready, the callback is executed immediately. + * Otherwise, it's queued for execution when SetCommReady(true) is called. + */ + void AddReadyCallback(std::function callback); + /** * @brief Get AclrtStream * @return The AclrtStream @@ -78,7 +96,7 @@ public: * @brief Get the status of hcclcomm. * @return The status of hcclcomm. */ - HcclResult GetDetailStatus() const; + Status GetDetailStatus() const; /** * @brief Get the lifetime state of hcclcomm. @@ -100,9 +118,9 @@ public: /** * @brief Sets the specific fault cause. - * @param[in] result Value of HcclResult + * @param[in] result The status of Hccl invocation */ - void SetHcclDetailState(HcclResult result); + void SetHcclDetailState(Status result); /** * @brief Check HcclComm health. @@ -124,30 +142,29 @@ public: /** * @brief P2P send the data to the receiving side. - * @param[in] dataInfos[in] The list of the data info. + * @param[in] blobs[in] The list of the blob info. * @param[in] comm[in] The hccl communicator. * @param[in] stream[in] The stream of acl context. * @return Status of the call */ - virtual Status P2PSend(const std::vector &dataInfos, const std::shared_ptr &event, + virtual Status P2PSend(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) = 0; /** * @brief P2P recv the data from the sending side. - * @param[in] dataInfos The list of the data info. + * @param[in] blobs The list of the blob info. * @param[in] comm The hccl communicator. * @param[in] stream The stream of acl context. * @return Status of the call */ - virtual Status P2PRecv(const std::vector &dataInfos, const std::shared_ptr &event, + virtual Status P2PRecv(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) = 0; /** * @brief Queries whether an error occurs in the communication domain. - * @return If the result is 0, no error occurs in the communication domain. For details about other return values, - * see HcclResult Type. + * @return The status of Hccl invocation */ - virtual HcclResult HcclGetCommAsyncError() = 0; + virtual Status HcclGetCommAsyncError() = 0; /** * @brief Init hccl communicator. @@ -187,6 +204,17 @@ public: Status SubmitPipelineTask(acl::P2PRecvTask task); private: + /** + * @brief Check if the communication pointer is valid and return corresponding error status if null. + * @param[in] pointer The communication pointer to be checked (sender_ or receiver_). + * @param[in] pointerName The name of the pointer for error message identification. + * @return Status::OK() if pointer is valid, otherwise returns error status with detailed message. + * @note This function is used to validate HCCL communication pointers that should be initialized + * during communication domain creation. A null pointer typically indicates HCCL communication + * domain creation failure. + */ + Status CheckTranPointer(const void *pointer, const std::string &pointerName); + acl::AclDeviceManager *aclImpl_; AclResourceManager *aclResourceMgr_; std::shared_ptr resource_; @@ -198,12 +226,18 @@ private: std::shared_ptr pool_; std::chrono::steady_clock::time_point commConnectTimestamp_; std::atomic hcclCommState_; - HcclResult hcclDetailState_ = HcclResult::HCCL_E_RESERVED; + Status hcclDetailState_; + mutable std::mutex hcclDetailStateMutex_; // protect hcclDetailState_ std::shared_ptr hcclThreadControl_; int bindThreadId_; std::mutex mutex_; bool hasShutDown_ = false; + std::atomic commReady_{false}; // Atomic flag indicating communication domain readiness + std::mutex stateMutex_; // Mutex protecting callback queue and execution state + std::vector> readyCallbacks_; // Queue of callbacks waiting for communication readiness + bool executingCallbacks_ = false; // Flag indicating if callbacks are currently being executed + friend class HcclCommWrapper; friend class P2PHcclCommWrapper; }; diff --git a/src/datasystem/common/device/ascend/ffts_dispatcher.cpp b/src/datasystem/common/device/ascend/ffts_dispatcher.cpp index 2f1e61b..3465000 100644 --- a/src/datasystem/common/device/ascend/ffts_dispatcher.cpp +++ b/src/datasystem/common/device/ascend/ffts_dispatcher.cpp @@ -191,7 +191,7 @@ HcclResult FftsDispatcher::ConstructFftsSqe(rtFftsPlusSqe_t &fftsPlusSqe, uint16 // 0x5A: identifies the communication task and optimizes the FFTS+ scheduling performance. const uint8_t TASK_TYPE_AIV_AIC = 0x5B; const uint8_t TASK_TYPE_OTHER = 0x5A; - fftsPlusSqe.subType = argsHandleList_.empty() ? TASK_TYPE_OTHER : TASK_TYPE_AIV_AIC; + fftsPlusSqe.subType = TASK_TYPE_OTHER; return HCCL_SUCCESS; } @@ -201,11 +201,6 @@ HcclResult FftsDispatcher::ConstructFftsTask(rtFftsPlusTaskInfo_t &task, rtFftsP task.descBuf = fftsCtxsPtr_->contexts.data(); task.descBufLen = sizeof(rtFftsPlusComCtx_t) * fftsCtxsPtr_->ctxNum; task.descAddrType = 0; - if (!argsHandleList_.empty()) { - task.argsHandleInfoNum = argsHandleList_.size(); - task.argsHandleInfoPtr = argsHandleList_.data(); - argsHandleList_.clear(); - } return HCCL_SUCCESS; } diff --git a/src/datasystem/common/device/ascend/ffts_dispatcher.h b/src/datasystem/common/device/ascend/ffts_dispatcher.h index d7594c0..d8a2b90 100644 --- a/src/datasystem/common/device/ascend/ffts_dispatcher.h +++ b/src/datasystem/common/device/ascend/ffts_dispatcher.h @@ -127,7 +127,6 @@ private: HcclFftsContextsInfo *fftsCtxsPtr_; std::vector fftsCtxs_; - std::vector argsHandleList_; int32_t devLogID_; int64_t chipId_; acl::AclDeviceManager *aclDeviceManager_; diff --git a/src/datasystem/common/device/ascend/hccl_comm_wrapper.cpp b/src/datasystem/common/device/ascend/hccl_comm_wrapper.cpp index 783c802..47a8f4a 100644 --- a/src/datasystem/common/device/ascend/hccl_comm_wrapper.cpp +++ b/src/datasystem/common/device/ascend/hccl_comm_wrapper.cpp @@ -69,47 +69,43 @@ Status HcclCommWrapper::InitHcclComm(int numRanks, HcclRootInfo &rootInfo, int r return rc; } -Status HcclCommWrapper::P2PSend(const std::vector &dataInfos, const std::shared_ptr &event, +Status HcclCommWrapper::P2PSend(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) { - LOG(INFO) << "hccl start to send " << (!dataInfos.empty() ? DataInfoToString(dataInfos[0]) : "") - << ", info num: " << dataInfos.size(); + LOG(INFO) << "hccl start to send " << (!blobs.empty() ? std::to_string(blobs[0].size) : "") + << ", info num: " << blobs.size(); (void)event; auto &comm = GetRef(); RETURN_IF_NOT_OK(CheckHcclCommPtr(comm)); - for (size_t i = 0; i < dataInfos.size(); i++) { - RETURN_IF_NOT_OK(aclImpl_->DSHcclSend(dataInfos[i].devPtr, dataInfos[i].count, - static_cast(dataInfos[i].dataType), P2P_RECV_RANK, comm, - stream)); + for (size_t i = 0; i < blobs.size(); i++) { + RETURN_IF_NOT_OK(aclImpl_->DSHcclSend(blobs[i].pointer, blobs[i].size, HcclDataType::HCCL_DATA_TYPE_INT8, + P2P_RECV_RANK, comm, stream)); } VLOG(1) << "Send hccl ok"; return Status::OK(); } -HcclResult HcclCommWrapper::HcclGetCommAsyncError() +Status HcclCommWrapper::HcclGetCommAsyncError() { - auto &comm = GetRef(); // Don't check if comm is creating. if (hcclCommState_ == HcclCommState::CREATING || hcclCommState_ == HcclCommState::UNCREATE) { - return HCCL_SUCCESS; + return Status::OK(); } - HcclResult asyncError; - aclImpl_->DSHcclGetCommAsyncError(comm, &asyncError); - return asyncError; + auto &comm = GetRef(); + return aclImpl_->DSHcclGetCommAsyncError(comm); } -Status HcclCommWrapper::P2PRecv(const std::vector &dataInfos, const std::shared_ptr &event, +Status HcclCommWrapper::P2PRecv(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) { - LOG(INFO) << "hccl receiving " << (!dataInfos.empty() ? DataInfoToString(dataInfos[0]) : "") - << ", info num: " << dataInfos.size(); + LOG(INFO) << "hccl receiving " << (!blobs.empty() ? std::to_string(blobs[0].size) : "") + << ", info num: " << blobs.size(); (void)event; auto &comm = GetRef(); RETURN_IF_NOT_OK(CheckHcclCommPtr(comm)); - for (size_t i = 0; i < dataInfos.size(); i++) { - RETURN_IF_NOT_OK(aclImpl_->DSHcclRecv(dataInfos[i].devPtr, dataInfos[i].count, - static_cast(dataInfos[i].dataType), P2P_SEND_RANK, comm, - stream)); + for (size_t i = 0; i < blobs.size(); i++) { + RETURN_IF_NOT_OK(aclImpl_->DSHcclRecv(blobs[i].pointer, blobs[i].size, HcclDataType::HCCL_DATA_TYPE_INT8, + P2P_SEND_RANK, comm, stream)); } VLOG(1) << "Recv hccl ok"; @@ -153,7 +149,9 @@ Status HcclCommWrapper::CreateRootInfo(HcclRootInfo &rootInfo) Status HcclCommWrapper::CheckHcclCommPtr(const void *ptr) { if (ptr == nullptr) { - return { K_RUNTIME_ERROR, "HcclComm is nullptr, create HCCL communication domain failed." }; + auto errorStatus = GetDetailStatus(); + return {K_RUNTIME_ERROR, + FormatString("HcclComm is nullptr, create HCCL communication domain failed. Detail:%s", errorStatus)}; } return Status::OK(); } diff --git a/src/datasystem/common/device/ascend/hccl_comm_wrapper.h b/src/datasystem/common/device/ascend/hccl_comm_wrapper.h index 5b02550..adf6795 100644 --- a/src/datasystem/common/device/ascend/hccl_comm_wrapper.h +++ b/src/datasystem/common/device/ascend/hccl_comm_wrapper.h @@ -57,29 +57,29 @@ public: Status InitCommunicator(HcclRootInfo &rootInfo, const HcclCommDirection direction, bool isSameNode) override; /** * @brief P2P send the data to the receiving side. - * @param[in] dataInfos[in] The list of the data info. + * @param[in] blobs The list of the blob info. * @param[in] event The aclRtEvent wrapper. * @return Status of the call */ - Status P2PSend(const std::vector &dataInfos, const std::shared_ptr &event, + + Status P2PSend(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) override; /** * @brief P2P recv the data from the sending side. - * @param[in] dataInfos The list of the data info. + * @param[in] blobs The list of the blob info. * @param[in] event The aclRtEvent wrapper. * @param[in] stream The stream of acl context. * @return Status of the call */ - Status P2PRecv(const std::vector &dataInfos, const std::shared_ptr &event, + Status P2PRecv(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) override; /** * @brief Queries whether an error occurs in the communication domain. - * @return If the result is 0, no error occurs in the communication domain. For details about other return values, - * see HcclResult Type. + * @return The status of Hccl invocation. */ - HcclResult HcclGetCommAsyncError() override; + Status HcclGetCommAsyncError() override; /** * @brief Warm up the hccl communicator wrapper in the send side. diff --git a/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.cpp b/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.cpp index 1d9f88f..09ce993 100644 --- a/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.cpp +++ b/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.cpp @@ -61,10 +61,10 @@ Status P2PHcclCommWrapper::InitP2PComm(const HcclRootInfo *rootInfo, P2pKind kin hcclCommState_ = HcclCommState::CREATING; Status rc; if (isSameNode) { - LOG(INFO) <<"InitP2PComm HCCS dir: "<< kind; + LOG(INFO) << "InitP2PComm HCCS dir: " << kind; rc = aclImpl_->DSP2PCommInitRootInfo(rootInfo, kind, P2pLink::P2P_LINK_HCCS, &GetRef()); } else { - LOG(INFO) <<"InitP2PComm ROCE dir: "<< kind; + LOG(INFO) << "InitP2PComm ROCE dir: " << kind; rc = aclImpl_->DSP2PCommInitRootInfo(rootInfo, kind, P2pLink::P2P_LINK_ROCE, &GetRef()); } @@ -73,17 +73,17 @@ Status P2PHcclCommWrapper::InitP2PComm(const HcclRootInfo *rootInfo, P2pKind kin return rc; } -Status P2PHcclCommWrapper::P2PSend(const std::vector &dataInfos, - const std::shared_ptr &event, aclrtStream stream) +Status P2PHcclCommWrapper::P2PSend(const std::vector &blobs, const std::shared_ptr &event, + aclrtStream stream) { - LOG(INFO) << "p2phccl start to send " << (dataInfos.size() > 0 ? DataInfoToString(dataInfos[0]) : "") - << ", info num: " << dataInfos.size(); + LOG(INFO) << "p2phccl start to send " << (blobs.size() > 0 ? std::to_string(blobs[0].size) : "") + << ", info num: " << blobs.size(); (void)event; auto &comm = GetRef(); if (comm == nullptr) { return { K_RUNTIME_ERROR, "HcclComm is nullptr" }; } - for (size_t i = 0; i < dataInfos.size(); i++) { + for (size_t i = 0; i < blobs.size(); i++) { auto injectTest = [] { INJECT_POINT("client.P2PSend.skip_DSHcclSend", [] { return true; }); return false; @@ -91,18 +91,18 @@ Status P2PHcclCommWrapper::P2PSend(const std::vector &dataInfos, if (injectTest()) { continue; } - RETURN_IF_NOT_OK(aclImpl_->DSP2PSend(dataInfos[i].devPtr, dataInfos[i].count, - static_cast(dataInfos[i].dataType), comm, stream)); + RETURN_IF_NOT_OK(aclImpl_->DSP2PSend(blobs[i].pointer, blobs[i].size, HcclDataType::HCCL_DATA_TYPE_INT8, + comm, stream)); } VLOG(1) << "Send hccl ok"; return Status::OK(); } -Status P2PHcclCommWrapper::P2PRecv(const std::vector &dataInfos, - const std::shared_ptr &event, aclrtStream stream) +Status P2PHcclCommWrapper::P2PRecv(const std::vector &blobs, const std::shared_ptr &event, + aclrtStream stream) { - LOG(INFO) << "p2phccl receiving " << (dataInfos.size() > 0 ? DataInfoToString(dataInfos[0]) : "") - << ", info num: " << dataInfos.size(); + LOG(INFO) << "p2phccl receiving " << (blobs.size() > 0 ? std::to_string(blobs[0].size) : "") + << ", info num: " << blobs.size(); auto &comm = GetRef(); if (comm == nullptr) { return { K_RUNTIME_ERROR, "HcclComm is nullptr" }; @@ -119,9 +119,9 @@ Status P2PHcclCommWrapper::P2PRecv(const std::vector &dataInfos, int eightS = 8; std::this_thread::sleep_for(std::chrono::seconds(eightS)); } - for (size_t i = 0; i < dataInfos.size(); i++) { - RETURN_IF_NOT_OK(aclImpl_->DSP2PRecv(dataInfos[i].devPtr, dataInfos[i].count, - static_cast(dataInfos[i].dataType), comm, stream)); + for (size_t i = 0; i < blobs.size(); i++) { + RETURN_IF_NOT_OK( + aclImpl_->DSP2PRecv(blobs[i].pointer, blobs[i].size, HcclDataType::HCCL_DATA_TYPE_INT8, comm, stream)); } RETURN_IF_NOT_OK(event->RecordEvent(stream)); @@ -129,16 +129,14 @@ Status P2PHcclCommWrapper::P2PRecv(const std::vector &dataInfos, return Status::OK(); } -HcclResult P2PHcclCommWrapper::HcclGetCommAsyncError() +Status P2PHcclCommWrapper::HcclGetCommAsyncError() { - auto &comm = GetRef(); // Don't check if comm is creating. if (hcclCommState_ == HcclCommState::CREATING || hcclCommState_ == HcclCommState::UNCREATE) { - return HCCL_SUCCESS; + return Status::OK(); } - HcclResult asyncError; - aclImpl_->DSP2PGetCommAsyncError(comm, &asyncError); - return asyncError; + auto &comm = GetRef(); + return aclImpl_->DSP2PGetCommAsyncError(comm); } Status P2PHcclCommWrapper::InitCommunicator(HcclRootInfo &rootInfo, const HcclCommDirection direction, bool isSameNode) diff --git a/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.h b/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.h index 61da877..66f7484 100644 --- a/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.h +++ b/src/datasystem/common/device/ascend/p2phccl_comm_wrapper.h @@ -51,30 +51,29 @@ public: /** * @brief P2P send the data to the receiving side. - * @param[in] dataInfos[in] The list of the data info. + * @param[in] blobs[in] The list of the blob info. * @param[in] comm[in] The hccl communicator. * @param[in] stream[in] The stream of acl context. * @return Status of the call */ - Status P2PSend(const std::vector &dataInfos, const std::shared_ptr &event, + Status P2PSend(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) override; /** * @brief P2P recv the data from the sending side. - * @param[in] dataInfos The list of the data info. + * @param[in] blobs The list of the blob info. * @param[in] comm The hccl communicator. * @param[in] streamThe stream of acl context. * @return Status of the call */ - Status P2PRecv(const std::vector &dataInfos, const std::shared_ptr &event, + Status P2PRecv(const std::vector &blobs, const std::shared_ptr &event, aclrtStream stream) override; /** * @brief Queries whether an error occurs in the communication domain. - * @return If the result is 0, no error occurs in the communication domain. For details about other return values, - * see HcclResult Type. + * @return The status of Hccl invocation. */ - HcclResult HcclGetCommAsyncError() override; + Status HcclGetCommAsyncError() override; /** * @brief Warm up the hccl communicator wrapper in the send side. diff --git a/src/datasystem/common/eventloop/timer_queue.cpp b/src/datasystem/common/eventloop/timer_queue.cpp index f24f38b..56b5329 100644 --- a/src/datasystem/common/eventloop/timer_queue.cpp +++ b/src/datasystem/common/eventloop/timer_queue.cpp @@ -159,7 +159,8 @@ Status TimerQueue::AddTimer(const uint64_t &durationMs, const std::function id(1); - uint64_t timeWatch = CurrentTimeMs() + durationMs; + auto currentTimeMs = CurrentTimeMs(); + uint64_t timeWatch = durationMs > UINT64_MAX - currentTimeMs ? UINT64_MAX : currentTimeMs + durationMs; timer = TimerImpl(id.fetch_add(1), timeWatch, timeOutCallBack); VLOG(DEBUG_LOG_LEVEL) << FormatString("AddTimer with delay %llu at expire time %llu, with id %llu", durationMs, timeWatch, timer.GetId()); diff --git a/src/datasystem/common/httpclient/http_request.cpp b/src/datasystem/common/httpclient/http_request.cpp index bbb3b4c..57681fc 100644 --- a/src/datasystem/common/httpclient/http_request.cpp +++ b/src/datasystem/common/httpclient/http_request.cpp @@ -91,7 +91,7 @@ void HttpRequest::ClearSensitiveInfo() void HttpRequest::SetAsyncElapseTime(int64_t asyncElapseTime) { - asyncElapse_ = static_cast(asyncElapseTime); + asyncElapse_ = asyncElapseTime >= 0 ? static_cast(asyncElapseTime) : 0; } uint64_t HttpRequest::GetAsyncElapseTime() diff --git a/src/datasystem/common/immutable_string/ref_count_string.cpp b/src/datasystem/common/immutable_string/ref_count_string.cpp index 6d322a3..f97c900 100644 --- a/src/datasystem/common/immutable_string/ref_count_string.cpp +++ b/src/datasystem/common/immutable_string/ref_count_string.cpp @@ -148,4 +148,4 @@ RefCountStringHandle &RefCountStringHandle::operator=(const RefCountStringHandle const std::string RefCountStringHandle::default_ = ""; -} // namespace datasystem \ No newline at end of file +} // namespace datasystem diff --git a/src/datasystem/common/kvstore/etcd/etcd_store.cpp b/src/datasystem/common/kvstore/etcd/etcd_store.cpp index 805cd80..35ade6a 100644 --- a/src/datasystem/common/kvstore/etcd/etcd_store.cpp +++ b/src/datasystem/common/kvstore/etcd/etcd_store.cpp @@ -41,12 +41,12 @@ DS_DEFINE_string(etcd_address, "", "Address of ETCD server"); DS_DEFINE_validator(etcd_address, &Validator::ValidateEtcdAddresses); -DS_DEFINE_string(other_az_names, "", "Specify other az names using the same etcd. Split by ','"); -DS_DEFINE_validator(other_az_names, &Validator::ValidateOtherAzNames); +DS_DEFINE_string(other_cluster_names, "", "Specify other az names using the same etcd. Split by ','"); +DS_DEFINE_validator(other_cluster_names, &Validator::ValidateOtherAzNames); DS_DECLARE_uint32(node_timeout_s); DS_DECLARE_uint32(node_dead_timeout_s); DS_DECLARE_bool(auto_del_dead_node); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); namespace datasystem { EtcdStore::EtcdStore(const std::string &address) : address_(address) @@ -156,15 +156,15 @@ Status EtcdStore::CreateTable(const std::string &tableName, const std::string &t std::lock_guard lck(mutex_); CHECK_FAIL_RETURN_STATUS(tableMap_.find(tableName) == tableMap_.end(), K_DUPLICATED, "The table already exists. tableName:" + tableName); - if (!FLAGS_az_name.empty()) { - tableMap_.emplace(tableName, "/" + FLAGS_az_name + tablePrefix); + if (!FLAGS_cluster_name.empty()) { + tableMap_.emplace(tableName, "/" + FLAGS_cluster_name + tablePrefix); } else { tableMap_.emplace(tableName, tablePrefix); } - if (!FLAGS_other_az_names.empty()) { - for (auto &azName : Split(FLAGS_other_az_names, ",")) { - if (azName != FLAGS_az_name) { + if (!FLAGS_other_cluster_names.empty()) { + for (auto &azName : Split(FLAGS_other_cluster_names, ",")) { + if (azName != FLAGS_cluster_name) { std::lock_guard lck(otherAzTblMutex_); otherAzTableMap_[tableName].emplace_back("/" + azName + tablePrefix); } diff --git a/src/datasystem/common/kvstore/rocksdb/replica.cpp b/src/datasystem/common/kvstore/rocksdb/replica.cpp index fcdfbed..326c57d 100644 --- a/src/datasystem/common/kvstore/rocksdb/replica.cpp +++ b/src/datasystem/common/kvstore/rocksdb/replica.cpp @@ -40,17 +40,24 @@ #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/uri.h" #include "datasystem/common/log/log.h" +#include "datasystem/master/stream_cache/store/stream_transform.h" #include "datasystem/utils/status.h" namespace datasystem { const size_t MAX_REPLICATION_BYTES = 32 * 1024 * 1024; +const std::string STREAM_META_NAME = "stream_meta_data"; const std::string OBJECT_META_NAME = "object_metadata"; namespace { std::unordered_map GetTableOptions() { + auto tableList = { STREAM_TABLE_NAME, PUB_TABLE_NAME, SUB_TABLE_NAME }; rocksdb::ColumnFamilyOptions prefixOption = rocksdb::ColumnFamilyOptions(); + prefixOption.prefix_extractor = std::make_shared(); std::unordered_map tableOptions; + for (auto table : tableList) { + tableOptions[table] = prefixOption; + } return tableOptions; } } // namespace @@ -78,7 +85,9 @@ Status Replica::Init() scStore_ = ocStore_; } else { std::string objectPath = dbPath_ + "/" + OBJECT_META_NAME; + std::string streamPath = dbPath_ + "/" + STREAM_META_NAME; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitRocksStore(objectPath, ocStore_), "InitRocksStore for object failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitRocksStore(streamPath, scStore_), "InitRocksStore for stream failed"); } return Status::OK(); } @@ -94,27 +103,47 @@ Status Replica::CreateOcTable(RocksStore *store) std::vector tables; RETURN_IF_NOT_OK(store->ListTables(tables)); VLOG(1) << "Existing tables in rocksdb: " << VectorToString(tables); - RETURN_IF_NOT_OK(CreateTable(META_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(LOCATION_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(NESTED_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(NESTED_COUNT_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(ASYNC_WORKER_OP_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(GLOBAL_REF_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(GLOBAL_CACHE_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(REMOTE_CLIENT_OBJ_REF_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(REMOTE_CLIENT_REF_TABLE, store, tables)); - RETURN_IF_NOT_OK(CreateTable(HEALTH_TABLE, store, tables)); + RETURN_IF_NOT_OK(CreateTable(META_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(LOCATION_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(NESTED_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(NESTED_COUNT_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(ASYNC_WORKER_OP_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(GLOBAL_REF_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(GLOBAL_CACHE_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(REMOTE_CLIENT_OBJ_REF_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(REMOTE_CLIENT_REF_TABLE, store, false, tables)); + RETURN_IF_NOT_OK(CreateTable(HEALTH_TABLE, store, false, tables)); LOG(INFO) << FormatString("CreateOcTable success."); return Status::OK(); } -Status Replica::CreateTable(const std::string &tableName, RocksStore *store, std::vector &tables) +Status Replica::CreateScTable(RocksStore *store) +{ + std::vector tables; + RETURN_IF_NOT_OK(store->ListTables(tables)); + VLOG(1) << "Existing tables in rocksdb: " << VectorToString(tables); + RETURN_IF_NOT_OK(CreateTable(STREAM_TABLE_NAME, store, true, tables)); + RETURN_IF_NOT_OK(CreateTable(PUB_TABLE_NAME, store, true, tables)); + RETURN_IF_NOT_OK(CreateTable(SUB_TABLE_NAME, store, true, tables)); + RETURN_IF_NOT_OK(CreateTable(NOTIFY_PUB_TABLE_NAME, store, true, tables)); + RETURN_IF_NOT_OK(CreateTable(NOTIFY_SUB_TABLE_NAME, store, true, tables)); + RETURN_IF_NOT_OK(CreateTable(STREAM_CON_CNT_TABLE_NAME, store, true, tables)); + RETURN_IF_NOT_OK(CreateTable(STREAM_PRODUCER_COUNT, store, true, tables)); + LOG(INFO) << FormatString("CreateScTable success."); + return Status::OK(); +} + +Status Replica::CreateTable(const std::string &tableName, RocksStore *store, bool isSc, + std::vector &tables) { VLOG(1) << "tableName:" << tableName; bool exits = (std::find(tables.begin(), tables.end(), tableName) != tables.end()); if (!exits) { tables.emplace_back(tableName); auto options = rocksdb::ColumnFamilyOptions(); + if (isSc) { + options.prefix_extractor = std::make_shared(); + } RETURN_IF_NOT_OK(store->CreateTable(tableName, options)); LOG(INFO) << FormatString("Create table { %s } successfully.", tableName); } @@ -133,6 +162,7 @@ Status Replica::CreateRocksStoreInstanceAndTable(const std::string &dbPath, std: { RETURN_IF_NOT_OK(CreateRocksStoreInstance(dbPath, store)); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CreateOcTable(store.get()), "Replica create oc table failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CreateScTable(store.get()), "Replica create sc table failed"); return Status::OK(); } @@ -356,7 +386,10 @@ Status Replica::RemoveRocksFromFileSystem(const std::string &dbRootPath, bool mu RETURN_IF_NOT_OK(RemoveAll(path)); } else { std::string objectPath = dbRootPath + "/" + OBJECT_META_NAME; + std::string streamPath = dbRootPath + "/" + STREAM_META_NAME; + LOG(INFO) << "try remove path " << objectPath << " and " << streamPath; RETURN_IF_NOT_OK(RemoveAll(objectPath)); + RETURN_IF_NOT_OK(RemoveAll(streamPath)); } return Status::OK(); } diff --git a/src/datasystem/common/kvstore/rocksdb/replica.h b/src/datasystem/common/kvstore/rocksdb/replica.h index 990f13c..95189a2 100644 --- a/src/datasystem/common/kvstore/rocksdb/replica.h +++ b/src/datasystem/common/kvstore/rocksdb/replica.h @@ -97,14 +97,22 @@ public: */ static Status CreateOcTable(RocksStore *store); + /** + * @brief Create stream cache tables. + * @return Status of this call. + */ + static Status CreateScTable(RocksStore *store); + /** * @brief Create table in rocksdb. * @param[in] table The Rocksdb table name to be created. * @param[in] store The rocks store pointer. + * @param[in] isSc Create for stream cache. * @param[in] tables The tables currently in rocksdb. * @return Status of the call. */ - static Status CreateTable(const std::string &tableName, RocksStore *store, std::vector &tables); + static Status CreateTable(const std::string &tableName, RocksStore *store, bool isSc, + std::vector &tables); /** * @brief Get the object rocks store instance. @@ -115,6 +123,15 @@ public: return ocStore_.get(); } + /** + * @brief Get the stream rocks store instance. + * @return The rocks store instance. + */ + RocksStore *GetStreamRocksStore() + { + return scStore_.get(); + } + /** * @brief Get the rocks db path. * @return The rocks db path. diff --git a/src/datasystem/common/kvstore/rocksdb/rocks_store.cpp b/src/datasystem/common/kvstore/rocksdb/rocks_store.cpp index 780bc4e..1bf521d 100644 --- a/src/datasystem/common/kvstore/rocksdb/rocks_store.cpp +++ b/src/datasystem/common/kvstore/rocksdb/rocks_store.cpp @@ -28,19 +28,26 @@ #include "datasystem/common/flags/flags.h" #include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/util/format.h" #include "datasystem/common/util/strings_util.h" #include "datasystem/common/util/thread_local.h" #include "datasystem/common/util/timer.h" #include "datasystem/common/util/uri.h" #include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/validator.h" #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/log/log.h" #include "datasystem/utils/connection.h" +#include "datasystem/utils/status.h" DS_DEFINE_bool(rocksdb_sync_write, false, "Controls whether rocksdb sets sync to true when writing data."); DS_DEFINE_int32(rocksdb_max_open_file, 128, "Number of open files that can be used by the rocksdb"); DS_DEFINE_int32(rocksdb_background_threads, 16, "Number of background threads rocksdb can use for flushing and compacting."); +DS_DEFINE_string(rocksdb_write_mode, "async", + "Config the rocksdb support none, sync or async, async by default. Optional value: " + "'none', 'sync', 'async'. This represents the method of writing metadata to rocksdb."); +DS_DEFINE_validator(rocksdb_write_mode, &Validator::ValidateRocksdbModeType); namespace datasystem { std::mutex RocksStore::lck; @@ -73,10 +80,16 @@ void InitRocksOptions(rocksdb::Options &options) } } // namespace +RocksStore::RocksStore() +{ + mode_ = ParseRocksdbWriteMode(); + InitializeAsyncThreadPool(); +} + std::shared_ptr RocksStore::GetInstance( const std::string &dbPath, const std::unordered_map &tableOptions) { - INJECT_POINT("master.disableRocksDb", [] () { + INJECT_POINT("master.disableRocksDb", []() { RocksStore::disableRocksDB = true; return nullptr; }); @@ -129,8 +142,7 @@ std::shared_ptr RocksStore::GetInstance( LOG(ERROR) << "Cannot create/open database: " + std::string(rc.getState()); return nullptr; } - - LOG(INFO) << "Rocksdb get instance finished, dbPath:" << dbPath; + LOG(INFO) << "Rocksdb get instance finished, dbPath:" << dbPath << "write mode: " << FLAGS_rocksdb_write_mode; return instance; } @@ -152,12 +164,47 @@ void RocksStore::Close() db_->Close(); delete db_; db_ = nullptr; + if (asyncThreadPool_) { + asyncThreadPool_.reset(); + } +} + +RocksdbWriteMode RocksStore::ParseRocksdbWriteMode() +{ + if (FLAGS_rocksdb_write_mode == "async") { + return RocksdbWriteMode::ASYNC; + } else if (FLAGS_rocksdb_write_mode == "sync") { + return RocksdbWriteMode::SYNC; + } else if (FLAGS_rocksdb_write_mode == "none") { + return RocksdbWriteMode::NONE; + } else { + LOG(INFO) << FormatString("Rocksdb write mode is : % s, will use none mode instead.", FLAGS_rocksdb_write_mode); + return RocksdbWriteMode::NONE; + } +} + +bool RocksStore::IsClusterInfoTable(const std::string &tableName) +{ + auto it = std::find(clusterInfoTable_.begin(), clusterInfoTable_.end(), tableName); + if (it != clusterInfoTable_.end()) { + return true; + } + return false; +} + +void RocksStore::InitializeAsyncThreadPool(size_t threadCount) +{ + if (!asyncThreadPool_ && mode_ == RocksdbWriteMode::ASYNC) { + asyncThreadPool_ = std::make_unique(threadCount); + LOG(INFO) << "Init rocksdb async thread pool."; + } } Status RocksStore::CreateTable(const std::string &tableName, const rocksdb::ColumnFamilyOptions &tableOptions, rocksdb::ColumnFamilyHandle **tableHandle) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); rocksdb::ColumnFamilyHandle *cf = nullptr; rocksdb::Status rc; auto item = tables_.find(tableName); @@ -183,6 +230,7 @@ Status RocksStore::CreateTable(const std::string &tableName, const rocksdb::Colu Status RocksStore::DropTable(const std::string &tableName) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); rocksdb::Status rc; rocksdb::ColumnFamilyHandle *tableHandle = nullptr; @@ -205,6 +253,7 @@ Status RocksStore::DropTable(const std::string &tableName) Status RocksStore::Put(const std::string &tableName, const std::string &key, const std::string &value) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); rocksdb::Status rc; CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_NOT_FOUND, "Database does not exist"); rocksdb::ColumnFamilyHandle *tableHandle = nullptr; @@ -215,37 +264,72 @@ Status RocksStore::Put(const std::string &tableName, const std::string &key, con CHECK_FAIL_RETURN_STATUS(iter != tables_.end(), StatusCode::K_NOT_FOUND, "Table " + tableName + " does not exist"); tableHandle = iter->second; Timer timer; - rc = Put(tableHandle, key, value, FLAGS_rocksdb_sync_write); + if (mode_ == RocksdbWriteMode::SYNC) { + rc = Put(tableHandle, key, value, FLAGS_rocksdb_sync_write); + masterOperationTimeCost.Append("RocksDB Put", timer.ElapsedMilliSecond()); + return CheckAndRemoveDbPath(rc); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(key, [this, tableHandle, key, value]() { + rocksdb::Status rc = Put(tableHandle, key, value, FLAGS_rocksdb_sync_write); + if (!rc.ok()) { + LOG(ERROR) << FormatString("Async Put key %s failed: %s", key, rc.getState()); + } + }); + return Status::OK(); + } masterOperationTimeCost.Append("RocksDB Put", timer.ElapsedMilliSecond()); - return CheckAndRemoveDbPath(rc); + return Status::OK(); } Status RocksStore::BatchPut(const std::string &tableName, std::unordered_map &metaInfos) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); rocksdb::Status rc; CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_RUNTIME_ERROR, "Database does not exist"); rocksdb::ColumnFamilyHandle *tableHandle = nullptr; auto iter = tables_.find(tableName); CHECK_FAIL_RETURN_STATUS(iter != tables_.end(), StatusCode::K_NOT_FOUND, "Table " + tableName + " does not exist"); tableHandle = iter->second; + if (mode_ == RocksdbWriteMode::SYNC) { + rc = BatchPut(metaInfos, tableHandle, FLAGS_rocksdb_sync_write); + return CheckAndRemoveDbPath(rc); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(tableName, [this, tableHandle, metaInfos]() { + rocksdb::Status rc = BatchPut(metaInfos, tableHandle, FLAGS_rocksdb_sync_write); + if (!rc.ok()) { + FormatString("Async BatchPut failed: %s", rc.getState()); + } + }); - rc = BatchPut(metaInfos, tableHandle, FLAGS_rocksdb_sync_write); - return CheckAndRemoveDbPath(rc); + return Status::OK(); + } + return Status::OK(); } Status RocksStore::BatchDelete(const std::string &tableName, std::unordered_map &metaInfos) { RETURN_OK_IF_TRUE(disableRocksDB); - rocksdb::Status rc; + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_NOT_FOUND, "Database does not exist"); rocksdb::ColumnFamilyHandle *tableHandle = nullptr; auto iter = tables_.find(tableName); CHECK_FAIL_RETURN_STATUS(iter != tables_.end(), StatusCode::K_NOT_FOUND, "Table " + tableName + " does not exist"); tableHandle = iter->second; - - rc = BatchDelete(metaInfos, tableHandle, FLAGS_rocksdb_sync_write); - return CheckAndRemoveDbPath(rc); + if (mode_ == RocksdbWriteMode::SYNC) { + rocksdb::Status rc = BatchDelete(metaInfos, tableHandle, FLAGS_rocksdb_sync_write); + return CheckAndRemoveDbPath(rc); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(tableName, [this, tableHandle, metaInfos]() { + sleep(10); + rocksdb::Status rc = BatchDelete(metaInfos, tableHandle, FLAGS_rocksdb_sync_write); + if (!rc.ok()) { + LOG(ERROR) << FormatString("Async BatchDelete failed: %s", rc.getState()); + } + }); + return Status::OK(); + } + return Status::OK(); } Status RocksStore::CheckAndRemoveDbPath(rocksdb::Status rc) @@ -275,7 +359,7 @@ Status RocksStore::ListTables(std::vector &tables) if (!rc.ok()) { // Set the database handle to a null pointer. db_ = nullptr; - RETURN_STATUS_LOG_ERROR(StatusCode::K_KVSTORE_ERROR, "Error when listing table names:" + rc.ToString()); + RETURN_STATUS_LOG_ERROR(StatusCode::K_KVSTORE_ERROR, "Error when listing table names: " + rc.ToString()); } return Status::OK(); } @@ -283,7 +367,7 @@ Status RocksStore::ListTables(std::vector &tables) Status RocksStore::Get(const std::string &tableName, const std::string &key, std::string &value) { RETURN_OK_IF_TRUE(disableRocksDB); - rocksdb::Status rc; + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_RUNTIME_ERROR, "Database does not exist"); rocksdb::ColumnFamilyHandle *tableHandle = nullptr; @@ -293,7 +377,18 @@ Status RocksStore::Get(const std::string &tableName, const std::string &key, std PerfPoint point(PerfKey::ROCKSDB_GET); tableHandle = iter->second; Timer timer; - rc = Get(tableHandle, key, value); + rocksdb::Status rc; + if (mode_ == RocksdbWriteMode::SYNC) { + rc = Get(tableHandle, key, value); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(key, [this, key, tableHandle, &rc, &value]() { + rc = Get(tableHandle, key, value); + if (!rc.ok()) { + LOG(ERROR) << FormatString("Async Get key %s failed: %s", key, rc.getState()); + } + }); + future.wait(); + } masterOperationTimeCost.Append("RocksDB Get", timer.ElapsedMilliSecond()); if (!rc.ok()) { if (rc == rocksdb::Status::NotFound()) { @@ -308,6 +403,7 @@ Status RocksStore::Get(const std::string &tableName, const std::string &key, std Status RocksStore::GetAll(const std::string &tableName, std::vector> &outKeyValues) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); outKeyValues.clear(); rocksdb::Status rc; CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_NOT_FOUND, "Database does not exist"); @@ -320,13 +416,28 @@ Status RocksStore::GetAll(const std::string &tableName, std::vectorsecond; auto readOptions = rocksdb::ReadOptions(); Timer timer; - std::unique_ptr iter2(db_->NewIterator(readOptions, tableHandle)); - masterOperationTimeCost.Append("RocksDB GetAll", timer.ElapsedMilliSecond()); - iter2->SeekToFirst(); - while (iter2->Valid()) { - outKeyValues.emplace_back(std::make_pair(iter2->key().ToString(), iter2->value().ToString())); - iter2->Next(); + if (mode_ == RocksdbWriteMode::SYNC) { + std::unique_ptr iter2(db_->NewIterator(readOptions, tableHandle)); + iter2->SeekToFirst(); + while (iter2->Valid()) { + outKeyValues.emplace_back(std::make_pair(iter2->key().ToString(), iter2->value().ToString())); + iter2->Next(); + } + masterOperationTimeCost.Append("RocksDB GetAll", timer.ElapsedMilliSecond()); + return Status::OK(); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(tableName, [this, readOptions, tableHandle, &outKeyValues]() { + std::unique_ptr iter2(db_->NewIterator(readOptions, tableHandle)); + iter2->SeekToFirst(); + while (iter2->Valid()) { + outKeyValues.emplace_back(std::make_pair(iter2->key().ToString(), iter2->value().ToString())); + iter2->Next(); + } + }); + future.wait(); + return Status::OK(); } + masterOperationTimeCost.Append("RocksDB GetAll", timer.ElapsedMilliSecond()); return Status::OK(); } @@ -334,6 +445,7 @@ Status RocksStore::PrefixSearch(const std::string &tableName, const std::string std::vector> &outKeyValues) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); rocksdb::Status rc; CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_NOT_FOUND, "Database does not exist"); rocksdb::ColumnFamilyHandle *tableHandle = nullptr; @@ -355,13 +467,23 @@ Status RocksStore::PrefixSearch(const std::string &tableName, const std::string CHECK_FAIL_RETURN_STATUS(result, StatusCode::K_KVSTORE_ERROR, "Table " + tableName + " was created with a prefix search pattern longer than " + "the input pattern \"" + prefixKey + "\""); - PrefixSearch(tableHandle, prefixKey, outKeyValues); + if (mode_ == RocksdbWriteMode::SYNC) { + PrefixSearch(tableHandle, prefixKey, outKeyValues); + return Status::OK(); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(tableName, [this, prefixKey, tableHandle, &outKeyValues]() { + PrefixSearch(tableHandle, prefixKey, outKeyValues); + }); + future.wait(); + return Status::OK(); + } return Status::OK(); } Status RocksStore::Delete(const std::string &tableName, const std::string &key) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); rocksdb::Status rc; CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_NOT_FOUND, "Database does not exist"); rocksdb::ColumnFamilyHandle *tableHandle = nullptr; @@ -372,15 +494,28 @@ Status RocksStore::Delete(const std::string &tableName, const std::string &key) tableHandle = iter->second; Timer timer; PerfPoint point(PerfKey::ROCKSDB_DELETE); - rc = Delete(tableHandle, key, FLAGS_rocksdb_sync_write); + if (mode_ == RocksdbWriteMode::SYNC) { + rc = Delete(tableHandle, key, FLAGS_rocksdb_sync_write); + masterOperationTimeCost.Append("RocksDB Delete", timer.ElapsedMilliSecond()); + // Deleting a key that does not exist in the database will NOT yield an error. + return CheckAndRemoveDbPath(rc); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(key, [this, tableHandle, key]() { + rocksdb::Status rc = Delete(tableHandle, key, FLAGS_rocksdb_sync_write); + if (!rc.ok()) { + LOG(ERROR) << FormatString("Async Delete key %s failed: %s", key, rc.getState()); + } + }); + return Status::OK(); + } masterOperationTimeCost.Append("RocksDB Delete", timer.ElapsedMilliSecond()); - // Deleting a key that does not exist in the database will NOT yield an error. - return CheckAndRemoveDbPath(rc); + return Status::OK(); } Status RocksStore::PrefixDelete(const std::string &tableName, const std::string &prefixKey) { RETURN_OK_IF_TRUE(disableRocksDB); + RETURN_OK_IF_TRUE(mode_ == RocksdbWriteMode::NONE); rocksdb::Status rc; CHECK_FAIL_RETURN_STATUS(db_, StatusCode::K_NOT_FOUND, "Database does not exist"); rocksdb::ColumnFamilyHandle *tableHandle = nullptr; @@ -394,12 +529,26 @@ Status RocksStore::PrefixDelete(const std::string &tableName, const std::string rocksdb::WriteOptions options; options.sync = FLAGS_rocksdb_sync_write; Timer timer; - rc = db_->DeleteRange(options, tableHandle, rocksdb::Slice(prefixKey), rocksdb::Slice(endKey)); - masterOperationTimeCost.Append("RocksDB DeleteRange", timer.ElapsedMilliSecond()); - if (rc != rocksdb::Status::OK()) { - RETURN_STATUS(StatusCode::K_KVSTORE_ERROR, - "Cannot delete prefix key: " + prefixKey + " Error: " + rc.ToString()); + if (mode_ == RocksdbWriteMode::SYNC) { + rc = db_->DeleteRange(options, tableHandle, rocksdb::Slice(prefixKey), rocksdb::Slice(endKey)); + masterOperationTimeCost.Append("RocksDB DeleteRange", timer.ElapsedMilliSecond()); + if (rc != rocksdb::Status::OK()) { + RETURN_STATUS(StatusCode::K_KVSTORE_ERROR, + "Cannot delete prefix key: " + prefixKey + " Error: " + rc.ToString()); + } + return Status::OK(); + } else if (mode_ == RocksdbWriteMode::ASYNC) { + auto future = asyncThreadPool_->Submit(tableName, [this, tableHandle, options, prefixKey, endKey]() { + rocksdb::Status rc = + db_->DeleteRange(options, tableHandle, rocksdb::Slice(prefixKey), rocksdb::Slice(endKey)); + if (!rc.ok()) { + LOG(ERROR) << FormatString("Async PrefixDelete prefixKey %s to endKey %s failed: %s", prefixKey, endKey, + rc.getState()); + } + }); + return Status::OK(); } + masterOperationTimeCost.Append("RocksDB DeleteRange", timer.ElapsedMilliSecond()); return Status::OK(); } } // namespace datasystem diff --git a/src/datasystem/common/kvstore/rocksdb/rocks_store.h b/src/datasystem/common/kvstore/rocksdb/rocks_store.h index 9c4ca58..e47275f 100644 --- a/src/datasystem/common/kvstore/rocksdb/rocks_store.h +++ b/src/datasystem/common/kvstore/rocksdb/rocks_store.h @@ -20,10 +20,13 @@ #ifndef DATASYSTEM_COMMON_KVSTORE_ROCKS_STORE_H #define DATASYSTEM_COMMON_KVSTORE_ROCKS_STORE_H +#include #include #include +#include #include #include +#include #include #include @@ -31,10 +34,14 @@ #include "rocksdb/options.h" #include "rocksdb/slice_transform.h" +#include "datasystem/common/constants.h" #include "datasystem/common/kvstore/kv_store.h" +#include "datasystem/common/util/thread_pool.h" #include "datasystem/utils/status.h" namespace datasystem { + +enum class RocksdbWriteMode { ASYNC, SYNC, NONE }; class RocksStore : public KvStore { public: /** @@ -122,7 +129,7 @@ public: * @param[in] sync The sync mode to delete or not. * @return Status of the call. */ - inline rocksdb::Status BatchPut(std::unordered_map &metaInfos, + inline rocksdb::Status BatchPut(const std::unordered_map &metaInfos, rocksdb::ColumnFamilyHandle *tableHandle, bool sync = false) { rocksdb::WriteOptions options; @@ -141,7 +148,7 @@ public: * @param[in] sync The sync mode to delete or not. * @return Status of the call. */ - inline rocksdb::Status BatchDelete(std::unordered_map &metaInfos, + inline rocksdb::Status BatchDelete(const std::unordered_map &metaInfos, rocksdb::ColumnFamilyHandle *tableHandle, bool sync = false) { rocksdb::WriteOptions options; @@ -269,15 +276,45 @@ public: RocksStore(const RocksStore &&) = delete; RocksStore &operator=(const RocksStore &&) = delete; + bool IsAsyncQueueEmpty() + { + if (!asyncThreadPool_) { + return true; + } + return asyncThreadPool_->AreAllQueuesEmpty(); + } + private: // Make the constructor private to force the user to call GetInstance to open a read-only RocksDB database. - RocksStore() = default; + RocksStore(); + + /** + * @brief Check whether the table is a cluster table. + * @param[in] tableName The table need to check. + * @return True if the table is a cluster table. + */ + bool IsClusterInfoTable(const std::string &tableName); + + /** + * @brief Init async thread pool. + * @param[in] threadCount The thread numer of async thread pool. + */ + void InitializeAsyncThreadPool(size_t threadCount = 16); + + /** + * @brief Parse rocksdb write mode. + * @return RocksdbWriteMode enum. + */ + RocksdbWriteMode ParseRocksdbWriteMode(); rocksdb::DB *db_ = nullptr; std::string dbPath_; std::unordered_map tables_; static std::mutex lck; static bool disableRocksDB; + std::vector clusterInfoTable_ = { CLUSTER_TABLE, HASHRING_TABLE, REPLICA_GROUP_TABLE, HEALTH_TABLE }; + std::unique_ptr asyncThreadPool_; + RocksdbWriteMode mode_; }; } // namespace datasystem #endif // DATASYSTEM_COMMON_KVSTORE_ROCKS_STORE_H diff --git a/src/datasystem/common/log/access_point.def b/src/datasystem/common/log/access_point.def index 098c009..38331fa 100644 --- a/src/datasystem/common/log/access_point.def +++ b/src/datasystem/common/log/access_point.def @@ -14,6 +14,13 @@ ACCESS_RECORDER_KEY_DEF(DS_OBJECT_CLIENT_CREATE, CLIENT) ACCESS_RECORDER_KEY_DEF(DS_OBJECT_CLIENT_QUERY_GLOBAL_REF_NUM, CLIENT) ACCESS_RECORDER_KEY_DEF(DS_OBJECT_CLIENT_PUBLISH, CLIENT) +ACCESS_RECORDER_KEY_DEF(DS_STREAM_CREATE_PRODUCER, CLIENT) +ACCESS_RECORDER_KEY_DEF(DS_STREAM_CLOSE_PRODUCER, CLIENT) +ACCESS_RECORDER_KEY_DEF(DS_STREAM_SUBSCRIBE, CLIENT) +ACCESS_RECORDER_KEY_DEF(DS_STREAM_CLOSE_CONSUMER, CLIENT) +ACCESS_RECORDER_KEY_DEF(DS_STREAM_DELETE_STREAM, CLIENT) +ACCESS_RECORDER_KEY_DEF(DS_STREAM_QUERY_PRODUCERS_NUM, CLIENT) +ACCESS_RECORDER_KEY_DEF(DS_STREAM_QUERY_CONSUMERS_NUM, CLIENT) ACCESS_RECORDER_KEY_DEF(DS_HETERO_CLIENT_SHUTDOWN, CLIENT) ACCESS_RECORDER_KEY_DEF(DS_HETERO_CLIENT_INIT, CLIENT) ACCESS_RECORDER_KEY_DEF(DS_HETERO_CLIENT_MGETH2D, CLIENT) diff --git a/src/datasystem/common/log/log.h b/src/datasystem/common/log/log.h index d465848..800ce35 100644 --- a/src/datasystem/common/log/log.h +++ b/src/datasystem/common/log/log.h @@ -28,6 +28,7 @@ #include "datasystem/common/log/spdlog/log_param.h" DS_DECLARE_int32(v); +DS_DECLARE_int32(minloglevel); namespace datasystem { #define DS_LOGS_LEVEL_INFO datasystem::LogSeverity::INFO @@ -37,13 +38,16 @@ namespace datasystem { static constexpr int32_t HEARTBEAT_LEVEL = 3; // Heartbeat log level -// Basic Logging Macros -#define LOG(severity) datasystem::LogMessage(DS_LOGS_LEVEL_##severity, __FILE__, __LINE__).Stream() +// Basic Logging Macros Impl +#define LOG_IMPL(severity) datasystem::LogMessage(DS_LOGS_LEVEL_##severity, __FILE__, __LINE__).Stream() // Conditional Logging Macros #define LOG_IF(severity, condition) \ if (condition) \ - LOG(severity) + LOG_IMPL(severity) + +// Basic Logging Macros +#define LOG(severity) LOG_IF(severity, FLAGS_minloglevel <= DS_LOGS_LEVEL_##severity) // Frequency-Controlled Logging Macros #define LOG_EVERY_N(severity, n) \ @@ -57,7 +61,7 @@ static constexpr int32_t HEARTBEAT_LEVEL = 3; // Heartbeat log level auto LOG_EVERY_T_ELAPSED_##__LINE__ = std::chrono::duration_cast( \ LOG_EVERY_T_NOW_##__LINE__ - LOG_EVERY_T_LAST_TIME_##__LINE__) \ .count(); \ - if (LOG_EVERY_T_ELAPSED_##__LINE__ >= (seconds) * 1000 \ + if (LOG_EVERY_T_ELAPSED_##__LINE__ >= (seconds)*1000 \ && (LOG_EVERY_T_LAST_TIME_##__LINE__ = LOG_EVERY_T_NOW_##__LINE__, true)) \ LOG(severity) diff --git a/src/datasystem/common/log/log_helper.h b/src/datasystem/common/log/log_helper.h index 4a79847..37b2703 100644 --- a/src/datasystem/common/log/log_helper.h +++ b/src/datasystem/common/log/log_helper.h @@ -55,10 +55,17 @@ public: if (tokenField != nullptr) { copy.GetReflection()->SetString(©, tokenField, "***"); } - auto accessKeyField = copy.descriptor()->FindFieldByName("signature"); + + auto signatureField = copy.descriptor()->FindFieldByName("signature"); + if (signatureField != nullptr) { + copy.GetReflection()->SetString(©, signatureField, "***"); + } + + auto accessKeyField = copy.descriptor()->FindFieldByName("access_key"); if (accessKeyField != nullptr) { copy.GetReflection()->SetString(©, accessKeyField, "***"); } + return copy.ShortDebugString(); } diff --git a/src/datasystem/common/log/logging.cpp b/src/datasystem/common/log/logging.cpp index af6a769..815d324 100644 --- a/src/datasystem/common/log/logging.cpp +++ b/src/datasystem/common/log/logging.cpp @@ -90,7 +90,7 @@ DS_DEFINE_uint32(log_async_queue_size, DEFAULT_LOG_ASYNC_QUEUE_SIZE, "Size of as DS_DEFINE_validator(log_filename, &Validator::ValidateEligibleChar); DS_DECLARE_bool(log_monitor); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); using namespace std::chrono; @@ -372,7 +372,7 @@ Status Logging::WriteLogToFile(int lineOfCode, const std::string &fileNameOfCode auto pos = fileNameOfCode.find_last_of('/'); std::string name = pos == std::string::npos ? fileNameOfCode : fileNameOfCode.substr(pos + 1); ConstructLogPrefix(ss, logTime.getTm(), logTime.getUsec(), name.c_str(), lineOfCode, podName_.c_str(), level, - FLAGS_az_name); + FLAGS_cluster_name); ss << message; if (message.empty() || message[message.size() - 1] != '\n') { ss << '\n'; diff --git a/src/datasystem/common/log/spdlog/CMakeLists.txt b/src/datasystem/common/log/spdlog/CMakeLists.txt index afcc062..b72599a 100644 --- a/src/datasystem/common/log/spdlog/CMakeLists.txt +++ b/src/datasystem/common/log/spdlog/CMakeLists.txt @@ -8,7 +8,7 @@ set(SPDLOG_SRCS set(APPEND_LIB nlohmann_json::nlohmann_json - spdlog::spdlog + ds_spdlog::spdlog ds_flags common_perf ) diff --git a/src/datasystem/common/log/spdlog/log_message_impl.cpp b/src/datasystem/common/log/spdlog/log_message_impl.cpp index bdb258e..39d62ef 100644 --- a/src/datasystem/common/log/spdlog/log_message_impl.cpp +++ b/src/datasystem/common/log/spdlog/log_message_impl.cpp @@ -35,9 +35,10 @@ #include "datasystem/common/log/trace.h" DS_DEFINE_int32(v, 0, "Show all VLOG(m) messages for m <= this."); -DS_DEFINE_string(az_name, "", - "az_name is typically used in scenarios where multiple AZ datasystem share a single etcd cluster, " - "allowing different clusters to be distinguished by the az_name."); +DS_DEFINE_string( + cluster_name, "", + "cluster_name is typically used in scenarios where multiple AZ datasystem share a single etcd cluster, " + "allowing different clusters to be distinguished by the cluster_name."); namespace datasystem { // thread_local for store log info @@ -72,7 +73,7 @@ static void AppendLogMessageImplPrefix(const std::string &podName, std::ostream static thread_local pid_t tid = syscall(__NR_gettid); logStream << podName << " | " << pid << ":" << tid << " | " << Trace::Instance().GetTraceID() << " | " - << FLAGS_az_name << " | "; + << FLAGS_cluster_name << " | "; } static DsLogger GetMessageLogger() @@ -123,7 +124,7 @@ void LogMessageImpl::Init() void LogMessageImpl::ToSpdlog() { - logger_->log(sourceLoc_, level_, spdlog::string_view_t{g_ThreadLogData, msgSize_}); + logger_->log(sourceLoc_, level_, ds_spdlog::string_view_t{g_ThreadLogData, msgSize_}); if (level_ == SPDLOG_LEVEL_CRITICAL) { logger_->flush(); @@ -142,7 +143,7 @@ void LogMessageImpl::ToStderr() } ConstructLogPrefix(std::cerr, logTime.getTm(), logTime.getUsec(), baseFilename, sourceLoc_.line, podName_.c_str(), - LogSeverityName[0], FLAGS_az_name); + LogSeverityName[0], FLAGS_cluster_name); std::cerr.write(g_ThreadLogData, static_cast(msgSize_)); std::cerr << '\n'; diff --git a/src/datasystem/common/log/spdlog/log_message_impl.h b/src/datasystem/common/log/spdlog/log_message_impl.h index 5d2cea4..43e9dec 100644 --- a/src/datasystem/common/log/spdlog/log_message_impl.h +++ b/src/datasystem/common/log/spdlog/log_message_impl.h @@ -74,9 +74,9 @@ private: */ void ToStderr(); - std::shared_ptr logger_; - spdlog::level::level_enum level_; - spdlog::source_loc sourceLoc_; + std::shared_ptr logger_; + ds_spdlog::level::level_enum level_; + ds_spdlog::source_loc sourceLoc_; static std::string podName_; LogStreamBuf streamBuf_; std::ostream logStream_; diff --git a/src/datasystem/common/log/spdlog/log_param.h b/src/datasystem/common/log/spdlog/log_param.h index 9872c5c..4914374 100644 --- a/src/datasystem/common/log/spdlog/log_param.h +++ b/src/datasystem/common/log/spdlog/log_param.h @@ -39,7 +39,7 @@ constexpr uint32_t SIZE_MEGA_BYTES = 1024 * 1024; // 1 MB const std::string DEFAULT_FILE_LOG_LEVEL = "INFO"; const std::string DEFAULT_LOG_DIR = "/.datasystem/logs"; const std::string DEFAULT_LOG_PATTERN = - "%Y-%m-%dT%H:%M:%S.%6f | %^%L%$ | %s:%# | %v"; // %v = "pod_name | pid:tid | trace_id | az_name | message" + "%Y-%m-%dT%H:%M:%S.%6f | %^%L%$ | %s:%# | %v"; // %v = "pod_name | pid:tid | trace_id | cluster_name | message" const std::string DEFAULT_STDERR_LOG_LEVEL = "FATAL"; } // namespace log_param diff --git a/src/datasystem/common/log/spdlog/log_severity.h b/src/datasystem/common/log/spdlog/log_severity.h index 3712fcf..28eba5b 100644 --- a/src/datasystem/common/log/spdlog/log_severity.h +++ b/src/datasystem/common/log/spdlog/log_severity.h @@ -40,9 +40,9 @@ inline const char *GetLogSeverityName(const int &logLevel) return LogSeverityNames[logLevel]; } -inline static spdlog::level::level_enum ToSpdlogLevel(LogSeverity severity) +inline static ds_spdlog::level::level_enum ToSpdlogLevel(LogSeverity severity) { - return static_cast( + return static_cast( static_cast(severity) + 2 // INFO(0) → info(2) ); } diff --git a/src/datasystem/common/log/spdlog/logger_context.cpp b/src/datasystem/common/log/spdlog/logger_context.cpp index 6c484b6..0113f9c 100644 --- a/src/datasystem/common/log/spdlog/logger_context.cpp +++ b/src/datasystem/common/log/spdlog/logger_context.cpp @@ -37,7 +37,7 @@ namespace datasystem { -constexpr auto LOG_LEVEL_OFF = spdlog::level::off; +constexpr auto LOG_LEVEL_OFF = ds_spdlog::level::off; static std::vector GetLogFiles(const LogParam &logParam) { @@ -65,39 +65,40 @@ static void FlushLogger(DsLogger logger) } } -static const std::map &GetLogLevelMap() +static const std::map &GetLogLevelMap() { - static const std::map LOG_LEVEL_MAP = { { "INFO", spdlog::level::info }, - { "WARNING", spdlog::level::warn }, - { "ERROR", spdlog::level::err }, - { "FATAL", - spdlog::level::critical } }; + static const std::map LOG_LEVEL_MAP = { + { "INFO", ds_spdlog::level::info }, + { "WARNING", ds_spdlog::level::warn }, + { "ERROR", ds_spdlog::level::err }, + { "FATAL", ds_spdlog::level::critical } + }; return LOG_LEVEL_MAP; } -static spdlog::level::level_enum GetLogLevel(const std::string &level) +static ds_spdlog::level::level_enum GetLogLevel(const std::string &level) { auto iter = GetLogLevelMap().find(level); - return iter == GetLogLevelMap().end() ? spdlog::level::info : iter->second; + return iter == GetLogLevelMap().end() ? ds_spdlog::level::info : iter->second; } LoggerContext::LoggerContext(const GlobalLogParam &globalLogParam) noexcept : globalLogParam_(globalLogParam) { - spdlog::drop_all(); - if (!spdlog::thread_pool()) { - spdlog::init_thread_pool(static_cast(globalLogParam_.maxAsyncQueueSize), - static_cast(globalLogParam_.asyncThreadCount)); + ds_spdlog::drop_all(); + if (!ds_spdlog::thread_pool()) { + ds_spdlog::init_thread_pool(static_cast(globalLogParam_.maxAsyncQueueSize), + static_cast(globalLogParam_.asyncThreadCount)); } - spdlog::flush_every(std::chrono::seconds(globalLogParam_.logBufSecs)); + ds_spdlog::flush_every(std::chrono::seconds(globalLogParam_.logBufSecs)); } DsLogger LoggerContext::CreateLogger(const LogParam &logParam) { try { - std::vector sinks{}; + std::vector sinks{}; std::vector logFiles = GetLogFiles(logParam); for (size_t i = 0; i < logFiles.size(); ++i) { - auto rotatingSink = std::make_shared( + auto rotatingSink = std::make_shared( logFiles[i], logParam.maxSize * log_param::SIZE_MEGA_BYTES, logParam.maxFiles); const auto log2FileLevel = ToSpdlogLevel(LogSeverity(i % (NUM_SEVERITIES - 1))); rotatingSink->set_level(log2FileLevel); @@ -108,22 +109,22 @@ DsLogger LoggerContext::CreateLogger(const LogParam &logParam) ? GetLogLevel(logParam.stderrLogLevel) : (logParam.alsoLog2Stderr ? GetLogLevel(logParam.logLevel) : LOG_LEVEL_OFF); if (stderrLogLevel != LOG_LEVEL_OFF) { - auto errSink = std::make_shared(); + auto errSink = std::make_shared(); errSink->set_level(stderrLogLevel); sinks.emplace_back(errSink); } - std::shared_ptr logger; + std::shared_ptr logger; if (logParam.logAsync) { - logger = std::make_shared(DS_LOGGER_NAME, sinks.begin(), sinks.end(), - spdlog::thread_pool(), - spdlog::async_overflow_policy::overrun_oldest); + logger = std::make_shared(DS_LOGGER_NAME, sinks.begin(), sinks.end(), + ds_spdlog::thread_pool(), + ds_spdlog::async_overflow_policy::overrun_oldest); } else { - logger = std::make_shared(DS_LOGGER_NAME, sinks.begin(), sinks.end()); + logger = std::make_shared(DS_LOGGER_NAME, sinks.begin(), sinks.end()); } - spdlog::initialize_logger(logger); - logger->set_pattern(logParam.pattern, spdlog::pattern_time_type::utc); + ds_spdlog::initialize_logger(logger); + logger->set_pattern(logParam.pattern, ds_spdlog::pattern_time_type::utc); const auto logLevel = GetLogLevel(logParam.logLevel); logger->set_level(logLevel); @@ -140,22 +141,22 @@ DsLogger LoggerContext::CreateLogger(const LogParam &logParam) DsLogger LoggerContext::GetLogger(const std::string &loggerName) const noexcept { - return spdlog::get(loggerName); + return ds_spdlog::get(loggerName); } DsLogger LoggerContext::GetDefaultLogger() noexcept { - return spdlog::default_logger(); + return ds_spdlog::default_logger(); } void LoggerContext::DropLogger(const std::string &loggerName) const noexcept { - spdlog::drop(loggerName); + ds_spdlog::drop(loggerName); } bool LoggerContext::ForceFlush(std::chrono::microseconds) const noexcept { - spdlog::apply_all(FlushLogger); + ds_spdlog::apply_all(FlushLogger); return true; } diff --git a/src/datasystem/common/log/spdlog/logger_context.h b/src/datasystem/common/log/spdlog/logger_context.h index 9a8cfa2..68612fe 100644 --- a/src/datasystem/common/log/spdlog/logger_context.h +++ b/src/datasystem/common/log/spdlog/logger_context.h @@ -28,7 +28,7 @@ namespace datasystem { -using DsLogger = std::shared_ptr; +using DsLogger = std::shared_ptr; const std::string DS_LOGGER_NAME = "DsLogger"; diff --git a/src/datasystem/common/metrics/hard_disk_exporter/hard_disk_exporter.cpp b/src/datasystem/common/metrics/hard_disk_exporter/hard_disk_exporter.cpp index ea30dc7..b2ed08e 100644 --- a/src/datasystem/common/metrics/hard_disk_exporter/hard_disk_exporter.cpp +++ b/src/datasystem/common/metrics/hard_disk_exporter/hard_disk_exporter.cpp @@ -29,7 +29,7 @@ #include "datasystem/common/util/uri.h" #include "datasystem/common/log/log_time.h" -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); DS_DECLARE_string(log_dir); DS_DECLARE_int32(logfile_mode); DS_DECLARE_uint32(log_size); @@ -52,7 +52,7 @@ void HardDiskExporter::Send(const std::string &message, Uri &uri, int line) std::ostringstream constructStr; LogTime logTime; ConstructLogPrefix(constructStr, logTime.getTm(), logTime.getUsec(), uri.GetFileName().c_str(), line, podName_, - 'I', FLAGS_az_name); + 'I', FLAGS_cluster_name); constructStr << std::string(message); WriteMessage(constructStr.str()); } diff --git a/src/datasystem/common/metrics/res_metric_collector.cpp b/src/datasystem/common/metrics/res_metric_collector.cpp index 89fa11b..2e46f4f 100644 --- a/src/datasystem/common/metrics/res_metric_collector.cpp +++ b/src/datasystem/common/metrics/res_metric_collector.cpp @@ -33,7 +33,13 @@ #include "datasystem/common/util/strings_util.h" #include "datasystem/common/util/validator.h" +DS_DEFINE_int32(sc_regular_socket_num, 16, + "The number of regular backend socket for stream cache. Must be great equal than 0."); +DS_DEFINE_int32(sc_stream_socket_num, 16, + "The number of stream backend socket for stream cache. Must be great equal than 0."); DS_DEFINE_int32(log_monitor_interval_ms, 10000, "The sleep time between iterations of observability collector scan"); +DS_DEFINE_validator(sc_regular_socket_num, &Validator::ValidateRpcThreadNum); +DS_DEFINE_validator(sc_stream_socket_num, &Validator::ValidateRpcThreadNum); DS_DEFINE_validator(log_monitor_interval_ms, &Validator::ValidateInt32); DS_DECLARE_bool(log_monitor); DS_DECLARE_string(log_monitor_exporter); diff --git a/src/datasystem/common/metrics/res_metric_collector.h b/src/datasystem/common/metrics/res_metric_collector.h index fe21a6b..a681f45 100644 --- a/src/datasystem/common/metrics/res_metric_collector.h +++ b/src/datasystem/common/metrics/res_metric_collector.h @@ -33,8 +33,14 @@ #include "datasystem/common/util/wait_post.h" DS_DECLARE_int32(log_monitor_interval_ms); +DS_DECLARE_int32(sc_regular_socket_num); +DS_DECLARE_int32(sc_stream_socket_num); namespace datasystem { +inline bool EnableSCService() +{ + return FLAGS_sc_regular_socket_num > 0 && FLAGS_sc_stream_socket_num > 0; +} const std::string RES_ETCD_DEFAULT_USAGE = "0/0/0"; const std::string RES_THREAD_POOL_DEFAULT_USAGE = "0/0/0/0/0"; diff --git a/src/datasystem/common/object_cache/buffer_composer.cpp b/src/datasystem/common/object_cache/buffer_composer.cpp index 31d6230..59b48dd 100644 --- a/src/datasystem/common/object_cache/buffer_composer.cpp +++ b/src/datasystem/common/object_cache/buffer_composer.cpp @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - + /** * Description: Implementation of compose and decompose buffer data. */ @@ -26,34 +26,28 @@ namespace datasystem { namespace object_cache { -Status PrepareDataSizeList(std::vector &sizeList, const std::vector> &dataInfoList, +Status PrepareDataSizeList(std::vector &sizeList, const std::vector &devBlobList, BlobListInfo &blobInfo) { const uint64_t dataAlignSize = 64; - blobInfo.nonExistNums = dataInfoList.size(); + blobInfo.nonExistNums = devBlobList.size(); std::vector blobNumsList; std::vector blobSizeList; - for (const auto &info : dataInfoList) { + for (const auto &blobList : devBlobList) { // For Length, Prefix Sum Arr, in O(1) and O(num+1) space. // Round to 64x. + auto &info = blobList.blobs; uint64_t num = info.size(); uint64_t sz = sizeof(uint64_t) * (num + 2); sz = (sz + dataAlignSize - 1) / dataAlignSize * dataAlignSize; - - int prevDevIdx = -1; + blobNumsList.emplace_back(info.size()); for (auto &desc : info) { - if (prevDevIdx == -1) { - prevDevIdx = desc.deviceIdx; - } - if (desc.deviceIdx == -1 || prevDevIdx != desc.deviceIdx) { - return { K_INVALID, "Please set deviceIdx in datainfo." }; - } - sz += desc.Size(); - blobInfo.totalSize += desc.Size(); - blobSizeList.emplace_back(desc.Size()); + sz += desc.size; + blobInfo.totalSize += desc.size; + blobSizeList.emplace_back(desc.size); } - + sizeList.emplace_back(sz); } if (blobInfo.nonExistNums <= 0) { @@ -66,12 +60,11 @@ Status PrepareDataSizeList(std::vector &sizeList, const std::vector> &bufferList, - const std::vector> &dataInfoList) + +void ComposeBufferData(std::vector> &bufferList, const std::vector &devBlobList) { // Record MetaData of SubBuffers. // | NumOfBuffers (n) | Off0 | Off1 | Offn | Padding | Buf1 | Buf2 | ... Bufn | @@ -80,28 +73,19 @@ void ComposeBufferData(std::vector> &bufferList, for (uint64_t i = 0; i < bufferList.size(); i++) { auto &buf = bufferList[i]; - auto &dataInfos = dataInfoList[i]; + auto &blobs = devBlobList[i].blobs; auto prefixSumArr = reinterpret_cast(buf->MutableData()); - - uint64_t num = dataInfos.size(); + + uint64_t num = blobs.size(); uint64_t descSz = sizeof(uint64_t) * (num + preOccupySize); descSz = (descSz + dataAlignSize - 1) / dataAlignSize * dataAlignSize; - - prefixSumArr[0] = dataInfos.size(); + + prefixSumArr[0] = blobs.size(); prefixSumArr[1] = descSz; - for (uint64_t j = 0; j < dataInfos.size(); j++) { - prefixSumArr[j + preOccupySize] = prefixSumArr[j + 1] + dataInfos[j].size; + for (uint64_t j = 0; j < blobs.size(); j++) { + prefixSumArr[j + preOccupySize] = prefixSumArr[j + 1] + blobs[j].size; } } } - -size_t GetComposedBufferSize(void* mutableData) -{ - auto offsetArrPtr = reinterpret_cast(mutableData); - if (offsetArrPtr) { - return *offsetArrPtr; - } - return 0; -} -} -} +} // namespace object_cache +} // namespace datasystem diff --git a/src/datasystem/common/object_cache/buffer_composer.h b/src/datasystem/common/object_cache/buffer_composer.h index 12e397f..11cc3a8 100644 --- a/src/datasystem/common/object_cache/buffer_composer.h +++ b/src/datasystem/common/object_cache/buffer_composer.h @@ -24,8 +24,8 @@ #include #include +#include "datasystem/hetero/device_common.h" #include "datasystem/object/buffer.h" -#include "datasystem/client/hetero_cache/device_util.h" namespace datasystem { namespace object_cache { @@ -62,26 +62,20 @@ struct BlobListInfo { /** * @brief Prepare the data sizes by user data list * @param[out] sizeList The list of all data sizes - * @param[in] dataInfoList The user data list + * @param[in] devBlobList The user data list * @param[in] blobInfo The information of blob * @return K_OK on any object success; the error code otherwise. */ -Status PrepareDataSizeList(std::vector &sizeList, const std::vector> &dataInfoList, +Status PrepareDataSizeList(std::vector &sizeList, const std::vector &devBlobList, BlobListInfo &blobInfo); /** * @brief Compose buffer list by user data list * @param[out] bufferList Compose the user data to bufferList - * @param[in] dataInfoList The user data list + * @param[in] devBlobList The user data list */ void ComposeBufferData(std::vector> &bufferList, - const std::vector> &dataInfoList); -/** - * @brief Get the total data size of this buffer - * @param[in] mutableData Pointer to the composed data - * @return The size of composed size - */ -size_t GetComposedBufferSize(void* mutableData); + const std::vector &devBlobList); } // namespace object_cache } // namespace datasystem diff --git a/src/datasystem/common/object_cache/device_buffer.cpp b/src/datasystem/common/object_cache/device_buffer.cpp index 8f60316..88bc4b6 100644 --- a/src/datasystem/common/object_cache/device_buffer.cpp +++ b/src/datasystem/common/object_cache/device_buffer.cpp @@ -111,10 +111,10 @@ DeviceBuffer::~DeviceBuffer() uint64_t DeviceBuffer::Size() const { - DataInfo dataInfo; - Status rc = deviceMemUnit_->CheckAndGetSingleDataInfo(dataInfo); + Blob blob; + Status rc = deviceMemUnit_->CheckAndGetSingleBlob(blob); if (rc.IsOk()) { - return dataInfo.Size(); + return blob.size; } LOG(ERROR) << rc.ToString(); return 0; @@ -122,10 +122,10 @@ uint64_t DeviceBuffer::Size() const void *DeviceBuffer::Data() const { - DataInfo dataInfo; - Status rc = deviceMemUnit_->CheckAndGetSingleDataInfo(dataInfo); + Blob blob; + Status rc = deviceMemUnit_->CheckAndGetSingleBlob(blob); if (rc.IsOk()) { - return dataInfo.devPtr; + return blob.pointer; } LOG(ERROR) << rc.ToString(); return nullptr; @@ -140,8 +140,8 @@ Status DeviceBuffer::Publish() { TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); CHECK_FAIL_RETURN_STATUS(!bufferInfo_->isPublished, K_OC_ALREADY_SEALED, "Device object is already published"); - CHECK_FAIL_RETURN_STATUS(!deviceMemUnit_->GetDataInfoStorage().empty(), K_INVALID, - "The dataInfo can't be empty in device buffer"); + CHECK_FAIL_RETURN_STATUS(!deviceMemUnit_->GetBlobsStorage().empty(), K_INVALID, + "The blobs can't be empty in device buffer"); Status rc = clientImpl_->PublishDeviceObject(shared_from_this()); if (rc.IsOk()) { bufferInfo_->isPublished = true; @@ -149,9 +149,9 @@ Status DeviceBuffer::Publish() return rc; } -std::vector DeviceBuffer::GetDataInfoList() const +std::vector DeviceBuffer::GetDevBlobList() const { - return deviceMemUnit_->GetDataInfoStorage(); + return deviceMemUnit_->GetBlobsStorage(); } Status DeviceBuffer::GetSendStatus(std::vector &futureVec) diff --git a/src/datasystem/common/object_cache/safe_object.h b/src/datasystem/common/object_cache/safe_object.h index d597f4c..f0a1beb 100644 --- a/src/datasystem/common/object_cache/safe_object.h +++ b/src/datasystem/common/object_cache/safe_object.h @@ -147,6 +147,16 @@ public: */ Status TryRLock(bool nullable = false); + /** + * @brief Transfers the write lock from the current thread to the calling thread. + * + * This function is used to transfer the write lock ownership from the current thread to the thread that calls this + * function. It ensures that the write lock is held by the calling thread after the transfer. + * + * @return Status of the call. + */ + Status TransferWLockToCurrentThread(); + /** * @brief Releases a read lock on the SafeObject. */ @@ -268,6 +278,17 @@ Status SafeObject::WLock(bool nullable) return Status::OK(); } +template +Status SafeObject::TransferWLockToCurrentThread() +{ + if (!wLocked_) { + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "Write lock is not held by any thread."); + } + pid_t currentTid = syscall(__NR_gettid); + lastWriteThread_ = currentTid; + return Status::OK(); +} + template Status SafeObject::TryWLock(bool nullable) { diff --git a/src/datasystem/common/perf/perf_point.def b/src/datasystem/common/perf/perf_point.def index eaf24ba..5a5beb7 100644 --- a/src/datasystem/common/perf/perf_point.def +++ b/src/datasystem/common/perf/perf_point.def @@ -73,6 +73,9 @@ PERF_KEY_DEF(KV_CLIENT_DEL_OBJECT) PERF_KEY_DEF(RPC_CLIENT_DEL_OBJECT) PERF_KEY_DEF(KV_CLIENT_DEL_MUL_OBJECTS) PERF_KEY_DEF(KV_CLIENT_EXPIRE_OBJECT) +PERF_KEY_DEF(CLIENT_MSET_MULTICREATE) +PERF_KEY_DEF(CLIENT_MSET_MEMCOPY) +PERF_KEY_DEF(CLIENT_MSET_MULTI_PUBLSIH) PERF_KEY_DEF(CLIENT_MSET_D2H_ALL) PERF_KEY_DEF(CLIENT_MSET_CHECK_EXISTS) @@ -352,6 +355,7 @@ PERF_KEY_DEF(MASTER_ROCKSDB_CREATE_META) PERF_KEY_DEF(MASTER_ROCKSDB_ADD_OBJ_LOCATION) PERF_KEY_DEF(MASTER_ROCKSDB_REMOVE_OBJ_LOCATION) PERF_KEY_DEF(MASTER_DELETE_OBJECT) +PERF_KEY_DEF(MASTER_IS_EXIST_DEAD_LOCK) PERF_KEY_DEF(ROCKSDB_PUT) PERF_KEY_DEF(ROCKSDB_GET) diff --git a/src/datasystem/common/rdma/CMakeLists.txt b/src/datasystem/common/rdma/CMakeLists.txt index 7fba4d1..0e43263 100644 --- a/src/datasystem/common/rdma/CMakeLists.txt +++ b/src/datasystem/common/rdma/CMakeLists.txt @@ -9,11 +9,12 @@ set(URMA_DEPEND_LIBS common_rdma_util) if (BUILD_WITH_URMA) - list(APPEND URMA_SRCS urma_manager.cpp) + list(APPEND URMA_SRCS urma_manager.cpp urma_info.cpp) list(APPEND URMA_DEPEND_LIBS ${URMA_LIBRARY}) set(URMA_STUB_SRCS - urma_stub.cpp) + urma_stub.cpp + urma_info.cpp) add_library(common_stub_rdma STATIC ${URMA_STUB_SRCS}) endif() diff --git a/src/datasystem/common/rdma/rdma_util.cpp b/src/datasystem/common/rdma/rdma_util.cpp index 62d1b66..38dd9f3 100644 --- a/src/datasystem/common/rdma/rdma_util.cpp +++ b/src/datasystem/common/rdma/rdma_util.cpp @@ -34,9 +34,12 @@ #include #include "datasystem/common/log/log.h" +#include "datasystem/common/flags/flags.h" #include "datasystem/common/util/status_helper.h" #include "datasystem/common/util/strings_util.h" +DS_DECLARE_string(urma_mode); + namespace datasystem { static constexpr int RDMA_LOG_LEVEL = 3; const int OCTET = 8; @@ -218,4 +221,14 @@ Status EthToRdmaDevName(std::string ethDevName, std::string &rdmaDevName) VLOG(RDMA_LOG_LEVEL) << "Result RdmaDevName is " << rdmaDevName; return Status::OK(); } + +UrmaMode GetUrmaMode() +{ + if (FLAGS_urma_mode == "IB") { + return UrmaMode::IB; + } else if (FLAGS_urma_mode == "UB") { + return UrmaMode::UB; + } + return UrmaMode::UNKNOWN; +}; } // namespace datasystem \ No newline at end of file diff --git a/src/datasystem/common/rdma/rdma_util.h b/src/datasystem/common/rdma/rdma_util.h index 0fa87b6..a9c07ee 100644 --- a/src/datasystem/common/rdma/rdma_util.h +++ b/src/datasystem/common/rdma/rdma_util.h @@ -45,5 +45,13 @@ int GetDevNameFromLocalIp(const std::string &ipAddr, std::string &devName); * @return Status of the call. */ Status EthToRdmaDevName(std::string ethDevName, std::string &rdmaDevName); + +enum class UrmaMode { IB = 0, UB = 1, UNKNOWN }; +/** + * @brief Check the URMA mode. + * @return The urma mode, valid options are IB or UB. + */ +UrmaMode GetUrmaMode(); + } // namespace datasystem #endif // DATASYSTEM_COMMON_RDMA_RDMA_UTIL_H \ No newline at end of file diff --git a/src/datasystem/common/rdma/urma_info.cpp b/src/datasystem/common/rdma/urma_info.cpp new file mode 100644 index 0000000..d2ea6fc --- /dev/null +++ b/src/datasystem/common/rdma/urma_info.cpp @@ -0,0 +1,202 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "datasystem/common/rdma/urma_info.h" + +#include "datasystem/common/rdma/urma_manager.h" +#include "datasystem/utils/status.h" + +namespace datasystem { + +std::string UrmaJfrInfo::EidToFmtStr(const urma_eid_t &eid) +{ + char s[URMA_EID_STR_LEN + 1] = { 0 }; + int ret = sprintf_s(s, URMA_EID_STR_LEN + 1, EID_FMT, EID_ARGS(eid)); + return ret == -1 ? "invalid" : s; +} + +std::string UrmaJfrInfo::ToString() const +{ + std::stringstream oss; + oss << "address " << localAddress.ToString() << ", eid "; + // eid is not really printable as a string. So we will dump its context in hex + urma_eid_t e; + if (UrmaManager::StrToEid(eid, e).IsOk()) { + oss << EidToFmtStr(e); + } else { + oss << eid; + } + oss << " uasid " << uasid << " jfr_id ["; + bool first = true; + for (auto jfr_id : jfrIds) { + if (first) { + first = false; + } else { + oss << " "; + } + oss << jfr_id; + } + oss << "]"; + oss << " bondInfos ["; + first = true; +#ifdef URMA_OVER_UB + for (auto bondInfo : bondInfos) { + if (first) { + first = false; + } else { + oss << ","; + } + oss << "base_id:" << EidToFmtStr(bondInfo.base_id.eid) << "+" << bondInfo.base_id.uasid << "+" + << bondInfo.base_id.id << " "; + for (size_t index = 0; index < URMA_UBAGG_DEV_MAX_NUM; index++) { + if (bondInfo.slave_id[index].id > 0) { + oss << "slave_id[" << index << "]:" << EidToFmtStr(bondInfo.slave_id[index].eid) << "+" + << bondInfo.slave_id[index].uasid << "+" << bondInfo.slave_id[index].id << " "; + } + } + oss << "dev_num:" << bondInfo.dev_num << " "; + oss << "is_in_matrix_server:" << bondInfo.is_in_matrix_server << " "; + oss << "is_multipath:" << bondInfo.is_multipath; + } +#endif + oss << "]"; + return oss.str(); +} + +#ifdef URMA_OVER_UB +void UrmaJfrInfo::UrmaBondIdInfoToProto(const urma_bond_id_info_out &info, JfrBondInfo &proto) +{ + auto baseId = proto.mutable_base_id(); + baseId->set_eid(UrmaManager::EidToStr(info.base_id.eid)); + baseId->set_uasid(info.base_id.uasid); + baseId->set_id(info.base_id.id); + for (int index = 0; index < URMA_UBAGG_DEV_MAX_NUM; index++) { + auto slaveId = proto.add_slave_ids(); + slaveId->set_eid(UrmaManager::EidToStr(info.slave_id[index].eid)); + slaveId->set_uasid(info.slave_id[index].uasid); + slaveId->set_id(info.slave_id[index].id); + } + proto.set_dev_num(info.dev_num); + proto.set_is_in_matrix_server(info.is_in_matrix_server); + proto.set_is_multipath(info.is_multipath); +} + +Status UrmaJfrInfo::UrmaBondIdInfoFromProto(const JfrBondInfo &proto, urma_bond_id_info_out &info) +{ + const auto &baseId = proto.base_id(); + RETURN_IF_NOT_OK(UrmaManager::StrToEid(baseId.eid(), info.base_id.eid)); + info.base_id.uasid = baseId.uasid(); + info.base_id.id = baseId.id(); + auto slaveSize = proto.slave_ids_size(); + for (int index = 0; index < slaveSize; index++) { + const auto &slaveId = proto.slave_ids(index); + RETURN_IF_NOT_OK(UrmaManager::StrToEid(slaveId.eid(), info.slave_id[index].eid)); + info.slave_id[index].uasid = slaveId.uasid(); + info.slave_id[index].id = slaveId.id(); + } + info.dev_num = proto.dev_num(); + info.is_in_matrix_server = proto.is_in_matrix_server(); + info.is_multipath = proto.is_multipath(); + return Status::OK(); +} +#endif + +void UrmaSeg::ToProto(const urma_seg_t &seg, UrmaSegPb &proto) +{ + proto.set_eid(UrmaManager::EidToStr(seg.ubva.eid)); + proto.set_uasid(seg.ubva.uasid); + proto.set_va(seg.ubva.va); + proto.set_len(seg.len); + proto.set_attr(seg.attr.value); + proto.set_token_id(seg.token_id); +} + +Status UrmaSeg::FromProto(const UrmaSegPb &proto, urma_seg_t &seg) +{ + urma_eid_t eid; + RETURN_IF_NOT_OK(UrmaManager::StrToEid(proto.eid(), eid)); + seg.ubva.eid = eid; + seg.ubva.uasid = proto.uasid(); + seg.ubva.va = proto.va(); + seg.len = proto.len(); + seg.attr.value = proto.attr(); + seg.token_id = proto.token_id(); + return Status::OK(); +} + +std::string UrmaSeg::ToString(const urma_seg_t &seg) +{ + std::stringstream ss; + ss << "ubva: { eid: " << UrmaJfrInfo::EidToFmtStr(seg.ubva.eid); + ss << ", uasid: " << seg.ubva.uasid; + ss << ", va: " << seg.ubva.va; + ss << "}, len: " << seg.len; + ss << ", attr: " << seg.attr.value; + ss << ", token_id: " << seg.token_id; + return ss.str(); +} + +std::string UrmaSeg::ToString() +{ + return UrmaSeg::ToString(raw); +} + +void UrmaSeg::ToProto(UrmaSegPb &proto) const +{ + UrmaSeg::ToProto(raw, proto); +} + +Status UrmaSeg::FromProto(const UrmaSegPb &proto) +{ + return UrmaSeg::FromProto(proto, raw); +} + +#ifdef URMA_OVER_UB +std::string UrmaBondSegInfo::ToString() +{ + std::stringstream ss; + ss << "base: {" << UrmaSeg::ToString(raw.base) << "}"; + for (int index = 0; index < URMA_UBAGG_DEV_MAX_NUM; index++) { + ss << ", slaves[" << index << "]: {" << UrmaSeg::ToString(raw.slaves[index]) << "}"; + } + ss << ", dev_num: " << raw.dev_num; + return ss.str(); +} + +void UrmaBondSegInfo::ToProto(UrmaBondSegInfoPb &proto) const +{ + auto base = proto.mutable_base(); + UrmaSeg::ToProto(raw.base, *base); + for (int index = 0; index < URMA_UBAGG_DEV_MAX_NUM; index++) { + auto slave = proto.add_slaves(); + UrmaSeg::ToProto(raw.slaves[index], *slave); + } + proto.set_dev_num(raw.dev_num); +} + +Status UrmaBondSegInfo::FromProto(const UrmaBondSegInfoPb &proto) +{ + const auto &base = proto.base(); + RETURN_IF_NOT_OK(UrmaSeg::FromProto(base, raw.base)); + auto slaveSize = proto.slaves_size(); + for (int index = 0; index < slaveSize; index++) { + const auto &slave = proto.slaves(index); + RETURN_IF_NOT_OK(UrmaSeg::FromProto(slave, raw.slaves[index])); + } + raw.dev_num = proto.dev_num(); + return Status::OK(); +} +#endif +} // namespace datasystem diff --git a/src/datasystem/common/rdma/urma_info.h b/src/datasystem/common/rdma/urma_info.h new file mode 100644 index 0000000..e5d7d6a --- /dev/null +++ b/src/datasystem/common/rdma/urma_info.h @@ -0,0 +1,117 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_COMMON_RDMA_URMA_INFO_H +#define DATASYSTEM_COMMON_RDMA_URMA_INFO_H + +#include +#include + +#include +#ifdef URMA_OVER_UB +#include +#endif + +#include "datasystem/common/rdma/rdma_util.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/protos/meta_zmq.pb.h" +#include "datasystem/protos/utils.pb.h" + +namespace datasystem { +struct UrmaJfrInfo { + std::string eid; + uint32_t uasid{ 0 }; + std::vector jfrIds; + HostPort localAddress; +#ifdef URMA_OVER_UB + std::vector bondInfos; +#endif + + std::string ToString() const; + + static std::string EidToFmtStr(const urma_eid_t &eid); +#ifdef URMA_OVER_UB + static void UrmaBondIdInfoToProto(const urma_bond_id_info_out &info, JfrBondInfo &proto); + static Status UrmaBondIdInfoFromProto(const JfrBondInfo &proto, urma_bond_id_info_out &info); +#endif + + template + void ToProto(Proto &proto) const + { + proto.set_eid(eid); + proto.set_uasid(uasid); + for (auto jfrId : jfrIds) { + proto.add_jfr_ids(jfrId); + } + proto.mutable_address()->set_host(localAddress.Host()); + proto.mutable_address()->set_port(localAddress.Port()); +#ifdef URMA_OVER_UB + if (GetUrmaMode() == UrmaMode::UB) { + for (auto &bondInfo : bondInfos) { + auto info = proto.add_bond_infos(); + UrmaBondIdInfoToProto(bondInfo, *info); + } + } +#endif + } + + template + Status FromProto(const Proto &proto) + { + eid = proto.eid(); + uasid = proto.uasid(); + for (auto jfrId : proto.jfr_ids()) { + jfrIds.emplace_back(jfrId); + } + localAddress = HostPort(proto.address().host(), proto.address().port()); +#ifdef URMA_OVER_UB + if (GetUrmaMode() == UrmaMode::UB) { + auto size = proto.bond_infos_size(); + for (int i = 0; i < size; i++) { + bondInfos.emplace_back(); + RETURN_IF_NOT_OK(UrmaBondIdInfoFromProto(proto.bond_infos(i), bondInfos[i])); + } + } +#endif + return Status::OK(); + } +}; + +struct UrmaSeg { + urma_seg_t raw; + static void ToProto(const urma_seg_t &seg, UrmaSegPb &proto); + static Status FromProto(const UrmaSegPb &proto, urma_seg_t &seg); + static std::string ToString(const urma_seg_t &seg); + + std::string ToString(); + void ToProto(UrmaSegPb &proto) const; + Status FromProto(const UrmaSegPb &proto); +}; + +#ifdef URMA_OVER_UB +struct UrmaBondSegInfo { + urma_bond_seg_info_out raw; + + std::string ToString(); + void ToProto(UrmaBondSegInfoPb &proto) const; + Status FromProto(const UrmaBondSegInfoPb &proto); +}; +#endif + +} // namespace datasystem + +#endif diff --git a/src/datasystem/common/rdma/urma_manager.cpp b/src/datasystem/common/rdma/urma_manager.cpp index 1ae1541..7684351 100644 --- a/src/datasystem/common/rdma/urma_manager.cpp +++ b/src/datasystem/common/rdma/urma_manager.cpp @@ -19,14 +19,17 @@ */ #include "datasystem/common/rdma/urma_manager.h" +#include "datasystem/common/constants.h" #include "datasystem/common/log/log.h" #include "datasystem/common/flags/flags.h" +#include "datasystem/common/perf/perf_manager.h" #include "datasystem/common/rdma/rdma_util.h" #include "datasystem/common/rdma/urma_manager_wrapper.h" #include "datasystem/common/rpc/rpc_constants.h" #include "datasystem/common/util/raii.h" #include "datasystem/common/util/thread_local.h" -#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/utils/status.h" +#include "urma_opcode.h" DS_DECLARE_uint32(urma_poll_size); DS_DECLARE_uint32(urma_connection_size); @@ -46,18 +49,27 @@ UrmaManager::UrmaManager() urmaToken_.token = DEFAULT_TOKEN; registerSegmentFlag_.bs.token_policy = URMA_TOKEN_PLAIN_TEXT; + registerSegmentFlag_.bs.token_id_valid = URMA_TOKEN_ID_INVALID; + LOG(INFO) << "registerSegmentFlag_.token_id_valid=" << URMA_TOKEN_ID_INVALID; registerSegmentFlag_.bs.cacheable = URMA_NON_CACHEABLE; - registerSegmentFlag_.bs.access = - URMA_ACCESS_LOCAL_WRITE | URMA_ACCESS_REMOTE_READ | URMA_ACCESS_REMOTE_WRITE | URMA_ACCESS_REMOTE_ATOMIC; registerSegmentFlag_.bs.reserved = 0; importSegmentFlag_.bs.cacheable = URMA_NON_CACHEABLE; - importSegmentFlag_.bs.access = - URMA_ACCESS_LOCAL_WRITE | URMA_ACCESS_REMOTE_READ | URMA_ACCESS_REMOTE_WRITE | URMA_ACCESS_REMOTE_ATOMIC; importSegmentFlag_.bs.mapping = URMA_SEG_NOMAP; importSegmentFlag_.bs.reserved = 0; + +#ifdef URMA_OVER_UB + registerSegmentFlag_.bs.access = URMA_ACCESS_READ | URMA_ACCESS_WRITE | URMA_ACCESS_ATOMIC; + importSegmentFlag_.bs.access = URMA_ACCESS_READ | URMA_ACCESS_WRITE | URMA_ACCESS_ATOMIC; +#else + registerSegmentFlag_.bs.access = + URMA_ACCESS_LOCAL_WRITE | URMA_ACCESS_REMOTE_READ | URMA_ACCESS_REMOTE_WRITE | URMA_ACCESS_REMOTE_ATOMIC; + importSegmentFlag_.bs.access = + URMA_ACCESS_LOCAL_WRITE | URMA_ACCESS_REMOTE_READ | URMA_ACCESS_REMOTE_WRITE | URMA_ACCESS_REMOTE_ATOMIC; +#endif localSegmentMap_ = std::make_unique(); remoteDeviceMap_ = std::make_unique(); + eventMap_ = std::make_unique(); } UrmaManager::~UrmaManager() @@ -66,6 +78,7 @@ UrmaManager::~UrmaManager() VLOG(RPC_LOG_LEVEL) << "UrmaManager::~UrmaManager()"; remoteDeviceMap_.reset(); localSegmentMap_.reset(); + eventMap_.reset(); urmaJfrVec_.clear(); urmaJfsVec_.clear(); urmaJfc_.reset(); @@ -86,17 +99,26 @@ Status UrmaManager::Stop() return Status::OK(); } -Status UrmaManager::Init(const std::string &host) +Status UrmaManager::Init(const HostPort &hostport) { - LOG(INFO) << "UrmaManager::Init(host = " << host << ")"; + LOG(INFO) << FormatString("UrmaManager::Init(hostport = %s)", hostport.ToString()); std::string deviceName; - CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(GetDevNameFromLocalIp(host, deviceName) == 0, K_INVALID, + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(GetDevNameFromLocalIp(hostport.Host(), deviceName) == 0, K_INVALID, "Invalid ip address to get device name"); - RETURN_IF_NOT_OK(EthToRdmaDevName(deviceName, deviceName)); LOG(INFO) << "deviceName = " << deviceName; + std::string urmaDeviceName; + if (GetUrmaMode() == UrmaMode::IB) { + RETURN_IF_NOT_OK(EthToRdmaDevName(deviceName, urmaDeviceName)); + } else if (GetUrmaMode() == UrmaMode::UB) { + urmaDeviceName = GetStringFromEnv(ENV_UB_DEVICE_NAME.c_str(), DEFAULT_UB_DEVICE_NAME.c_str()); + if (urmaDeviceName.empty()) { + RETURN_STATUS(K_INVALID, "env DS_URMA_DEV_NAME is empty"); + } + } + LOG(INFO) << "urmaDeviceName = " << urmaDeviceName; RETURN_IF_NOT_OK(UrmaInit()); urma_device_t *urmaDevice = nullptr; - RETURN_IF_NOT_OK(UrmaGetDeviceByName(deviceName, urmaDevice)); + RETURN_IF_NOT_OK(UrmaGetDeviceByName(urmaDeviceName, urmaDevice)); RETURN_IF_NOT_OK(UrmaQueryDevice(urmaDevice)); int eidIndex = -1; RETURN_IF_NOT_OK(GetEidIndex(urmaDevice, eidIndex)); @@ -112,10 +134,26 @@ Status UrmaManager::Init(const std::string &host) RETURN_IF_NOT_OK(UrmaCreateJfs(urmaJfc_, urmaJfsVec_[i])); RETURN_IF_NOT_OK(UrmaCreateJfr(urmaJfc_, urmaJfrVec_[i])); } + RETURN_IF_NOT_OK(InitLocalUrmaInfo(hostport)); serverEventThread_ = std::make_unique(&UrmaManager::ServerEventHandleThreadMain, this); return Status::OK(); } +Status UrmaManager::InitLocalUrmaInfo(const HostPort &hostport) +{ + localUrmaInfo_.eid = GetEid(); + localUrmaInfo_.uasid = GetUasid(); + localUrmaInfo_.jfrIds = GetJfrIds(); + localUrmaInfo_.localAddress = hostport; +#ifdef URMA_OVER_UB + if (GetUrmaMode() == UrmaMode::UB) { + RETURN_IF_NOT_OK(GetJfrInfoForBond(localUrmaInfo_.bondInfos)); + } +#endif + LOG(INFO) << "local urma info: " << localUrmaInfo_.ToString(); + return Status::OK(); +} + Status UrmaManager::UrmaInit() { LOG(INFO) << "UrmaManager::UrmaInit()"; @@ -158,7 +196,7 @@ Status UrmaManager::UrmaQueryDevice(urma_device_t *&urmaDevice) if (ret != URMA_SUCCESS) { RETURN_STATUS_LOG_ERROR(K_URMA_ERROR, FormatString("Failed to urma query device, ret = %d", ret)); } - LOG(INFO) << "urma query device success"; + LOG(INFO) << "urma query device success with dev type:" << urmaDevice->type; return Status::OK(); } @@ -195,12 +233,15 @@ Status UrmaManager::GetEidIndex(urma_device_t *&urmaDevice, int &eidIndex) RETURN_STATUS_LOG_ERROR(K_URMA_ERROR, "Failed to get eid index for device"); } -Status UrmaManager::UrmaCreateContext(urma_device_t *&urmaDevice, const uint32_t eidIndex) +Status UrmaManager::UrmaCreateContext(urma_device_t *&urmaDevice, uint32_t eidIndex) { - LOG(INFO) << "UrmaManager::UrmaCreateContext()"; + LOG(INFO) << "UrmaManager::UrmaCreateContext() with eidIndex:" << eidIndex; CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(!urmaContext_, K_DUPLICATED, "Failed to urma create context, context already exist"); - + if (GetUrmaMode() == UrmaMode::UB) { + LOG(INFO) << "force using eidIndex 0"; + eidIndex = 0; + } urmaContext_ = urma_create_context(urmaDevice, eidIndex); if (urmaContext_) { LOG(INFO) << "urma create context success"; @@ -256,8 +297,13 @@ Status UrmaManager::UrmaCreateJfc(custom_unique_ptr &out) urma_jfc_cfg_t jfcConfig; jfcConfig.depth = urmaDeviceAttribute_.dev_cap.max_jfc_depth; jfcConfig.flag.value = 0; - jfcConfig.jfce = urmaJfce_; + if (GetUrmaMode() == UrmaMode::IB) { + jfcConfig.jfce = urmaJfce_; + } else if (GetUrmaMode() == UrmaMode::UB) { + jfcConfig.jfce = nullptr; + } jfcConfig.user_ctx = 0; + jfcConfig.ceqn = 0; out = MakeCustomUnique(urma_create_jfc(urmaContext_, &jfcConfig), [](urma_jfc_t *p) { std::stringstream oss; @@ -302,6 +348,12 @@ Status UrmaManager::UrmaCreateJfs(const custom_unique_ptr &jfc, cust jfsConfig.err_timeout = URMA_TYPICAL_ERR_TIMEOUT; jfsConfig.jfc = jfc.get(); jfsConfig.user_ctx = 0; + jfsConfig.flag.value = 0; +#ifdef URMA_OVER_UB + if (GetUrmaMode() == UrmaMode::UB) { + jfsConfig.flag.bs.multi_path = 1; + } +#endif out = MakeCustomUnique(urma_create_jfs(urmaContext_, &jfsConfig), [](urma_jfs_t *p) { std::stringstream oss; @@ -327,6 +379,7 @@ Status UrmaManager::UrmaCreateJfr(const custom_unique_ptr &jfc, cust urma_jfr_cfg_t jfrConfig; jfrConfig.depth = JETTY_SIZE_; + jfrConfig.flag.value = 0; jfrConfig.flag.bs.tag_matching = URMA_NO_TAG_MATCHING; jfrConfig.trans_mode = URMA_TM_RM; jfrConfig.min_rnr_timer = URMA_TYPICAL_MIN_RNR_TIMER; @@ -334,6 +387,7 @@ Status UrmaManager::UrmaCreateJfr(const custom_unique_ptr &jfc, cust jfrConfig.token_value = urmaToken_; jfrConfig.id = 0; jfrConfig.max_sge = 1; + jfrConfig.user_ctx = (uint64_t)NULL; out = MakeCustomUnique(urma_create_jfr(urmaContext_, &jfrConfig), [](urma_jfr_t *p) { std::stringstream oss; @@ -384,16 +438,65 @@ Status UrmaManager::RegisterSegment(const uint64_t &segAddress, const uint64_t & return Status::OK(); } -Status UrmaManager::GetSegmentInfo(const uint64_t &segAddress, const uint64_t &segSize, uint64_t &segVA, - uint64_t &segLen, uint32_t &segFlag, uint32_t &segTokenId) +namespace { +#ifdef URMA_OVER_UB +Status GetUbBondSegInfo(urma_target_seg_t *tseg, urma_bond_seg_info_out_t &segInfoOut) +{ + urma_bond_seg_info_in_t segInfoIn; + segInfoIn.tseg = tseg; + urma_user_ctl_in_t userCtlIn; + userCtlIn.opcode = URMA_USER_CTL_BOND_GET_SEG_INFO; + userCtlIn.addr = (uint64_t)&segInfoIn; + userCtlIn.len = sizeof(urma_bond_seg_info_in_t); + urma_user_ctl_out_t userCtlOut; + userCtlOut.addr = (uint64_t)&segInfoOut; + userCtlOut.len = sizeof(urma_bond_seg_info_out_t); + CHECK_FAIL_RETURN_STATUS(urma_user_ctl(tseg->urma_ctx, &userCtlIn, &userCtlOut) == 0, K_RUNTIME_ERROR, + "Get segment UB bond info failed."); + return Status::OK(); +} + +Status AddUbBondSegInfo(urma_context_t *ctx, urma_bond_add_remote_seg_info_in_t &segInfo) +{ + urma_user_ctl_in_t userCtlIn; + userCtlIn.opcode = URMA_USER_CTL_BOND_ADD_REMOTE_SEG_INFO; + userCtlIn.addr = (uint64_t)&segInfo; + userCtlIn.len = sizeof(urma_bond_add_remote_seg_info_in_t); + urma_user_ctl_out_t userCtlOut{ 0 }; + CHECK_FAIL_RETURN_STATUS(urma_user_ctl(ctx, &userCtlIn, &userCtlOut) == 0, K_RUNTIME_ERROR, + "Failed to add seg info"); + return Status::OK(); +} +#endif +} // namespace + +Status UrmaManager::GetSegmentInfo(const uint64_t &segAddress, const uint64_t &segSize, const uint64_t &shmOffset, + const uint64_t &metaSz, const HostPort &localAddress, UrmaImportSegmentPb &segInfo) { SegmentMap::ConstAccessor constAccessor; RETURN_IF_NOT_OK(GetOrRegisterSegment(segAddress, segSize, constAccessor)); auto &localSegment = constAccessor.entry->data.segment_; - segVA = localSegment->seg.ubva.va; - segLen = localSegment->seg.len; - segFlag = localSegment->seg.attr.value; - segTokenId = localSegment->seg.token_id; + auto segPb = segInfo.mutable_seg(); + UrmaSeg::ToProto(localSegment->seg, *segPb); + LOG(INFO) << "local seg info: " << UrmaSeg::ToString(localSegment->seg); + + if (IsRegisterWholeArenaEnabled()) { + segInfo.set_seg_data_offset(shmOffset + metaSz); + } else { + segInfo.set_seg_data_offset(metaSz); + } + segInfo.mutable_request_address()->set_host(localAddress.Host()); + segInfo.mutable_request_address()->set_port(localAddress.Port()); +#ifdef URMA_OVER_UB + // UB bond prehandling for segment. + if (GetUrmaMode() == UrmaMode::UB) { + UrmaBondSegInfo info; + RETURN_IF_NOT_OK(GetUbBondSegInfo(localSegment.get(), info.raw)); + auto *bondInfo = segInfo.mutable_bond_info(); + info.ToProto(*bondInfo); + LOG(INFO) << "local bond seg info: " << info.ToString(); + } +#endif return Status::OK(); } @@ -411,6 +514,7 @@ Status UrmaManager::GetOrRegisterSegment(const uint64_t &segAddress, const uint6 segmentConfig.flag = registerSegmentFlag_; segmentConfig.user_ctx = (uint64_t)NULL; segmentConfig.iova = 0; + segmentConfig.token_id = nullptr; PerfPoint point(PerfKey::URMA_REGISTER_SEGMENT); auto *segment = urma_register_seg(urmaContext_, &segmentConfig); point.Record(); @@ -487,34 +591,38 @@ Status UrmaManager::CheckAndNotify() void UrmaManager::DeleteEvent(uint64_t requestId) { - std::unique_lock lock(eventMapMutex_); - eventMap_.erase(requestId); + std::shared_lock lock(eventMapMutex_); + EventMap::Accessor accessor; + if (eventMap_->Find(accessor, requestId)) { + eventMap_->BlockingErase(accessor); + } } Status UrmaManager::GetEvent(uint64_t requestId, std::shared_ptr &event) { std::shared_lock lock(eventMapMutex_); - auto eventItr = eventMap_.find(requestId); - if (eventItr != eventMap_.end()) { - event = eventItr->second; + EventMap::Accessor accessor; + if (eventMap_->Find(accessor, requestId)) { + event = accessor.entry->data; return Status::OK(); } - // can happen if event is not yet inserted by sender thread - // so dont log the status + // Can happen if event is not yet inserted by sender thread. RETURN_STATUS(K_NOT_FOUND, FormatString("Request id %d doesnt exist in event map", requestId)); } Status UrmaManager::CreateEvent(uint64_t requestId, std::shared_ptr &event) { - std::unique_lock lock(eventMapMutex_); - auto result = eventMap_.emplace(requestId, std::make_shared(requestId)); - if (result.second) { - event = result.first->second; - return Status::OK(); - } else { - // If this happens that means requestId is duplicated + std::shared_lock lock(eventMapMutex_); + EventMap::Accessor accessor; + auto res = eventMap_->Insert(accessor, requestId); + if (!res) { + // If this happens that means requestId is duplicated. RETURN_STATUS_LOG_ERROR(K_DUPLICATED, FormatString("Request id %d already exists in event map", requestId)); + } else { + event = std::make_shared(requestId); + accessor.entry->data = event; } + return Status::OK(); } Status UrmaManager::WaitToFinish(uint64_t requestId, int64_t timeoutMs) @@ -614,8 +722,9 @@ Status UrmaManager::PollJfcWait(const custom_unique_ptr &jfc, const for (uint64_t i = 0; i < maxTryCount; ++i) { cnt = urma_poll_jfc(urmaJfc, numPollCRS, completeRecords); if (cnt == 0) { - // if there is ntg to poll, just sleep for 10us - usleep(10); + // If there is nothing to poll, just sleep. + // Note that it takes on average 50us to wake up with usleep(0), due to OS timerslack settings. + usleep(0); } else if (cnt < 0) { RETURN_STATUS_LOG_ERROR(K_URMA_ERROR, FormatString("Failed to poll jfc, ret = %d", cnt)); } else if (cnt > 0) { @@ -628,10 +737,12 @@ Status UrmaManager::PollJfcWait(const custom_unique_ptr &jfc, const RETURN_STATUS(K_TRY_AGAIN, FormatString("No Event present in JFC")); } -Status UrmaManager::ImportRemoteJfr(const RpcChannel::UrmaInfo &urmaInfo) +Status UrmaManager::ImportRemoteJfr(const UrmaJfrInfo &urmaInfo) { PerfPoint point1(PerfKey::URMA_CONNECT_WITH_REMOTE_DEVICE); - const std::string remoteDeviceId = urmaInfo.localAddress_.ToString(); + // Do not need to import jfr for the local node. + RETURN_OK_IF_TRUE(localUrmaInfo_.localAddress == urmaInfo.localAddress); + const std::string remoteDeviceId = urmaInfo.localAddress.ToString(); std::shared_lock l(remoteMapMutex_); // Insert or update the import jfr (in case the sending worker restarts) RemoteDeviceMap::Accessor accessor; @@ -643,6 +754,24 @@ Status UrmaManager::ImportRemoteJfr(const RpcChannel::UrmaInfo &urmaInfo) device.Clear(); } device.urmaInfo_ = urmaInfo; + // UB bond handling before import jfr. +#ifdef URMA_OVER_UB + if (GetUrmaMode() == UrmaMode::UB) { + for (auto &bondInfo : urmaInfo.bondInfos) { + urma_user_ctl_in_t userCtlIn; + userCtlIn.opcode = URMA_USER_CTL_BOND_ADD_RJFR_ID_INFO; + userCtlIn.addr = (uint64_t)&bondInfo; + userCtlIn.len = sizeof(urma_bond_add_rjfr_id_info_in_t); + urma_user_ctl_out_t userCtlOut; + userCtlOut.addr = 0; + userCtlOut.len = 0; + if (urma_user_ctl(urmaContext_, &userCtlIn, &userCtlOut)) { + return Status(K_RUNTIME_ERROR, FormatString("Failed to add rjfr info, %s", urmaInfo.ToString())); + } + } + } +#endif + // Now we import a new jfr urma_rjfr_t remoteJfr; urma_eid_t eid; @@ -656,7 +785,7 @@ Status UrmaManager::ImportRemoteJfr(const RpcChannel::UrmaInfo &urmaInfo) remoteJfr.trans_mode = URMA_TM_RM; std::vector tjfrs; for (uint i = 0; i < FLAGS_urma_connection_size; ++i) { - remoteJfr.jfr_id.id = urmaInfo.jfr_ids[i]; + remoteJfr.jfr_id.id = urmaInfo.jfrIds[i]; PerfPoint point1a(PerfKey::URMA_IMPORT_JFR); auto *tjfr = urma_import_jfr(urmaContext_, &remoteJfr, &urmaToken_); point1a.Record(); @@ -686,6 +815,7 @@ Status UrmaManager::ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo // Note that the returned keys only contain the new key(s). keys.clear(); PerfPoint point(PerfKey::URMA_IMPORT_AND_WRITE_PAYLOAD); + auto segVa = urmaInfo.seg().va(); const HostPort requestAddress(urmaInfo.request_address().host(), urmaInfo.request_address().port()); const std::string remoteDeviceId = requestAddress.ToString(); PerfPoint point1(PerfKey::URMA_CONNECT_WITH_REMOTE_DEVICE); @@ -698,7 +828,7 @@ Status UrmaManager::ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo point1.Record(); PerfPoint point2(PerfKey::URMA_IMPORT_REMOTE_SEGMENT); SegmentMap::ConstAccessor remoteSegAccessor; - RETURN_IF_NOT_OK(constAccessor.entry->data.GetOrImportRemoteSeg(urmaInfo, remoteSegAccessor)); + RETURN_IF_NOT_OK(constAccessor.entry->data.GetOrImportRemoteSeg(urmaContext_, urmaInfo, remoteSegAccessor)); point2.Record(); PerfPoint point3(PerfKey::URMA_REGISTER_LOCAL_SEGMENT); @@ -707,7 +837,8 @@ Status UrmaManager::ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo point3.Record(); // Write payload - urma_jfs_wr_flag_t flag = { 0 }; + urma_jfs_wr_flag_t flag; + flag.value = 0; flag.bs.complete_enable = 1; PerfPoint point4(PerfKey::URMA_TOTAL_WRITE); @@ -720,11 +851,10 @@ Status UrmaManager::ImportSegAndWritePayload(const UrmaImportSegmentPb &urmaInfo urma_jfs_t *urmaJfs = urmaJfsVec_[index].get(); urma_target_jetty_t *importJfr = constAccessor.entry->data.importJfrs_[index].get(); PerfPoint point4a(PerfKey::URMA_WRITE); - urma_status_t ret = - urma_write(urmaJfs, importJfr, remoteSegAccessor.entry->data.segment_.get(), - localSegAccessor.entry->data.segment_.get(), - urmaInfo.seg_va() + urmaInfo.seg_data_offset() + readOffset + writtenSize, - localObjectAddress + readOffset + metaDataSize + writtenSize, writeSize, flag, key); + urma_status_t ret = urma_write( + urmaJfs, importJfr, remoteSegAccessor.entry->data.segment_.get(), + localSegAccessor.entry->data.segment_.get(), segVa + urmaInfo.seg_data_offset() + readOffset + writtenSize, + localObjectAddress + readOffset + metaDataSize + writtenSize, writeSize, flag, key); CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( ret == URMA_SUCCESS, K_RUNTIME_ERROR, FormatString("Failed to urma write object with key = %zu, ret = %d", key, ret)); @@ -774,11 +904,6 @@ Status UrmaManager::RemoveRemoteDevice(const HostPort &remoteAddress) remoteAddress.ToString())); } -std::string UrmaManager::EidToStr(const urma_eid_t &eid) -{ - return std::string(reinterpret_cast(eid.raw), URMA_EID_SIZE); -} - Status UrmaManager::StrToEid(const std::string &eid, urma_eid_t &out) { CHECK_FAIL_RETURN_STATUS(eid.size() == URMA_EID_SIZE, K_RUNTIME_ERROR, @@ -794,24 +919,39 @@ Status UrmaManager::ExchangeJfr(const UrmaHandshakeReqPb &req, UrmaHandshakeRspP if (UrmaManager::IsUrmaEnabled()) { auto &mgr = UrmaManager::Instance(); // Register the incoming jfr. - RpcChannel::UrmaInfo urmaInfo; - urmaInfo.eid = req.eid(); - urmaInfo.uasid = req.uasid(); - for (auto jfrId : req.jfr_ids()) { - urmaInfo.jfr_ids.emplace_back(jfrId); - } - urmaInfo.localAddress_ = HostPort(req.address().host(), req.address().port()); - LOG(INFO) << urmaInfo.ToString(); + UrmaJfrInfo urmaInfo; + RETURN_IF_NOT_OK(urmaInfo.FromProto(req)); + LOG(INFO) << "Start import remote jfr, remote urma info: " << urmaInfo.ToString(); LOG_IF_ERROR(mgr.ImportRemoteJfr(urmaInfo), "Error in import incoming jfr"); - // Return our own jfr. - rsp.set_eid(mgr.GetEid()); - rsp.set_uasid(mgr.GetUasid()); - for (auto jfrId : mgr.GetJfrIds()) { - rsp.add_jfr_ids(jfrId); + // Do not need to fill in jfr response for urma_write scenario. + } + return Status::OK(); +} + +#ifdef URMA_OVER_UB +Status UrmaManager::GetJfrInfoForBond(std::vector &infoOut) +{ + auto size = urmaJfrVec_.size(); + infoOut.resize(size); + for (size_t i = 0; i < size; i++) { + // UB bond prehandling for jfr. + urma_bond_id_info_in_t in; + in.jfr = urmaJfrVec_[i].get(); + in.type = URMA_JFR; + urma_user_ctl_in_t userCtlIn; + userCtlIn.opcode = URMA_USER_CTL_BOND_GET_ID_INFO; + userCtlIn.addr = (uint64_t)∈ + userCtlIn.len = sizeof(urma_bond_id_info_in_t); + urma_user_ctl_out_t userCtlOut; + userCtlOut.addr = (uint64_t)&infoOut[i]; + userCtlOut.len = sizeof(urma_bond_id_info_out_t); + if (urma_user_ctl(urmaJfrVec_[i]->urma_ctx, &userCtlIn, &userCtlOut)) { + return Status(K_RUNTIME_ERROR, "Failed to urma user ctl"); } } return Status::OK(); } +#endif Segment::~Segment() { @@ -858,37 +998,42 @@ void RemoteDevice::SetJfrs(std::vector &jetties) } } -Status RemoteDevice::GetOrImportRemoteSeg(const UrmaImportSegmentPb &importSegmentInfo, +Status RemoteDevice::GetOrImportRemoteSeg(urma_context_t *urmaContext, const UrmaImportSegmentPb &importSegmentInfo, SegmentMap::ConstAccessor &constAccessor) { - if (!remoteSegments_.Find(constAccessor, importSegmentInfo.seg_va())) { + auto segVa = importSegmentInfo.seg().va(); + if (!remoteSegments_.Find(constAccessor, segVa)) { SegmentMap::Accessor accessor; - if (remoteSegments_.Insert(accessor, importSegmentInfo.seg_va())) { - urma_seg_t remoteSegment; - urma_eid_t eid; - auto rc = UrmaManager::StrToEid(urmaInfo_.eid, eid); - if (rc.IsError()) { - remoteSegments_.BlockingErase(accessor); - return rc; - } - remoteSegment.ubva.eid = eid; - remoteSegment.ubva.uasid = urmaInfo_.uasid; - remoteSegment.ubva.va = importSegmentInfo.seg_va(); - remoteSegment.len = importSegmentInfo.seg_len(); - remoteSegment.attr.value = importSegmentInfo.seg_flag(); - remoteSegment.token_id = importSegmentInfo.seg_token_id(); - auto *segment = UrmaManager::Instance().ImportSegment(remoteSegment); - if (segment == nullptr) { - remoteSegments_.BlockingErase(accessor); - return Status(K_RUNTIME_ERROR, - FormatString("Failed to import segment from device with eid %s, seg_va = %zu.", - urmaInfo_.eid, importSegmentInfo.seg_va())); + if (remoteSegments_.Insert(accessor, segVa)) { + bool needErase = true; + Raii eraseSegment([this, &accessor, &needErase]() { + if (needErase) { + remoteSegments_.BlockingErase(accessor); + } + }); + // UB bond handling before import segment. +#ifdef URMA_OVER_UB + if (GetUrmaMode() == UrmaMode::UB) { + UrmaBondSegInfo info; + RETURN_IF_NOT_OK(info.FromProto(importSegmentInfo.bond_info())); + LOG(INFO) << "add bond remote seg info: " << info.ToString(); + RETURN_IF_NOT_OK(AddUbBondSegInfo(urmaContext, info.raw)); } +#endif + + // Import segment + UrmaSeg remoteSegment; + RETURN_IF_NOT_OK(remoteSegment.FromProto(importSegmentInfo.seg())); + LOG(INFO) << "import remote seg info: " << remoteSegment.ToString(); + auto *segment = UrmaManager::Instance().ImportSegment(remoteSegment.raw); + CHECK_FAIL_RETURN_STATUS(segment != nullptr, K_RUNTIME_ERROR, + FormatString("Failed to import segment %s.", remoteSegment.ToString())); accessor.entry->data.Set(segment, false); + needErase = false; } accessor.Release(); // Switch to const accessor so it does not block the others. - CHECK_FAIL_RETURN_STATUS(remoteSegments_.Find(constAccessor, importSegmentInfo.seg_va()), K_RUNTIME_ERROR, + CHECK_FAIL_RETURN_STATUS(remoteSegments_.Find(constAccessor, segVa), K_RUNTIME_ERROR, "Failed to operate on remote segment map."); } return Status::OK(); @@ -908,5 +1053,4 @@ Status RemoteDevice::UnimportRemoteSeg(const uint64_t segmentAddress) } RETURN_STATUS(K_NOT_FOUND, "Cannot unimport remote segment, remote segment is not imported"); } - } // namespace datasystem diff --git a/src/datasystem/common/rdma/urma_manager.h b/src/datasystem/common/rdma/urma_manager.h index 0da12b4..d39837b 100644 --- a/src/datasystem/common/rdma/urma_manager.h +++ b/src/datasystem/common/rdma/urma_manager.h @@ -25,12 +25,17 @@ #include #include -#include "urma_api.h" +#include +#ifdef URMA_OVER_UB +#include +#endif + #include "datasystem/common/flags/flags.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/rdma/urma_info.h" #include "datasystem/common/rpc/rpc_channel.h" #include "datasystem/common/util/lock_map.h" #include "datasystem/common/util/net_util.h" -#include "datasystem/common/perf/perf_manager.h" #include "datasystem/protos/meta_zmq.pb.h" #include "datasystem/protos/utils.pb.h" #include "datasystem/utils/status.h" @@ -99,12 +104,14 @@ public: void SetJfrs(std::vector &jetties); /** - * @brief Get remote segment or import remote segment from the device - * @param[in] UrmaImportSegmentPb Pb with remote segment info - * @param[out] constAccessor Accessor in segment table + * @brief Get remote segment or import remote segment from the device. + * @param[in] urmaContext The urma context. + * @param[in] UrmaImportSegmentPb Pb with remote segment info. + * @param[out] constAccessor Accessor in segment table. * @return Status of the call. */ - Status GetOrImportRemoteSeg(const UrmaImportSegmentPb &urmaInfo, SegmentMap::ConstAccessor &constAccessor); + Status GetOrImportRemoteSeg(urma_context_t *urmaContext, const UrmaImportSegmentPb &urmaInfo, + SegmentMap::ConstAccessor &constAccessor); /** * @brief Unimport a remote segment @@ -117,7 +124,7 @@ public: * @brief Clears all remote Jfrs */ void Clear(); - RpcChannel::UrmaInfo urmaInfo_; + UrmaJfrInfo urmaInfo_; std::vector> importJfrs_; SegmentMap remoteSegments_; }; @@ -184,6 +191,8 @@ private: bool failed_{ false }; }; +using EventMap = LockMap>; + class UrmaManager { public: /** @@ -196,10 +205,10 @@ public: /** * @brief Init a Urma device - * @param[in] deviceName + * @param[in] hostport * @return Status of the call. */ - Status Init(const std::string &deviceName); + Status Init(const HostPort &hostport); /** * @brief Check if Urma worker flag is set @@ -252,18 +261,17 @@ public: Status RegisterSegment(const uint64_t &segAddress, const uint64_t &segSize); /** - * @brief Gets the segment if present or - * Registers the segment if address is not already registered - * @param[in] segAddress Starting address of the segment - * @param[in] segSize Size of the segment - * @param[out] segVA UB virtual address of the segment - * @param[out] segLen Size of the segment (==segSize) - * @param[out] segFlag Flags set for the segment - * @param[out] segTokenId Token provided for the segment + * @brief Fill segment info. Register the segment if not already registered. + * @param[in] segAddress Starting address of the segment. + * @param[in] segSize Size of the segment. + * @param[in] shmOffset The shared memory offset of the object. + * @param[in] metaSz The size of the shared memory metadata size. + * @param[in] localAddress The local worker hostport. + * @param[out] segInfo The urma segment info for import purposes. * @return Status of the call. */ - Status GetSegmentInfo(const uint64_t &segAddress, const uint64_t &segSize, uint64_t &segVA, uint64_t &segLen, - uint32_t &segFlag, uint32_t &segTokenId); + Status GetSegmentInfo(const uint64_t &segAddress, const uint64_t &segSize, const uint64_t &shmOffset, + const uint64_t &metaSz, const HostPort &localAddress, UrmaImportSegmentPb &segInfo); /** * @brief Does a RDMA write to remote worker memory location @@ -298,7 +306,7 @@ public: * @param[in] urmaInfo local urma device info * @return Status of the call */ - Status ImportRemoteJfr(const RpcChannel::UrmaInfo &urmaInfo); + Status ImportRemoteJfr(const UrmaJfrInfo &urmaInfo); /** * @brief Import segment @@ -320,7 +328,10 @@ public: * @param[in] eid Urma device eid object * @return String */ - static std::string EidToStr(const urma_eid_t &eid); + static std::string EidToStr(const urma_eid_t &eid) + { + return std::string(reinterpret_cast(eid.raw), URMA_EID_SIZE); + } /** * @brief Converts a valid string to Urma Eid @@ -338,6 +349,20 @@ public: */ Status ExchangeJfr(const UrmaHandshakeReqPb &req, UrmaHandshakeRspPb &rsp); +#ifdef URMA_OVER_UB + /** + * @brief Get the jfr info for UB bond purposes. + * @param[out] infoOut A vector of bond info for each created jfr. + * @return Status of the call. + */ + Status GetJfrInfoForBond(std::vector &infoOut); +#endif + + const UrmaJfrInfo &GetLocalUrmaInfo() + { + return localUrmaInfo_; + } + private: UrmaManager(); @@ -524,6 +549,8 @@ private: */ void DeleteEvent(uint64_t requestId); + Status InitLocalUrmaInfo(const HostPort &hostport); + // Polling thread std::unique_ptr serverEventThread_{ nullptr }; @@ -538,6 +565,7 @@ private: uint32_t JETTY_SIZE_ = 256; urma_reg_seg_flag_t registerSegmentFlag_; urma_import_seg_flag_t importSegmentFlag_; + UrmaJfrInfo localUrmaInfo_; // protect for segment maps. mutable std::shared_timed_mutex localMapMutex_; @@ -547,7 +575,7 @@ private: // Eid to segment maps mapping for remote jfr and segment. std::unique_ptr remoteDeviceMap_; mutable std::shared_timed_mutex eventMapMutex_; - std::unordered_map> eventMap_; + std::unique_ptr eventMap_; std::unordered_set finishedRequests_; std::unordered_set failedRequests_; std::atomic serverStop_{ false }; diff --git a/src/datasystem/common/rdma/urma_manager_wrapper.cpp b/src/datasystem/common/rdma/urma_manager_wrapper.cpp index e36fe8c..cd5dadd 100644 --- a/src/datasystem/common/rdma/urma_manager_wrapper.cpp +++ b/src/datasystem/common/rdma/urma_manager_wrapper.cpp @@ -26,12 +26,12 @@ bool IsUrmaEnabled() #endif } -Status InitializeUrmaManager(const std::string &host) +Status InitializeUrmaManager(const HostPort &hostport) { - (void)host; + (void)hostport; #ifdef USE_URMA if (UrmaManager::IsUrmaEnabled()) { - RETURN_IF_NOT_OK(UrmaManager::Instance().Init(host)); + RETURN_IF_NOT_OK(UrmaManager::Instance().Init(hostport)); } #endif return Status::OK(); @@ -110,25 +110,13 @@ Status FillUrmaInfo(std::shared_ptr shmUnit, const HostPort &localAddre (void)metaSz; (void)urmaInfo; #ifdef USE_URMA - uint64_t segAddress; - uint64_t segSize; - GetSegmentInfoFromShmUnit(shmUnit, reinterpret_cast(shmUnit->GetPointer()), segAddress, segSize); - uint64_t segVA; - uint64_t segLen; - uint32_t segFlag; - uint32_t segTokenId; - RETURN_IF_NOT_OK(UrmaManager::Instance().GetSegmentInfo(segAddress, segSize, segVA, segLen, segFlag, segTokenId)); - urmaInfo.set_seg_va(segVA); - urmaInfo.set_seg_len(segLen); - urmaInfo.set_seg_flag(segFlag); - urmaInfo.set_seg_token_id(segTokenId); - if (UrmaManager::IsRegisterWholeArenaEnabled()) { - urmaInfo.set_seg_data_offset(shmUnit->GetOffset() + metaSz); - } else { - urmaInfo.set_seg_data_offset(metaSz); + if (UrmaManager::IsUrmaEnabled()) { + uint64_t segAddress; + uint64_t segSize; + GetSegmentInfoFromShmUnit(shmUnit, reinterpret_cast(shmUnit->GetPointer()), segAddress, segSize); + RETURN_IF_NOT_OK(UrmaManager::Instance().GetSegmentInfo(segAddress, segSize, shmUnit->GetOffset(), metaSz, + localAddress, urmaInfo)); } - urmaInfo.mutable_request_address()->set_host(localAddress.Host()); - urmaInfo.mutable_request_address()->set_port(localAddress.Port()); #endif return Status::OK(); } diff --git a/src/datasystem/common/rdma/urma_manager_wrapper.h b/src/datasystem/common/rdma/urma_manager_wrapper.h index ed46460..f612c14 100644 --- a/src/datasystem/common/rdma/urma_manager_wrapper.h +++ b/src/datasystem/common/rdma/urma_manager_wrapper.h @@ -42,7 +42,7 @@ bool IsUrmaEnabled(); * @param[in] deviceName. * @return Status of the call. */ -Status InitializeUrmaManager(const std::string &host); +Status InitializeUrmaManager(const HostPort &hostport); /** * @brief Remove Remote Device and all associated segments diff --git a/src/datasystem/common/rdma/urma_stub.cpp b/src/datasystem/common/rdma/urma_stub.cpp index 40d3ee5..81b16e8 100644 --- a/src/datasystem/common/rdma/urma_stub.cpp +++ b/src/datasystem/common/rdma/urma_stub.cpp @@ -34,7 +34,7 @@ std::vector __attribute__((weak)) UrmaManager::GetJfrIds() return std::vector(); } -Status __attribute__((weak)) UrmaManager::ImportRemoteJfr(const RpcChannel::UrmaInfo &urmaInfo) +Status __attribute__((weak)) UrmaManager::ImportRemoteJfr(const UrmaJfrInfo &urmaInfo) { (void)urmaInfo; return Status::OK(); diff --git a/src/datasystem/common/rpc/CMakeLists.txt b/src/datasystem/common/rpc/CMakeLists.txt index 15d94a7..c4013b3 100644 --- a/src/datasystem/common/rpc/CMakeLists.txt +++ b/src/datasystem/common/rpc/CMakeLists.txt @@ -102,6 +102,8 @@ set(RPC_STUB_CACHE_MGR_SRCS set(RPC_STUB_CACHE_MGR_DEPENDS_LIBS master_object_protos worker_object_protos + master_stream_protos + worker_stream_protos ) add_library(rpc_stub_cache_mgr STATIC ${RPC_STUB_CACHE_MGR_SRCS}) diff --git a/src/datasystem/common/rpc/rpc_channel.cpp b/src/datasystem/common/rpc/rpc_channel.cpp index 3aba6ce..353d066 100644 --- a/src/datasystem/common/rpc/rpc_channel.cpp +++ b/src/datasystem/common/rpc/rpc_channel.cpp @@ -95,64 +95,4 @@ const HostPort &RpcChannel::GetHostPort() const { return destAddr_; } - -void RpcChannel::SetLocalInfo(const HostPort &localAddress) -{ - (void)localAddress; -#ifdef USE_URMA - if (UrmaManager::IsUrmaEnabled()) { - auto &mgr = UrmaManager::Instance(); - localUrmaInfo_ = std::make_unique(); - localUrmaInfo_->eid = mgr.GetEid(); - localUrmaInfo_->uasid = mgr.GetUasid(); - localUrmaInfo_->jfr_ids = mgr.GetJfrIds(); - localUrmaInfo_->localAddress_ = localAddress; - LOG(INFO) << localUrmaInfo_->ToString(); - } -#endif -} - -#ifdef USE_URMA -void RpcChannel::GetLocalUrmaInfo(RpcChannel::UrmaInfo &out) const -{ - if (localUrmaInfo_) { - out.eid = localUrmaInfo_->eid; - out.uasid = localUrmaInfo_->uasid; - out.jfr_ids = localUrmaInfo_->jfr_ids; - out.localAddress_ = localUrmaInfo_->localAddress_; - } -} - -std::string RpcChannel::UrmaInfo::ToString() const -{ - std::stringstream oss; - oss << localAddress_.ToString() << " urma info. eid "; - // eid is not really printable as a string. So we will dump its context in hex - urma_eid_t e; - if (UrmaManager::StrToEid(eid, e).IsOk()) { - char s[URMA_EID_STR_LEN + 1]; - int ret = sprintf_s(s, URMA_EID_STR_LEN + 1, EID_FMT, EID_ARGS(e)); - if (ret == -1) { - LOG(WARNING) << "sprintf_s eid in EID_FMT failed"; - oss << eid; - } else { - oss << s; - } - } else { - oss << eid; - } - oss << " uasid " << uasid << " jfr_id ["; - bool first = true; - for (auto jfr_id : jfr_ids) { - if (first) { - first = false; - } else { - oss << " "; - } - oss << jfr_id; - } - oss << "]"; - return oss.str(); -} -#endif } // namespace datasystem diff --git a/src/datasystem/common/rpc/rpc_channel.h b/src/datasystem/common/rpc/rpc_channel.h index 6e595f1..52c7b38 100644 --- a/src/datasystem/common/rpc/rpc_channel.h +++ b/src/datasystem/common/rpc/rpc_channel.h @@ -31,17 +31,6 @@ namespace datasystem { class RpcChannel { public: -#ifdef USE_URMA - struct UrmaInfo { - std::string eid; - uint32_t uasid{ 0 }; - std::vector jfr_ids; - HostPort localAddress_; - - std::string ToString() const; - }; -#endif - /** * @brief This form of constructor takes a ZMQ transport directly. * @note A ZMQ transport begins with tcpip:// or ipc:// or inproc://. @@ -132,24 +121,6 @@ public: */ size_t GetServiceConnectPoolSize(const std::string &svcName); - /** - * @brief Set up local address info (for urma purposes). - * @param[in] localAddress Local host address. - */ - void SetLocalInfo(const HostPort &localAddress); - -#ifdef USE_URMA - /** - * @brief Get local jfr info from the channel - */ - void GetLocalUrmaInfo(UrmaInfo &out) const; - - bool UrmaEnabled() const - { - return localUrmaInfo_ != nullptr; - } -#endif - private: std::string endPoint_; RpcCredential cred_; @@ -158,9 +129,6 @@ private: std::map tcpDirect_; std::map connectPoolSize_; const HostPort destAddr_; -#ifdef USE_URMA - std::unique_ptr localUrmaInfo_{ nullptr }; -#endif }; } // namespace datasystem diff --git a/src/datasystem/common/rpc/rpc_stub_cache_mgr.cpp b/src/datasystem/common/rpc/rpc_stub_cache_mgr.cpp index 39fdc64..43e4338 100644 --- a/src/datasystem/common/rpc/rpc_stub_cache_mgr.cpp +++ b/src/datasystem/common/rpc/rpc_stub_cache_mgr.cpp @@ -21,10 +21,15 @@ #include "datasystem/common/util/status_helper.h" #include "datasystem/protos/worker_object.stub.rpc.pb.h" #include "datasystem/protos/master_object.stub.rpc.pb.h" +#include "datasystem/protos/stream_posix.stub.rpc.pb.h" +#include "datasystem/protos/master_stream.stub.rpc.pb.h" +#include "datasystem/protos/worker_stream.stub.rpc.pb.h" DS_DEFINE_int32(oc_worker_worker_pool_size, 3, "Number of parallel connections between worker/worker. Default is 3."); +DS_DEFINE_int32(sc_worker_worker_pool_size, 3, "Number of parallel connections between worker/worker. Default is 3."); DS_DECLARE_int32(oc_worker_worker_direct_port); +DS_DECLARE_int32(sc_worker_worker_direct_port); DS_DECLARE_uint32(node_timeout_s); namespace datasystem { @@ -54,9 +59,18 @@ Status RpcStubCacheMgr::CreateRpcStub(StubType type, const std::shared_ptr(channel, FLAGS_node_timeout_s * TO_MILLISECOND); break; + case StubType::WORKER_WORKER_SC_SVC: + stub = std::make_shared(channel); + break; + case StubType::WORKER_MASTER_SC_SVC: + stub = std::make_shared(channel); + break; case StubType::MASTER_WORKER_OC_SVC: stub = std::make_shared(channel); break; + case StubType::MASTER_WORKER_SC_SVC: + stub = std::make_shared(channel); + break; case StubType::MASTER_MASTER_OC_SVC: stub = std::make_shared(channel); break; @@ -89,19 +103,20 @@ bool RpcStubCacheMgr::EnableOcWorkerWorkerDirectPort() return FLAGS_oc_worker_worker_direct_port > 0; } +bool RpcStubCacheMgr::EnableScWorkerWorkerDirectPort() +{ + return FLAGS_sc_worker_worker_direct_port > 0; +} void RpcStubCacheMgr::InitCreators() { creators_.emplace( - StubType::WORKER_WORKER_OC_SVC, - [&localAddress = localAddress_](const HostPort &hostPort, std::shared_ptr &rpcStub) { + StubType::WORKER_WORKER_OC_SVC, [](const HostPort &hostPort, std::shared_ptr &rpcStub) { return CreatorTemplate( - [&hostPort, &localAddress](std::shared_ptr &channel) { + [&hostPort](std::shared_ptr &channel) { RETURN_IF_NOT_OK(CreateRpcChannel( hostPort, EnableOcWorkerWorkerDirectPort() ? WorkerWorkerOCService_Stub::FullServiceName() : "", channel, FLAGS_oc_worker_worker_pool_size)); - // Set local address info for URMA purposes. - channel->SetLocalInfo(localAddress); return Status::OK(); }, StubType::WORKER_WORKER_OC_SVC, rpcStub); @@ -112,12 +127,34 @@ void RpcStubCacheMgr::InitCreators() [&hostPort](std::shared_ptr &channel) { return CreateRpcChannel(hostPort, "", channel); }, StubType::WORKER_MASTER_OC_SVC, rpcStub); }); + creators_.emplace( + StubType::WORKER_WORKER_SC_SVC, [](const HostPort &hostPort, std::shared_ptr &rpcStub) { + return CreatorTemplate( + [&hostPort](std::shared_ptr &channel) { + return CreateRpcChannel( + hostPort, EnableScWorkerWorkerDirectPort() ? WorkerWorkerSCService_Stub::FullServiceName() : "", + channel, FLAGS_sc_worker_worker_pool_size); + }, + StubType::WORKER_WORKER_SC_SVC, rpcStub); + }); + creators_.emplace( + StubType::WORKER_MASTER_SC_SVC, [](const HostPort &hostPort, std::shared_ptr &rpcStub) { + return CreatorTemplate( + [&hostPort](std::shared_ptr &channel) { return CreateRpcChannel(hostPort, "", channel); }, + StubType::WORKER_MASTER_SC_SVC, rpcStub); + }); creators_.emplace( StubType::MASTER_WORKER_OC_SVC, [](const HostPort &hostPort, std::shared_ptr &rpcStub) { return CreatorTemplate( [&hostPort](std::shared_ptr &channel) { return CreateRpcChannel(hostPort, "", channel); }, StubType::MASTER_WORKER_OC_SVC, rpcStub); }); + creators_.emplace( + StubType::MASTER_WORKER_SC_SVC, [](const HostPort &hostPort, std::shared_ptr &rpcStub) { + return CreatorTemplate( + [&hostPort](std::shared_ptr &channel) { return CreateRpcChannel(hostPort, "", channel); }, + StubType::MASTER_WORKER_SC_SVC, rpcStub); + }); creators_.emplace( StubType::MASTER_MASTER_OC_SVC, [](const HostPort &hostPort, std::shared_ptr &rpcStub) { return CreatorTemplate( @@ -185,6 +222,9 @@ StubPriority GetStubPriority(StubType type) return StubPriority::LOW; case StubType::WORKER_WORKER_OC_SVC: case StubType::WORKER_MASTER_OC_SVC: + case StubType::WORKER_WORKER_SC_SVC: + case StubType::WORKER_MASTER_SC_SVC: + case StubType::MASTER_WORKER_SC_SVC: case StubType::MASTER_MASTER_OC_SVC: return StubPriority::HIGH; #ifdef WITH_TESTS diff --git a/src/datasystem/common/rpc/rpc_stub_cache_mgr.h b/src/datasystem/common/rpc/rpc_stub_cache_mgr.h index 93296de..c6b9b67 100644 --- a/src/datasystem/common/rpc/rpc_stub_cache_mgr.h +++ b/src/datasystem/common/rpc/rpc_stub_cache_mgr.h @@ -48,6 +48,9 @@ enum class StubType : int { WORKER_WORKER_OC_SVC = 0, WORKER_MASTER_OC_SVC = 1, MASTER_WORKER_OC_SVC = 2, + WORKER_WORKER_SC_SVC = 3, + WORKER_MASTER_SC_SVC = 4, + MASTER_WORKER_SC_SVC = 5, MASTER_MASTER_OC_SVC = 6, #ifdef WITH_TESTS TEST_TYPE_1 = 1000, diff --git a/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp b/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp index 87e3971..926af62 100644 --- a/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp +++ b/src/datasystem/common/rpc/zmq/zmq_stub_conn.cpp @@ -100,18 +100,10 @@ Status ZmqFrontend::Init() Status ZmqFrontend::ExchangeJfr() { // Exchange jfr for this channel if needed - if (channel_->UrmaEnabled()) { + if (UrmaManager::IsUrmaEnabled()) { // Send our own jfr - RpcChannel::UrmaInfo localUrmaInfo; - channel_->GetLocalUrmaInfo(localUrmaInfo); UrmaHandshakeReqPb rq; - rq.set_eid(localUrmaInfo.eid); - rq.set_uasid(localUrmaInfo.uasid); - for (auto jfrId : localUrmaInfo.jfr_ids) { - rq.add_jfr_ids(jfrId); - } - rq.mutable_address()->set_host(localUrmaInfo.localAddress_.Host()); - rq.mutable_address()->set_port(localUrmaInfo.localAddress_.Port()); + UrmaManager::Instance().GetLocalUrmaInfo().ToProto(rq); MetaPb meta = CreateMetaData("", ZMQ_EXCHANGE_JFR_METHOD, ZMQ_INVALID_PAYLOAD_INX, GetStringUuid()); ZmqMsgFrames p; RETURN_IF_NOT_OK(PushFrontProtobufToFrames(meta, p)); @@ -127,16 +119,8 @@ Status ZmqFrontend::ExchangeJfr() reply.pop_front(); // Status UrmaHandshakeRspPb rsp; RETURN_IF_NOT_OK(ParseFromZmqMessage(reply.front(), rsp)); - // Import into UrmaManager - RpcChannel::UrmaInfo remoteUrmaInfo; - remoteUrmaInfo.eid = rsp.eid(); - remoteUrmaInfo.uasid = rsp.uasid(); - for (auto jfrId : rsp.jfr_ids()) { - remoteUrmaInfo.jfr_ids.emplace_back(jfrId); - } - remoteUrmaInfo.localAddress_ = channel_->GetHostPort(); - LOG(INFO) << remoteUrmaInfo.ToString(); - RETURN_IF_NOT_OK(UrmaManager::Instance().ImportRemoteJfr(remoteUrmaInfo)); + // Response does not need to be processed, + // the stub side does not need to import the remote jfr in urma_write scenario. } return Status::OK(); } diff --git a/src/datasystem/common/shared_memory/allocator.cpp b/src/datasystem/common/shared_memory/allocator.cpp index 680438b..9e9bf4c 100644 --- a/src/datasystem/common/shared_memory/allocator.cpp +++ b/src/datasystem/common/shared_memory/allocator.cpp @@ -62,15 +62,17 @@ Allocator::~Allocator() noexcept LOG(INFO) << "Allocator destructor."; } -Status Allocator::InitSharedMemory(uint64_t size, int objectThreshold) +Status Allocator::InitSharedMemory(uint64_t size, int objectThreshold, int streamThreshold) { CHECK_FAIL_RETURN_STATUS((size > 0) && (size < UINT64_MAX / HUNDRED_PERCENT), K_INVALID, "the memory size should be greater than 0 and less than UINT64_MAX/100"); CHECK_FAIL_RETURN_STATUS( - (objectThreshold > 0 && objectThreshold <= HUNDRED_PERCENT), K_INVALID, - "the allocation threshold percentage should be greater than 0 and less than or equal to 100"); + (objectThreshold > 0 && objectThreshold <= HUNDRED_PERCENT) + && (streamThreshold > 0 && streamThreshold <= HUNDRED_PERCENT), + K_INVALID, "the allocation threshold percentage should be greater than 0 and less than or equal to 100"); physicalMemoryStats_ = std::make_unique(size); objectMemoryStats_ = std::make_unique((size * objectThreshold) / HUNDRED_PERCENT); + streamMemoryStats_ = std::make_unique((size * streamThreshold) / HUNDRED_PERCENT); return Status::OK(); } @@ -107,9 +109,9 @@ bool Allocator::IsDiskAvailable() } Status Allocator::Init(uint64_t shmSize, uint64_t shdSize, bool populate, bool scaling, ssize_t decayMs, - int objectThreshold) + int objectThreshold, int streamThreshold) { - RETURN_IF_NOT_OK(InitSharedMemory(shmSize, objectThreshold)); + RETURN_IF_NOT_OK(InitSharedMemory(shmSize, objectThreshold, streamThreshold)); RETURN_IF_NOT_OK(InitSharedDisk(shdSize)); if (arenaManager_) { @@ -175,8 +177,11 @@ uint64_t Allocator::GetMaxMemoryLimit(CacheType cacheType) const } } -ResourcePool *Allocator::GetResourcePoolByType(CacheType cacheType) const +ResourcePool *Allocator::GetResourcePoolByType(ServiceType serviceType, CacheType cacheType) const { + if (serviceType == ServiceType::STREAM) { + return streamMemoryStats_.get(); + } switch (cacheType) { case CacheType::DISK: return diskStats_.get(); @@ -208,7 +213,7 @@ void Allocator::Shutdown() } Status Allocator::AllocateMemory(const std::string &tenantId, uint64_t needSize, bool populate, void *&pointer, int &fd, - ptrdiff_t &offset, uint64_t &mmapSize, CacheType cacheType) + ptrdiff_t &offset, uint64_t &mmapSize, ServiceType serviceType, CacheType cacheType) { RETURN_RUNTIME_ERROR_IF_NULL(arenaManager_); INJECT_POINT("worker.Allocator.AllocateMemory"); @@ -221,15 +226,15 @@ Status Allocator::AllocateMemory(const std::string &tenantId, uint64_t needSize, } } - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(IncrementMemoryUsage(needSize, cacheType), "ADD failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(IncrementMemoryUsage(needSize, serviceType, cacheType), "ADD failed"); std::shared_ptr arenaGroup; Status rc = arenaManager_->GetOrCreateArenaGroup({ tenantId, cacheType }, GetMaxMemoryLimit(cacheType), arenaGroup); uint64_t realSize; if (rc.IsOk()) { RETURN_RUNTIME_ERROR_IF_NULL(arenaGroup); - rc = arenaGroup->AllocateMemory(needSize, populate, realSize, pointer, fd, offset, mmapSize); + rc = arenaGroup->AllocateMemory(needSize, populate, realSize, pointer, fd, offset, mmapSize, serviceType); } - auto stats = GetResourcePoolByType(cacheType); + auto stats = GetResourcePoolByType(serviceType, cacheType); if (rc.IsError()) { stats->SubUsage(needSize); return rc; @@ -247,7 +252,7 @@ Status Allocator::AllocateMemory(const std::string &tenantId, uint64_t needSize, return Status::OK(); } -Status Allocator::IncrementMemoryUsage(uint64_t needSize, CacheType cacheType) +Status Allocator::IncrementMemoryUsage(uint64_t needSize, ServiceType serviceType, CacheType cacheType) { if (cacheType == CacheType::DISK) { return diskStats_->AddUsageCAS(needSize); @@ -256,23 +261,34 @@ Status Allocator::IncrementMemoryUsage(uint64_t needSize, CacheType cacheType) } else if (cacheType == CacheType::DEV_HOST) { return devHostMemStats_->AddUsageCAS(needSize); } - - return objectMemoryStats_->AddUsageCAS(needSize, physicalMemoryStats_->FootprintLimit()); + INJECT_POINT("worker.Allocator.MemoryAllocatedToStream", [this](int streamMemoryUsage) { + streamMemoryStats_->SetUsage(streamMemoryUsage); + streamMemoryStats_->SetRealUsage(streamMemoryUsage); + return Status::OK(); + }); + if (serviceType == ServiceType::OBJECT) { + return objectMemoryStats_->AddUsageCAS( + needSize, physicalMemoryStats_->FootprintLimit() - streamMemoryStats_->RealUsage()); + } else { + return streamMemoryStats_->AddUsageCAS( + needSize, physicalMemoryStats_->FootprintLimit() - objectMemoryStats_->RealUsage()); + } + return Status::OK(); } -Status Allocator::FreeMemory(void *&pointer) +Status Allocator::FreeMemory(void *&pointer, ServiceType type) { - return FreeMemory(DEFAULT_TENANT_ID, pointer); + return FreeMemory(DEFAULT_TENANT_ID, pointer, type); } -Status Allocator::FreeMemory(const std::string &tenantId, void *&pointer, CacheType cacheType) +Status Allocator::FreeMemory(const std::string &tenantId, void *&pointer, ServiceType serviceType, CacheType cacheType) { RETURN_RUNTIME_ERROR_IF_NULL(arenaManager_); std::shared_ptr arenaGroup; uint64_t bytesFree = 0; uint64_t bytesRealFree = 0; RETURN_IF_NOT_OK(arenaManager_->GetArenaGroup({ tenantId, cacheType }, arenaGroup)); - auto stats = GetResourcePoolByType(cacheType); + auto stats = GetResourcePoolByType(serviceType, cacheType); RETURN_IF_NOT_OK(arenaGroup->FreeMemory(pointer, bytesFree, bytesRealFree, stats->Usage())); if (arenaGroup->GetMemoryUsage() == 0) { @@ -287,9 +303,9 @@ Status Allocator::FreeMemory(const std::string &tenantId, void *&pointer, CacheT return Status::OK(); } -uint64_t Allocator::GetMaxMemorySize(CacheType cacheType) const +uint64_t Allocator::GetMaxMemorySize(ServiceType serviceType, CacheType cacheType) const { - return GetResourcePoolByType(cacheType)->FootprintLimit(); + return GetResourcePoolByType(serviceType, cacheType)->FootprintLimit(); } uint64_t Allocator::GetMemoryUsage(const std::string &tenantId, CacheType cacheType) @@ -322,8 +338,8 @@ Status Allocator::FdToPointer(const ArenaGroupKey &key, int fd, std::pairUsage(); - shmMemStat.realMemoryUsage = objectMemoryStats_->RealUsage(); + shmMemStat.memoryUsage = objectMemoryStats_->Usage() + streamMemoryStats_->Usage(); + shmMemStat.realMemoryUsage = objectMemoryStats_->RealUsage() + streamMemoryStats_->RealUsage(); shmMemStat.objectMemoryUsage = objectMemoryStats_->Usage(); shmMemStat.physicalMemoryUsage = GetTotalPhysicalMemoryUsage(); shmMemStat.numOfFds = arenaManager_->GetArenaCounts(); @@ -358,7 +374,8 @@ uint64_t Allocator::GetTotalPhysicalMemoryUsage(CacheType cacheType) physicalMemoryStats_->SetRealUsage(usage); return 0; }); - return physicalMemoryStats_->GetOrUpdateRealUsage(objectMemoryStats_->RealUsage()); + return physicalMemoryStats_->GetOrUpdateRealUsage(objectMemoryStats_->RealUsage() + + streamMemoryStats_->RealUsage()); } bool Allocator::AddTotalPhysicalMemoryUsage(CacheType type, uint64_t size) diff --git a/src/datasystem/common/shared_memory/allocator.h b/src/datasystem/common/shared_memory/allocator.h index 2792ed2..b1f078e 100644 --- a/src/datasystem/common/shared_memory/allocator.h +++ b/src/datasystem/common/shared_memory/allocator.h @@ -57,10 +57,11 @@ public: * @param[in] scaling Shared memory need scaling or not. * @param[in] decayMs Decay clean dirty pages milliseconds. * @param[in] objectThreshold A limit to restrict the memory usage of object cache / kv service. + * @param[in] streamThreshold A limit to restrict the memory usage of stream cache service. * @return Status of the call. */ Status Init(uint64_t shmSize, uint64_t shdSize = 0, bool populate = false, bool scaling = true, - ssize_t decayMs = 5'000, int objectThreshold = 100); + ssize_t decayMs = 5'000, int objectThreshold = 100, int streamThreshold = 100); /** * @brief Pre allocate device memory size. The method will create a devHost mem and devDevice mem. @@ -111,10 +112,11 @@ public: /** * @brief Increase the memory usage for the given service type. * @param[in] needSize Memory size to be allocated in bytes. + * @param[in] serviceType The type of datasystem service for which memory usage is increased. * @param[in] cacheType The cache type. * @return Status of the call. */ - Status IncrementMemoryUsage(uint64_t needSize, CacheType cacheType); + Status IncrementMemoryUsage(uint64_t needSize, ServiceType serviceType, CacheType cacheType); /** * @brief Allocate memory from shared memory for the specific tenant. @@ -125,18 +127,21 @@ public: * @param[out] fd File descriptor of the allocated shared memory segments. * @param[out] offset Offset from the base of the shared memory mmap. * @param[out] mmapSize Total size of shared memory segments. + * @param[in] serviceType The type of datasystem service for this allocation request. * @param[in] cacheType The cache type, either MEMORY or DISK. * @return Status of the call. */ Status AllocateMemory(const std::string &tenantId, uint64_t needSize, bool populate, void *&pointer, int &fd, - ptrdiff_t &offset, uint64_t &mmapSize, CacheType cacheType = CacheType::MEMORY); + ptrdiff_t &offset, uint64_t &mmapSize, ServiceType serviceType = ServiceType::OBJECT, + CacheType cacheType = CacheType::MEMORY); /** * @brief Free memory from shared memory. * @param[in] pointer reference to the pointer to free. Sets pointer to nullptr after. + * @param[in] type The service type for which memory is getting freed. * @return Status of the call. */ - Status FreeMemory(void *&pointer); + Status FreeMemory(void *&pointer, ServiceType type = ServiceType::OBJECT); /** * @brief Free memory from shared memory for the specific tenant. @@ -146,14 +151,17 @@ public: * @param[in] cacheType The cache type, either MEMORY or DISK. * @return Status of the call. */ - Status FreeMemory(const std::string &tenantId, void *&pointer, CacheType cacheType = CacheType::MEMORY); + Status FreeMemory(const std::string &tenantId, void *&pointer, ServiceType serviceType = ServiceType::OBJECT, + CacheType cacheType = CacheType::MEMORY); /** * @brief Get max memory size for the requested service type. + * @param[in] serviceType The service type for which the max memory size is requested. * @param[in] cacheType The cache type. * @return max memory size in bytes for the requested type. */ - uint64_t GetMaxMemorySize(CacheType cacheType = CacheType::MEMORY) const; + uint64_t GetMaxMemorySize(ServiceType serviceType = ServiceType::OBJECT, + CacheType cacheType = CacheType::MEMORY) const; /** * @brief Get the Max Memory Limit size. @@ -215,22 +223,25 @@ public: /** * @brief Get the total memory usage for the given service type. + * @param[in] serviceType The service type for which total memory usage is requested. * @param[in] cacheType The cache type. * @return The total memory usage. */ - uint64_t GetTotalMemoryUsage(CacheType cacheType = CacheType::MEMORY) + uint64_t GetTotalMemoryUsage(ServiceType serviceType = ServiceType::OBJECT, CacheType cacheType = CacheType::MEMORY) { - return GetResourcePoolByType(cacheType)->Usage(); + return GetResourcePoolByType(serviceType, cacheType)->Usage(); } /** * @brief Get the total real memory usage. + * @param[in] serviceType The service type for which total real memory usage is requested. * @param[in] cacheType The cache type. * @return The total real memory usage. */ - uint64_t GetTotalRealMemoryUsage(CacheType cacheType = CacheType::MEMORY) + uint64_t GetTotalRealMemoryUsage(ServiceType serviceType = ServiceType::OBJECT, + CacheType cacheType = CacheType::MEMORY) { - return GetResourcePoolByType(cacheType)->RealUsage(); + return GetResourcePoolByType(serviceType, cacheType)->RealUsage(); } /** @@ -238,9 +249,14 @@ public: * @param[in] type The service type for which total real memory limit is requested. * @return The total real memory limit. */ - uint64_t GetTotalMemoryLimit() + uint64_t GetTotalMemoryLimit(ServiceType type = ServiceType::OBJECT) { - return std::min(objectMemoryStats_->FootprintLimit(), physicalMemoryStats_->FootprintLimit()); + if (type == ServiceType::OBJECT) { + return std::min(objectMemoryStats_->FootprintLimit(), + physicalMemoryStats_->FootprintLimit() - streamMemoryStats_->RealUsage()); + } + return std::min(streamMemoryStats_->FootprintLimit(), + physicalMemoryStats_->FootprintLimit() - objectMemoryStats_->RealUsage()); } /** @@ -257,7 +273,7 @@ public: realUsage = diskStats_->RealUsage(); } else { limit = physicalMemoryStats_->FootprintLimit(); - realUsage = objectMemoryStats_->RealUsage(); + realUsage = objectMemoryStats_->RealUsage() + streamMemoryStats_->RealUsage(); } return limit > realUsage ? limit - realUsage : 0; } @@ -277,7 +293,7 @@ public: realUsage = diskStats_->RealUsage(); } else { limit = physicalMemoryStats_->FootprintLimit(); - realUsage = objectMemoryStats_->RealUsage(); + realUsage = objectMemoryStats_->RealUsage() + streamMemoryStats_->RealUsage(); } if (limit == 0) { @@ -302,7 +318,7 @@ public: /** * @brief Obtains the usage of shared memory. * @return Usage: - * "memoryUsage/physicalMemoryUsage/totalLimit/workerShareMemoryUsage" + * "memoryUsage/physicalMemoryUsage/totalLimit/workerShareMemoryUsage/streamMemoryUsage/streamMemoryLimit" */ std::string GetMemoryStatistics() { @@ -310,10 +326,13 @@ public: return "0/0/0/0/0/0"; } auto objectMemoryUsage = objectMemoryStats_->RealUsage(); - auto memoryUsage = objectMemoryUsage; + auto streamMemoryUsage = streamMemoryStats_->RealUsage(); + auto memoryUsage = objectMemoryUsage + streamMemoryUsage; auto workerShareMemoryUsage = memoryUsage / static_cast(physicalMemoryStats_->FootprintLimit()); - return FormatString("%lu/%lu/%lu/%.3f", memoryUsage, physicalMemoryStats_->RealUsage(), - physicalMemoryStats_->FootprintLimit(), workerShareMemoryUsage); + auto streamMemoryLimit = GetTotalMemoryLimit(ServiceType::STREAM); + return FormatString("%lu/%lu/%lu/%.3f/%lu/%lu", memoryUsage, physicalMemoryStats_->RealUsage(), + physicalMemoryStats_->FootprintLimit(), workerShareMemoryUsage, streamMemoryUsage, + streamMemoryLimit); } /** @@ -383,13 +402,15 @@ private: /** * @brief Get the ResourcePool by type. + * @param[in] serviceType The service type. * @param[in] cacheType The cache type. * @return ResourcePool* The ResourcePool. */ - ResourcePool *GetResourcePoolByType(CacheType cacheType) const; + ResourcePool *GetResourcePoolByType(ServiceType serviceType, CacheType cacheType) const; /** * @brief Get the ResourcePool by type. + * @param[in] serviceType The service type. * @param[in] cacheType The cache type. * @return ResourcePool* The ResourcePool. */ @@ -399,9 +420,10 @@ private: * @brief Init shared memory. * @param[in] size The shared memory size. * @param[in] objectThreshold A limit to restrict the memory usage of object cache / kv service. + * @param[in] streamThreshold A limit to restrict the memory usage of stream cache service. * @return Status K_OK if success, the error otherwise. */ - Status InitSharedMemory(uint64_t size, int objectThreshold); + Status InitSharedMemory(uint64_t size, int objectThreshold, int streamThreshold); /** * @brief Init shared disk. @@ -431,7 +453,7 @@ private: std::atomic noRefPageCount_{ 0 }; // Record the shared memory already binding to physical memory. - // physicalMemoryStats_.realUsage_ = Size of the cache memory that has not been released + // physicalMemoryStats_.realUsage_ = realUsage_(object+stream) + Size of the cache memory that has not been released // after the memory is free. std::unique_ptr physicalMemoryStats_; @@ -451,6 +473,7 @@ private: // Record the shared memory real allocated in bytes among all arenas for different service types. // realUsage = usage + Additional memory for jemalloc alignment std::unique_ptr objectMemoryStats_; + std::unique_ptr streamMemoryStats_; std::function &)> checkIfAllFdReleasedHandler_; diff --git a/src/datasystem/common/shared_memory/arena.cpp b/src/datasystem/common/shared_memory/arena.cpp index 82f9dde..96f9bae 100644 --- a/src/datasystem/common/shared_memory/arena.cpp +++ b/src/datasystem/common/shared_memory/arena.cpp @@ -85,7 +85,7 @@ ArenaGroup::~ArenaGroup() } Status ArenaGroup::AllocateMemory(uint64_t size, bool populate, uint64_t &realSize, void *&pointer, int &fd, - ptrdiff_t &offset, uint64_t &mmapSize) + ptrdiff_t &offset, uint64_t &mmapSize, ServiceType type) { CHECK_FAIL_RETURN_STATUS(!destroyed_.load(), StatusCode::K_RUNTIME_ERROR, "ArenaGroup destroyed"); CHECK_FAIL_RETURN_STATUS(!arenas_.empty(), StatusCode::K_RUNTIME_ERROR, "arenas_ is empty"); @@ -110,13 +110,13 @@ Status ArenaGroup::AllocateMemory(uint64_t size, bool populate, uint64_t &realSi if (status.IsError()) { (void)memoryUsage_.fetch_sub(size, std::memory_order_relaxed); const int logFreq = 100; - LOG_EVERY_N(ERROR, logFreq) << "total size limit:" << Allocator::Instance()->GetMaxMemorySize(cacheType_) + LOG_EVERY_N(ERROR, logFreq) << "total size limit:" << Allocator::Instance()->GetMaxMemorySize(type, cacheType_) << ", total physical memory usage:" << Allocator::Instance()->GetTotalPhysicalMemoryUsage(cacheType_) << ", total real memory usage:" - << Allocator::Instance()->GetTotalRealMemoryUsage(cacheType_) + << Allocator::Instance()->GetTotalRealMemoryUsage(type, cacheType_) << ", total memory usage:" - << Allocator::Instance()->GetTotalMemoryUsage(cacheType_) + << Allocator::Instance()->GetTotalMemoryUsage(type, cacheType_) << ", try alloc size:" << size << ", cacheType:" << static_cast(cacheType_); return status; } @@ -306,6 +306,9 @@ ArenaManager::ArenaManager(bool populate, bool scaling, ssize_t decayMs) arenas_.resize(ARENAS_INIT_SIZE); Jemalloc::Init(&ArenaManager::AllocHook, &ArenaManager::DestroyHook, &ArenaManager::CommitHook); handleExpiredTenantThread_ = std::make_unique(handleExpiredTenantThreadNum_, 0, "TenantExpired"); + if (FLAGS_enable_huge_tlb) { + FLAGS_arena_per_tenant = 1; + } auto arenaNum = FLAGS_arena_per_tenant; if (!FLAGS_shared_disk_directory.empty()) { arenaNum += FLAGS_shared_disk_arena_per_tenant; diff --git a/src/datasystem/common/shared_memory/arena.h b/src/datasystem/common/shared_memory/arena.h index 53f876c..64fc869 100644 --- a/src/datasystem/common/shared_memory/arena.h +++ b/src/datasystem/common/shared_memory/arena.h @@ -39,6 +39,7 @@ #include "datasystem/common/util/wait_post.h" namespace datasystem { +enum class ServiceType { OBJECT, STREAM }; namespace memory { constexpr uint32_t TENANT_RESOURCE_RELEASE_DELAY_MS = 600'000; // 10min class Arena; @@ -59,13 +60,14 @@ public: * @param[out] fd File descriptor of the allocated shared memory segments. * @param[out] offset Offset from the base of the shared memory mmap. * @param[out] mmapSize Total size of shared memory segments. + * @param[in] type The type of datasystem service for this allocation request. * @return K_OK if success, the error otherwise. * K_OUT_OF_MEMORY: no enough memory can be allocated. * K_RUNTIME_ERROR: arena has not been initialized, pointer is null, or * memory info can not be found, it should not happen. */ Status AllocateMemory(uint64_t size, bool populate, uint64_t &realSize, void *&pointer, int &fd, ptrdiff_t &offset, - uint64_t &mmapSize); + uint64_t &mmapSize, ServiceType type); /** * @brief Free memory from shared memory. diff --git a/src/datasystem/common/shared_memory/shm_unit.cpp b/src/datasystem/common/shared_memory/shm_unit.cpp index 3fd2eac..c3c4c38 100644 --- a/src/datasystem/common/shared_memory/shm_unit.cpp +++ b/src/datasystem/common/shared_memory/shm_unit.cpp @@ -67,6 +67,11 @@ void ShmUnit::SetHardFreeMemory() Status ShmUnit::FreeMemory() { RETURN_OK_IF_TRUE(pointer == nullptr); + // If shm owner exists, the memory will be freed together at shmOwner destruction. + if (shmOwner_) { + shmOwner_.reset(); + return Status::OK(); + } // This call will set pointer to nullptr on success. VLOG(1) << "[ShmUnit] Arena FreeMemory, Tenant:" << (tenantId_.empty() ? "Default" : tenantId_) << ", needHardFree: " << needHardFree_; @@ -80,19 +85,45 @@ Status ShmUnit::FreeMemory() LOG(WARNING) << FormatString("[ShmId %s] memset failed, error code: %d.", id, ret); } } - return datasystem::memory::Allocator::Instance()->FreeMemory(tenantId_, pointer, cacheType_); + return datasystem::memory::Allocator::Instance()->FreeMemory(tenantId_, pointer, serviceType_, cacheType_); } -Status ShmUnit::AllocateMemory(const std::string &tenantId, uint64_t needSize, bool populate, +Status ShmUnit::AllocateMemory(const std::string &tenantId, uint64_t needSize, bool populate, ServiceType serviceType, memory::CacheType cacheType) { VLOG(1) << "[ShmUnit] AllocateMemory, Tenant: " << (tenantId.empty() ? "Default" : tenantId) << ", size: " << needSize << ", cachetype: " << static_cast(cacheType); + serviceType_ = serviceType; cacheType_ = cacheType; - RETURN_IF_NOT_OK(datasystem::memory::Allocator::Instance()->AllocateMemory(tenantId, needSize, populate, pointer, - fd, offset, mmapSize, cacheType_)); + RETURN_IF_NOT_OK(datasystem::memory::Allocator::Instance()->AllocateMemory( + tenantId, needSize, populate, pointer, fd, offset, mmapSize, serviceType_, cacheType_)); size = needSize; tenantId_ = tenantId; return Status::OK(); } + +Status ShmOwner::DistributeMemory(uint64_t shmSize, ShmUnit &shmUnit) +{ + // Distribute allocated memory to individual shmUnit. + // Note: Parallel distribute memory is supported via atomic cursor. + uint64_t positionCursor = AllocatePosition(shmSize); + CHECK_FAIL_RETURN_STATUS(positionCursor + shmSize <= size, K_RUNTIME_ERROR, + "Object needs more memory than available."); + shmUnit.size = shmSize; + shmUnit.pointer = reinterpret_cast(reinterpret_cast(pointer) + positionCursor); + shmUnit.fd = fd; + shmUnit.offset = offset + positionCursor; + shmUnit.mmapSize = mmapSize; + shmUnit.serviceType_ = serviceType_; + shmUnit.cacheType_ = cacheType_; + shmUnit.tenantId_ = tenantId_; + shmUnit.needHardFree_ = needHardFree_; + shmUnit.shmOwner_ = shared_from_this(); + return Status::OK(); +} + +uint64_t ShmOwner::AllocatePosition(uint64_t shmSize) +{ + return cursor_.fetch_add(shmSize, std::memory_order_acq_rel); +} } // namespace datasystem diff --git a/src/datasystem/common/shared_memory/shm_unit.h b/src/datasystem/common/shared_memory/shm_unit.h index b927453..dd08db4 100644 --- a/src/datasystem/common/shared_memory/shm_unit.h +++ b/src/datasystem/common/shared_memory/shm_unit.h @@ -27,6 +27,7 @@ #include "datasystem/common/shared_memory/shm_unit_info.h" namespace datasystem { +class ShmOwner; class ShmUnit : public ShmUnitInfo { public: @@ -72,10 +73,12 @@ public: * @param[in] tenantId The Id of the tenant owns the shm unit. * @param[in] needSize The requested size in bytes to allocate. * @param[in] populate Indicate need populate or not. + * @param[in] serviceType The type of datasystem service for this allocation request. * @param[in] cacheType The cache type. * @return Status of the call. */ Status AllocateMemory(const std::string &tenantId, uint64_t needSize, bool populate, + ServiceType serviceType = ServiceType::OBJECT, memory::CacheType cacheType = memory::CacheType::MEMORY); /** @@ -97,11 +100,38 @@ public: void SetHardFreeMemory(); private: + friend class ShmOwner; + + ServiceType serviceType_ = ServiceType::OBJECT; + memory::CacheType cacheType_ = memory::CacheType::MEMORY; std::string tenantId_; bool needHardFree_ = false; + + std::shared_ptr shmOwner_{ nullptr }; +}; + +class ShmOwner : public ShmUnit, public std::enable_shared_from_this { +public: + /** + * @brief Distribute allocated shared memory into the ShmUnit. + * @param[in] shmSize The required shared memory size. + * @param[out] shmUnit The shared memory unit. + * @return Status of the call. + */ + Status DistributeMemory(uint64_t shmSize, ShmUnit &shmUnit); + +private: + /** + * @brief Move up the cursor in ShmOwner to indicate some memory is distributed. + * @param[in] shmSize The required shared memory size. + * @return cursor position before the increment. + */ + uint64_t AllocatePosition(uint64_t shmSize); + + std::atomic cursor_{ 0 }; }; } // namespace datasystem diff --git a/src/datasystem/common/stream_cache/CMakeLists.txt b/src/datasystem/common/stream_cache/CMakeLists.txt new file mode 100644 index 0000000..65b7143 --- /dev/null +++ b/src/datasystem/common/stream_cache/CMakeLists.txt @@ -0,0 +1,13 @@ +set(SC_COMMON_SRCS + cursor.cpp + stream_data_page.cpp + stream_meta_shm.cpp) + +set(SC_COMMON_LIBS + common_log) + +set(COMMON_SC_DEPEND_LIBS + posix_protos_client + ) +add_library(common_sc STATIC ${SC_COMMON_SRCS}) +target_link_libraries(common_sc PRIVATE ${SC_COMMON_LIBS} ${COMMON_SC_DEPEND_LIBS}) diff --git a/src/datasystem/common/stream_cache/consumer_meta.h b/src/datasystem/common/stream_cache/consumer_meta.h new file mode 100644 index 0000000..3348b20 --- /dev/null +++ b/src/datasystem/common/stream_cache/consumer_meta.h @@ -0,0 +1,142 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The interface of file cache worker descriptor. + */ +#ifndef DATASYSTEM_COMMON_STREAM_CACHE_CONSUMER_META_H +#define DATASYSTEM_COMMON_STREAM_CACHE_CONSUMER_META_H + +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/stream/stream_config.h" + +namespace datasystem { +inline SubscriptionType ToSubscriptionType(SubscriptionTypePb typePb) +{ + switch (typePb) { + case SubscriptionTypePb::STREAM_PB: + return SubscriptionType::STREAM; + case SubscriptionTypePb::KEY_PARTITIONS_PB: + return SubscriptionType::KEY_PARTITIONS; + case SubscriptionTypePb::ROUND_ROBIN_PB: + return SubscriptionType::ROUND_ROBIN; + default: + return SubscriptionType::UNKNOWN; + } +} + +inline SubscriptionTypePb ToSubscriptionTypePb(SubscriptionType type) +{ + switch (type) { + case SubscriptionType::STREAM: + return SubscriptionTypePb::STREAM_PB; + case SubscriptionType::KEY_PARTITIONS: + return SubscriptionTypePb::KEY_PARTITIONS_PB; + case SubscriptionType::ROUND_ROBIN: + return SubscriptionTypePb::ROUND_ROBIN_PB; + default: + return SubscriptionTypePb::SubscriptionTypePb_INT_MIN_SENTINEL_DO_NOT_USE_; + } +} + +/** + * @brief Consumer metadata. + * @details Consisting of stream name, consumer id, worker address, subscription configuration and consumer last acked + * cursor. This class is used in both worker and master components. + */ +class ConsumerMeta { +public: + ConsumerMeta(std::string streamName, std::string consumerId, HostPort workerAddress, + SubscriptionConfig subConfig, uint64_t lastAckCursor) + : streamName_(std::move(streamName)), + consumerId_(std::move(consumerId)), + workerAddress_(std::move(workerAddress)), + subConfig_(std::move(subConfig)), + lastAckCursor_(lastAckCursor) + { + } + + std::string ToString() const + { + std::string typeName = (subConfig_.subscriptionType == SubscriptionType::STREAM) ? "stream" : "queue"; + return FormatString("Stream: <%s>, Worker:<%s>, Subscription:<%s>, mode:<%s>, Consumer:<%s>", streamName_, + workerAddress_.ToString(), subConfig_.subscriptionName, typeName, consumerId_); + } + + std::string StreamName() const + { + return streamName_; + } + + HostPort WorkerAddress() const + { + return workerAddress_; + } + + std::string ConsumerId() const + { + return consumerId_; + } + + SubscriptionConfig SubConfig() const + { + return subConfig_; + } + + uint64_t LastAckCursor() const + { + return lastAckCursor_; + } + + inline bool operator<(const ConsumerMeta &other) const + { + return this->consumerId_ < other.consumerId_; + } + + ConsumerMetaPb SerializeToPb() const + { + ConsumerMetaPb consumerMetaPb; + consumerMetaPb.set_stream_name(streamName_); + consumerMetaPb.mutable_worker_address()->set_host(workerAddress_.Host()); + consumerMetaPb.mutable_worker_address()->set_port(workerAddress_.Port()); + consumerMetaPb.set_consumer_id(consumerId_); + consumerMetaPb.mutable_sub_config()->set_subscription_name(subConfig_.subscriptionName); + consumerMetaPb.mutable_sub_config()->set_subscription_type(ToSubscriptionTypePb(subConfig_.subscriptionType)); + consumerMetaPb.set_last_ack_cursor(lastAckCursor_); + return consumerMetaPb; + } + + void ParseFromPb(const ConsumerMetaPb &metaPb) + { + streamName_ = metaPb.stream_name(); + workerAddress_ = HostPort(metaPb.worker_address().host(), metaPb.worker_address().port()); + consumerId_ = metaPb.consumer_id(); + subConfig_ = SubscriptionConfig(metaPb.sub_config().subscription_name(), + ToSubscriptionType(metaPb.sub_config().subscription_type())); + lastAckCursor_ = metaPb.last_ack_cursor(); + } + +private: + std::string streamName_; + std::string consumerId_; + HostPort workerAddress_; + SubscriptionConfig subConfig_; + uint64_t lastAckCursor_ = 0; // Generated on worker side +}; +} // namespace datasystem +#endif // DATASYSTEM_COMMON_STREAM_CACHE_CONSUMER_META_H diff --git a/src/datasystem/common/stream_cache/cursor.cpp b/src/datasystem/common/stream_cache/cursor.cpp new file mode 100644 index 0000000..a2b1f39 --- /dev/null +++ b/src/datasystem/common/stream_cache/cursor.cpp @@ -0,0 +1,518 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/common/stream_cache/cursor.h" + +#include +#include +#include + +#include "datasystem/common/log/log.h" +#include "datasystem/common/constants.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/object_cache/lock.h" +#include "datasystem/common/util/safe_shm_lock.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +SharedMemView::SharedMemView() : lock_(0), fd_(0), mmapSz_(0), offset_(0), sz_(0) +{ +} + +void SharedMemView::CopyTo(ShmView &v) const +{ + v.mmapSz = mmapSz_; + v.off = offset_; + v.sz = sz_; + if (fd_ == 0) { + v.fd = -1; + } else { + v.fd = static_cast(fd_); + } +} + +void SharedMemView::CopyFrom(const std::shared_ptr &shmInfo) +{ + mmapSz_ = shmInfo->mmapSize; + offset_ = shmInfo->offset; + sz_ = shmInfo->size; + if (shmInfo->fd < 0) { + fd_ = 0; + } else { + fd_ = static_cast(shmInfo->fd); + } +} + +void SharedMemView::CopyFrom(const ShmView &v) +{ + mmapSz_ = v.mmapSz; + offset_ = v.off; + sz_ = v.sz; + if (v.fd < 0) { + fd_ = 0; + } else { + fd_ = static_cast(v.fd); + } +} + +SharedMemViewLock::SharedMemViewLock(uint32_t *lockWord) : lockWord_(lockWord) +{ +} + +Status SharedMemViewLock::LockExclusiveAndExec(const std::function &writeFunc, uint64_t timeoutMs) +{ + Timer timer; + bool isFirstTimeout = false; + Status rc; + do { + uint32_t val = __atomic_load_n(lockWord_, __ATOMIC_ACQUIRE); + uint32_t expected = val & ~WRITER; + if (!__atomic_compare_exchange_n(lockWord_, &expected, val | WRITER, true, __ATOMIC_ACQUIRE, + __ATOMIC_RELAXED)) { + if (timer.ElapsedMilliSecond() > TIMEOUT_WARNING_LIMIT_MS && !isFirstTimeout) { + isFirstTimeout = true; + LOG(WARNING) << "Fetching a write-lock on shared memory takes more than " << TIMEOUT_WARNING_LIMIT_MS + << " ms, waiting for writer to release the lock."; + } + // If timeout send an error + CHECK_FAIL_RETURN_STATUS(timer.ElapsedMilliSecond() < timeoutMs, K_TRY_AGAIN, + FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); + continue; + } + // Write bit has been set, we must unset the writer bit before going out of scope. + while (val & ~WRITER) { + // Wait for all readers to go away + val = __atomic_load_n(lockWord_, __ATOMIC_ACQUIRE); + if (timer.ElapsedMilliSecond() > TIMEOUT_WARNING_LIMIT_MS && !isFirstTimeout) { + isFirstTimeout = true; + LOG(WARNING) << "Fetching a write-lock on shared memory takes more than " << TIMEOUT_WARNING_LIMIT_MS + << " ms, waiting for readers to release the lock."; + } + // If timeout send an error + if (timer.ElapsedMilliSecond() >= timeoutMs) { + // Unset the writer bit before returning error. + __atomic_fetch_sub(lockWord_, WRITER, __ATOMIC_RELEASE); + RETURN_STATUS(K_TRY_AGAIN, + FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); + } + } + // cache exception to avoid the lock not released. + try { + // Execute the user function after we get the lock in X + writeFunc(); + } catch (const std::exception &e) { + auto msg = FormatString("Exception when execute writeFunc get: %s", e.what()); + rc = Status(K_RUNTIME_ERROR, msg); + } + __atomic_fetch_sub(lockWord_, WRITER, __ATOMIC_RELEASE); + if (isFirstTimeout) { + LOG(WARNING) << "Fetching a write-lock on shared memory takes " << timer.ElapsedMilliSecond() << " ms"; + } + if (rc.IsError()) { + LOG(ERROR) << rc.GetMsg(); + } + return rc; + } while (true); +} + +Status SharedMemViewLock::LockSharedAndExec(const std::function &readFunc, uint64_t timeoutMs) +{ + Timer timer; + bool isFirstTimeout = false; + Status rc; + do { + while (__atomic_load_n(lockWord_, __ATOMIC_ACQUIRE) & WRITER) { + // Block on writer + if (timer.ElapsedMilliSecond() > TIMEOUT_WARNING_LIMIT_MS && !isFirstTimeout) { + isFirstTimeout = true; + LOG(WARNING) << "Fetching a read-lock on shared memory takes more than " << TIMEOUT_WARNING_LIMIT_MS + << " ms, waiting for writer to release the lock"; + } + + // If timeout send an error + CHECK_FAIL_RETURN_STATUS(timer.ElapsedMilliSecond() < timeoutMs, K_TRY_AGAIN, + FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); + } + if ((__atomic_add_fetch(lockWord_, READER, __ATOMIC_ACQUIRE) & WRITER) == 0) { + // cache exception to avoid the lock not released. + try { + // Execute user function after we get the lock in shared mode + readFunc(); + } catch (const std::exception &e) { + auto msg = FormatString("Exception when execute readFunc get: %s", e.what()); + rc = Status(K_RUNTIME_ERROR, msg); + } + + __atomic_fetch_sub(lockWord_, READER, __ATOMIC_RELEASE); + if (isFirstTimeout) { + LOG(WARNING) << "Fetching a read-lock on shared memory takes " << timer.ElapsedMilliSecond() << " ms"; + } + if (rc.IsError()) { + LOG(ERROR) << rc.GetMsg(); + } + return rc; + } + __atomic_fetch_sub(lockWord_, READER, __ATOMIC_RELEASE); // A writer beats us. retry again + // If timeout send an error + CHECK_FAIL_RETURN_STATUS(timer.ElapsedMilliSecond() < timeoutMs, K_TRY_AGAIN, + FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); + } while (true); +} + +uint64_t Cursor::GetWALastAckCursor() const +{ + if (lastAckCursor_) { + return __atomic_load_n(lastAckCursor_, __ATOMIC_SEQ_CST); + } + return 0; +} + +void Cursor::UpdateWALastAckCursor(uint64_t elementId) const +{ + if (lastAckCursor_) { + __atomic_store_n(lastAckCursor_, elementId, __ATOMIC_SEQ_CST); + return; + } + LOG(ERROR) << "Cursor not initialized"; +} + +Status Cursor::GetLastPageView(ShmView &shm, uint64_t timeoutMs) const +{ + return GetPageView(lastPageShmView_, shm, timeoutMs); +} + +Status Cursor::SetLastPage(const ShmView &shm, uint64_t timeoutMs) +{ + return SetPage(lastPageShmView_, shm, timeoutMs); +} + +Status Cursor::SetLastPageRef(const ShmView &shm, uint64_t timeoutMs, bool isTagged) +{ + return lastPageShmView_->SetView(shm, isTagged, timeoutMs); +} + +Status Cursor::GetLastLockedPageView(ShmView &shm, uint64_t timeoutMs) const +{ + return GetPageView(lastLockedShmView_, shm, timeoutMs); +} + +Status Cursor::SetLastLockedPage(const ShmView &shm, uint64_t timeoutMs) +{ + return SetPage(lastLockedShmView_, shm, timeoutMs); +} + +void Cursor::InitFutexArea() +{ + __atomic_store_n(futexWord_, AckVal::NONE, __ATOMIC_RELAXED); +} + +Status Cursor::Wait(uint64_t timeoutMs, int32_t &val) +{ + auto t = MilliSecondsToTimeSpec(timeoutMs); + auto res = syscall(SYS_futex, futexWord_, FUTEX_WAIT, AckVal::NONE, &t, nullptr, 0); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + res != -1 || errno == EAGAIN || errno == ETIMEDOUT || errno == EINTR, K_RUNTIME_ERROR, + FormatString("Futex wait error. Errno = %d. Message %s", errno, StrErr(errno))); + if (res == 0 || errno == EAGAIN) { + auto fetchVal = __atomic_load_n(futexWord_, __ATOMIC_RELAXED); + uint32_t checkBit = fetchVal & Cursor::SHIFT; + val = static_cast(fetchVal >> Cursor::SHIFT); + CHECK_FAIL_RETURN_STATUS(checkBit == Cursor::AckVal::DONE, K_RUNTIME_ERROR, + FormatString("Handshake error. Expect %d but get %d", Cursor::AckVal::DONE, checkBit)); + return Status::OK(); + } + RETURN_STATUS(K_TRY_AGAIN, FormatString("Time out within allowed time %zu ms", timeoutMs)); +} + +Status Cursor::Wake(const int32_t val, size_t &numWaiter) +{ + __atomic_store_n(futexWord_, static_cast(val) << SHIFT | AckVal::DONE, __ATOMIC_SEQ_CST); + auto res = syscall(SYS_futex, futexWord_, FUTEX_WAKE, INT_MAX, nullptr, nullptr, 0); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + res != -1, K_RUNTIME_ERROR, FormatString("Futex wake error. Errno = %d. Message %s", errno, StrErr(errno))); + numWaiter = static_cast(res); + return Status::OK(); +} + +bool Cursor::ForceClose() const +{ + uint32_t val = __atomic_load_n(forceClose_, __ATOMIC_RELAXED); + return val > 0; +} + +void Cursor::SetForceClose() +{ + __atomic_store_n(forceClose_, 1, __ATOMIC_RELAXED); +} + +void Cursor::SetElementCount(uint64_t val) +{ + __atomic_store_n(elementCount_, val, __ATOMIC_RELAXED); +} + +uint64_t Cursor::IncrementElementCount(uint64_t inc) +{ + return __atomic_add_fetch(elementCount_, inc, __ATOMIC_RELAXED); +} + +uint64_t Cursor::GetElementCountAndReset() +{ + uint64_t val = __atomic_load_n(elementCount_, __ATOMIC_RELAXED); + while (!__atomic_compare_exchange_n(elementCount_, &val, 0, true, __ATOMIC_SEQ_CST, __ATOMIC_RELAXED)) { + val = __atomic_load_n(elementCount_, __ATOMIC_RELAXED); + } + return val; +} + +uint64_t Cursor::GetElementCount() const +{ + return __atomic_load_n(elementCount_, __ATOMIC_RELAXED); +} + +uint64_t Cursor::IncrementRequestCount() +{ + return __atomic_add_fetch(requestCount_, 1, __ATOMIC_RELAXED); +} + +uint64_t Cursor::GetRequestCountAndReset() +{ + uint64_t val = __atomic_load_n(requestCount_, __ATOMIC_RELAXED); + while (!__atomic_compare_exchange_n(requestCount_, &val, 0, true, __ATOMIC_SEQ_CST, __ATOMIC_RELAXED)) { + val = __atomic_load_n(requestCount_, __ATOMIC_RELAXED); + } + return val; +} + +uint32_t Cursor::GetEyeCatcher() const +{ + return __atomic_load_n(eyeCatcher_, __ATOMIC_RELAXED); +} + +uint32_t Cursor::GetClientVersion() const +{ + return GetEyeCatcher() & CLIENT_EYECATCHER_MASK; +} + +uint32_t Cursor::GetWorkerVersion() const +{ + return GetEyeCatcher() & WORKER_EYECATCHER_MASK; +} + +Status Cursor::ForceUnLock(uint32_t lockId, const std::string &msg) +{ + Status lastRc; + if (lastPageShmView_ != nullptr) { + lastRc = lastPageShmView_->ForceUnLock(lockId, msg); + } + + if (lastLockedShmView_ != nullptr) { + auto rc = lastLockedShmView_->ForceUnLock(lockId, msg); + lastRc = rc.IsError() ? rc : lastRc; + } + return lastRc; +} + +Status Cursor::SetClientVersion(uint32_t val) +{ + return SetEyeCatcherHelper(val, CLIENT_EYECATCHER_MASK); +} + +Status Cursor::SetWorkerVersion(uint32_t val) +{ + return SetEyeCatcherHelper(val, WORKER_EYECATCHER_MASK); +} + +Status Cursor::SetEyeCatcherHelper(uint32_t val, uint32_t mask) +{ + RETURN_OK_IF_TRUE(val == 0); + CHECK_FAIL_RETURN_STATUS((val & mask) == val, K_RUNTIME_ERROR, + FormatString("Invalid eye catcher version %zu given mask %x", val, mask)); + uint32_t current; + uint32_t newVal; + do { + current = __atomic_load_n(eyeCatcher_, __ATOMIC_RELAXED); + RETURN_OK_IF_TRUE((current & mask) == val); + CHECK_FAIL_RETURN_STATUS((current & mask) == 0, K_RUNTIME_ERROR, + "Client or worker eye catcher version is to be set only once"); + newVal = current | val; + } while (!__atomic_compare_exchange_n(eyeCatcher_, ¤t, newVal, true, __ATOMIC_SEQ_CST, __ATOMIC_RELAXED)); + return Status::OK(); +} + +Status Cursor::GetPageView(const std::shared_ptr &impl, ShmView &shm, uint64_t timeoutMs) +{ + ShmView v; + bool isTagged; + RETURN_IF_NOT_OK(impl->GetView(v, isTagged, timeoutMs)); + // We never tag the view in the mailbox area. But for safety, just return null + if (isTagged) { + shm = {}; + return Status::OK(); + } + shm = v; + return Status::OK(); +} + +Status Cursor::SetPage(std::shared_ptr &impl, const ShmView &shm, uint64_t timeoutMs) +{ + return impl->SetView(shm, false, timeoutMs); +} + +Cursor::Cursor(void *ptr, size_t sz, uint32_t lockId) : ptr_(reinterpret_cast(ptr)), sz_(sz), lockId_(lockId) +{ +} + +Status Cursor::Init(std::shared_ptr mmapEntry) +{ + RETURN_RUNTIME_ERROR_IF_NULL(ptr_); +#define CURSOR_INIT_FIELD(start, cur, field) \ + do { \ + (field) = reinterpret_cast(cur); \ + (cur) += sizeof(*(field)); \ + CHECK_FAIL_RETURN_STATUS(static_cast((cur) - (start)) <= sz_, K_RUNTIME_ERROR, \ + "Work area size too small"); \ + } while (false) + + auto *data = ptr_; + CURSOR_INIT_FIELD(ptr_, data, lastAckCursor_); + CURSOR_INIT_FIELD(ptr_, data, lastPage_); + CURSOR_INIT_FIELD(ptr_, data, futexWord_); + CURSOR_INIT_FIELD(ptr_, data, forceClose_); + CURSOR_INIT_FIELD(ptr_, data, elementCount_); + CURSOR_INIT_FIELD(ptr_, data, requestCount_); + lastPageShmView_ = std::make_shared(lastPage_, sizeof(*lastPage_), lockId_); + RETURN_IF_NOT_OK(lastPageShmView_->Init(false)); + + if (mmapEntry != nullptr) { + mmapEntry_ = std::move(mmapEntry); + } + // Clear the wait area + __atomic_store_n(futexWord_, 0, __ATOMIC_SEQ_CST); + // Clear the stream state + __atomic_store_n(forceClose_, 0, __ATOMIC_SEQ_CST); + // Clear the request count + __atomic_store_n(requestCount_, 0, __ATOMIC_SEQ_CST); + + // This starts V2 where another 64 bytes is added after requestCount + RETURN_OK_IF_TRUE(sz_ == K_CURSOR_SIZE_V1); + // Continue the new fields added in V2 + CURSOR_INIT_FIELD(ptr_, data, eyeCatcher_); + CURSOR_INIT_FIELD(ptr_, data, waitCount_); + CURSOR_INIT_FIELD(ptr_, data, lastLockedPage_); + // Initialize the shm view + lastLockedShmView_ = std::make_shared(lastLockedPage_, sizeof(*lastLockedPage_), lockId_); + RETURN_IF_NOT_OK(lastLockedShmView_->Init(false)); + // Clear the wait area + __atomic_store_n(waitCount_, 0, __ATOMIC_SEQ_CST); + + return Status::OK(); +} + +Status SharedMemViewImpl::Init(bool clearFields) +{ + RETURN_RUNTIME_ERROR_IF_NULL(view_); + CHECK_FAIL_RETURN_STATUS(sz_ >= sizeof(SharedMemView), K_RUNTIME_ERROR, + FormatString("Not enough size. Need at least %zu", sizeof(SharedMemView))); + if (clearFields) { + auto rc = memset_s(view_, sz_, 0, sizeof(SharedMemView)); + CHECK_FAIL_RETURN_STATUS(rc == 0, K_RUNTIME_ERROR, FormatString("memset_s fails. Errno = %d", errno)); + } + return Status::OK(); +} + +Status SharedMemViewImpl::LockExclusiveAndExec(const std::function &writeFunc, uint64_t timeoutMs) +{ + RETURN_RUNTIME_ERROR_IF_NULL(view_); + SharedMemViewLock lock(&view_->lock_); + return lock.LockExclusiveAndExec(writeFunc, timeoutMs); +} + +Status SharedMemViewImpl::LockSharedAndExec(const std::function &readFunc, uint64_t timeoutMs) const +{ + RETURN_RUNTIME_ERROR_IF_NULL(view_); + SharedMemViewLock lock(&view_->lock_); + return lock.LockSharedAndExec(readFunc, timeoutMs); +} + +Status SharedMemViewImpl::LockAndExec(const std::function &func, uint64_t timeoutMs) +{ + RETURN_RUNTIME_ERROR_IF_NULL(view_); + SafeShmLock xlocker(&view_->lock_, lockId_); + RETURN_IF_NOT_OK(xlocker.Lock(timeoutMs)); + Status rc; + try { + func(); + } catch (const std::exception &e) { + auto msg = FormatString("Exception when execute func get: %s", e.what()); + rc = Status(K_RUNTIME_ERROR, msg); + } + xlocker.UnLock(); + if (rc.IsError()) { + LOG(ERROR) << rc.GetMsg(); + } + return rc; +} + +Status SharedMemViewImpl::SetView(const ShmView &shm, bool isTagged, uint64_t timeoutMs) +{ + auto func = [this, isTagged, &shm]() { + INJECT_POINT("SharedMemViewImpl.SetView", [] { throw std::bad_function_call(); }); + view_->CopyFrom(shm); + if (isTagged) { + view_->fd_ |= PAGE_VIEW_TAG; + } else { + view_->fd_ &= ~PAGE_VIEW_TAG; + } + }; + INJECT_POINT("MemView.Lock.OldVersion", [&] { return LockExclusiveAndExec(func, timeoutMs); }); + return LockAndExec(func, timeoutMs); +} + +Status SharedMemViewImpl::GetView(ShmView &shm, bool &isTagged, uint64_t timeoutMs) +{ + auto func = [this, &isTagged, &shm]() { + INJECT_POINT("SharedMemViewImpl.GetView", [] { throw std::bad_function_call(); }); + isTagged = (view_->fd_ & PAGE_VIEW_TAG); + if (isTagged) { + auto fd = view_->fd_; + fd &= ~PAGE_VIEW_TAG; + shm = { .fd = static_cast(fd), .mmapSz = view_->mmapSz_, .off = view_->offset_, .sz = view_->sz_ }; + } else { + INJECT_POINT("producer_crash_getview", [] {}); + view_->CopyTo(shm); + } + }; + INJECT_POINT("MemView.Lock.OldVersion", [&] { return LockSharedAndExec(func, timeoutMs); }); + return LockAndExec(func, timeoutMs); +} + +Status SharedMemViewImpl::ForceUnLock(uint32_t lockId, const std::string &msg) +{ + CHECK_FAIL_RETURN_STATUS( + lockId_ == WORKER_LOCK_ID, K_RUNTIME_ERROR, + FormatString("Only worker can call ForceUnLock, invalid lockId_ %zu, lockId %zu", lockId_, lockId)); + + CHECK_FAIL_RETURN_STATUS(lockId > WORKER_LOCK_ID, K_RUNTIME_ERROR, FormatString("Invalid lockId", lockId)); + + if (view_ != nullptr && SafeShmLock::ForceUnlock(&view_->lock_, lockId)) { + LOG(INFO) << FormatString("[%s] ForceUnLock for lockId %zu, PageViewInfo: %s", msg, lockId, view_->ToString()); + } + return Status::OK(); +} + +} // namespace datasystem diff --git a/src/datasystem/common/stream_cache/cursor.h b/src/datasystem/common/stream_cache/cursor.h new file mode 100644 index 0000000..8090a28 --- /dev/null +++ b/src/datasystem/common/stream_cache/cursor.h @@ -0,0 +1,383 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_STREAM_CACHE_CURSOR_H +#define DATASYSTEM_STREAM_CACHE_CURSOR_H + +#include +#include +#include +#include +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/client/mmap_manager.h" + +namespace datasystem { +class SharedMemViewLock { +public: + explicit SharedMemViewLock(uint32_t *lockWord); + Status LockExclusiveAndExec(const std::function &writeFunc, uint64_t timeoutMs); + Status LockSharedAndExec(const std::function &readFunc, uint64_t timeoutMs); + +private: + uint32_t *lockWord_; + constexpr static const uint32_t WRITER = 1; + constexpr static const uint32_t READER = 2; + constexpr static const int TIMEOUT_WARNING_LIMIT_MS = 3000; +}; + +// We store ShmView on the page. To avoid the size changes on different platforms +// and avoid any gap between fields, we will mirror ShmView but enforce the whole +// structure to be 32 bytes in size without any gap +// Some fields will be updated atomically. +struct SharedMemView { + uint32_t lock_; // 4 bytes + uint32_t fd_; // 4 bytes (unlikely a file descriptor needs a full 8 byte to store) + uint64_t mmapSz_; // 8 bytes + int64_t offset_; // 8 bytes + uint64_t sz_; // 8 bytes + ~SharedMemView() = default; + SharedMemView(); + void CopyTo(ShmView &v) const; + void CopyFrom(const std::shared_ptr &shmInfo); + void CopyFrom(const ShmView &v); + std::string ToString() + { + std::stringstream ss; + ss << "fd:" << fd_; + ss << ", mmapSz:" << mmapSz_; + ss << ", offset:" << offset_; + ss << ", sz:" << sz_; + return ss.str(); + } +}; + +class SharedMemViewImpl { +public: + SharedMemViewImpl() : view_(nullptr), sz_(0) + { + } + SharedMemViewImpl(void *ptr, size_t sz, uint32_t lockId) + : view_(reinterpret_cast(ptr)), sz_(sz), lockId_(lockId) + { + } + ~SharedMemViewImpl() = default; + + /** + * Init. Check size + * @return + */ + Status Init(bool clearFields = false); + + /** + * Set the ShmView of the last page + * @param[in] shm The shme view . + * @param[in] isTagged The tag bit for shm page. + * @param[in] timeoutMs The timeout in ms. + * @return Status of this call. + */ + Status SetView(const ShmView &shm, bool isTagged, uint64_t timeoutMs); + + /** + * Get ShmView of the last page + * @param[out] shm The shme view . + * @param[out] isTagged The tag bit for shm page. + * @param[in] timeoutMs The timeout in ms. + * @return Status of this call. + */ + Status GetView(ShmView &shm, bool &isTagged, uint64_t timeoutMs); + + /** + * @brief Force unlock the shm view. + * @param[in] lockId The lock id. + * @param[in] msg The message for log. + * @return Status of this call. + */ + Status ForceUnLock(uint32_t lockId, const std::string &msg); + +protected: + // The high bit of fd_ field is used for different purpose. + static uint32_t constexpr PAGE_VIEW_TAG = 0x80000000; + +private: + friend class StreamDataPage; + SharedMemView *view_; + size_t sz_; + uint32_t lockId_{ 0 }; + Status LockExclusiveAndExec(const std::function &func, uint64_t timeoutMs); + Status LockSharedAndExec(const std::function &func, uint64_t timeoutMs) const; + Status LockAndExec(const std::function &func, uint64_t timeoutMs); +}; + +static constexpr uint64_t ONE_K = 1'000ul; +static constexpr uint64_t ONE_M = 1'000'000ul; +timespec inline MilliSecondsToTimeSpec(uint64_t timeoutMs) +{ + timespec t{ .tv_sec = static_cast<__time_t>(timeoutMs / ONE_K), + .tv_nsec = static_cast<__syscall_slong_t>((timeoutMs % ONE_K) * ONE_M) }; + return t; +} + +// A work area that is shared between +// (1) client::stream_cache::ConsumerImpl and the corresponding worker::stream_cache::Consumer +// (2) client::stream_cache::ProducerImpl and the corresponding worker::stream_cache::Producer +// sz is the size of this work area. For version 1, the size is 64 bytes. +// (a) The first 8 bytes is used for fetching the lastAckCursor from the client +// (b) Next 32 bytes is used storing ShmView of the last page (if it exists) +// (c) Next 4 bytes is used for futex +// (d) Next 4 bytes is used to signal force close +// (e) Next 8 bytes for element count (pushing/receiving) +// (f) Next 8 bytes for request count (send/receive requests called by client) +// (g) The size of the work area is 64 bytes, and we have 0 bytes left if this is version 1. +// Local producer/consumer should use the inherited function WorkAreaIsV2 (see client_base_impl.h) +// for compatibility if the worker is of lower level +// For version 2, another 64 bytes is allocated and the total structure size is now 128 bytes +// (h) Some eye catcher field (4 bytes) so that worker can tell if it is a V1 or V2 client. +// V1 client will not write past the 64 bytes. So if this field is set, it is a V2 client +// (i) Next 4 bytes is for alignment and can be combined with futex area (c) above as a wait count area +// and call the static function PageLock::FutexWake and PageLock::FutexWait to improve performance +// (j) Next 32 bytes is used to store ShmView of the last page locked by the producer +// (j) 24 bytes is left for future use. +class Cursor { +public: + // V1 of Cursor area has a size of 64 bytes + constexpr static size_t K_CURSOR_SIZE_V1 = 64; + // V2 of Cursor area is extending to 128 bytes + constexpr static size_t K_CURSOR_SIZE_V2 = 128; + // EyeCatcher masks for client and worker version. + const uint32_t CLIENT_EYECATCHER_MASK = static_cast(0x0000FFFF); + const uint32_t WORKER_EYECATCHER_MASK = static_cast(0xFFFF0000); + // Client EyeCatcher V2 is K_CURSOR_SIZE_V2. + constexpr static uint32_t K_WORKER_EYECATCHER_V1 = static_cast(0x00010000); + enum AckVal : uint32_t { NONE = 0, DONE = 1 }; + constexpr static uint32_t SHIFT = 1; + + Cursor(void *ptr, size_t sz, uint32_t lockId); + ~Cursor() = default; + + /** + * @brief Get the last ack cursor from the work area + * @return last ack cursor + */ + uint64_t GetWALastAckCursor() const; + + /** + * @brief Update the last ack cursor in work area + * @param elementId + */ + void UpdateWALastAckCursor(uint64_t elementId) const; + + /** + * Initialization + */ + Status Init(std::shared_ptr mmapEntry = nullptr); + + /** + * Get ShmView of the last page + */ + Status GetLastPageView(ShmView &shm, uint64_t timeoutMs) const; + + /** + * @brief Get the last page if enable shared page. + * @param[out] shm The shm view of ths last page. + * @param[out] switchToSharedPage Whether switched to shared page. + * @param[in] timeoutMs The timeout in ms. + * @param[in] toShmInfo The shm mapping function + * @return Status of this call + */ + template + Status GetLastPageViewByRef(ShmView &shm, bool &switchToSharedPage, uint64_t timeoutMs, F &&toShmInfo) + { + switchToSharedPage = false; + if (lastPageRefShmView_) { + return Cursor::GetPageView(lastPageRefShmView_, shm, timeoutMs); + } + + ShmView lastPageRefView; + RETURN_IF_NOT_OK(lastPageShmView_->GetView(lastPageRefView, switchToSharedPage, timeoutMs)); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(lastPageRefView.fd > 0 && lastPageRefView.sz == sizeof(SharedMemView), + K_RUNTIME_ERROR, + FormatString("Invalid lastPageRefView %s", lastPageRefView.ToStr())); + std::shared_ptr mmapEntry; + std::shared_ptr lastPageRefUnit; + RETURN_IF_NOT_OK(toShmInfo(lastPageRefView, lastPageRefUnit, mmapEntry)); + SharedMemView *memView = reinterpret_cast( + reinterpret_cast(lastPageRefUnit->GetPointer()) + lastPageRefUnit->offset); + auto lastPageRefShmView = std::make_shared(memView, sizeof(SharedMemView), lockId_); + RETURN_IF_NOT_OK(lastPageRefShmView->Init(false)); + if (switchToSharedPage) { + refMmapEntry_ = std::move(mmapEntry); + lastPageRefShmView_ = lastPageRefShmView; + } + return Cursor::GetPageView(lastPageRefShmView, shm, timeoutMs); + } + + /** + * Set the ShmView of the last page + * @param shm + */ + Status SetLastPage(const ShmView &shm, uint64_t timeoutMs); + + /** + * Set the ShmView of the last page + * @param[in] shm The shm view of last page ref. + * @param[in] timeoutMs The timeout in ms. + * @param[in] isTagged The page has tag bit. + * @return The status of this call. + */ + Status SetLastPageRef(const ShmView &shm, uint64_t timeoutMs, bool isTagged); + + void InitFutexArea(); + + /** + * @brief Wait for event + * @param timeoutMs timeout in milliseconds + * @param val Updated value from futex word + * @return + */ + Status Wait(uint64_t timeoutMs, int32_t &val); + + /** + * @brief Wake up waiters + * @param val Value to write to the futex word + * @return numWaiter Number of waiters on the futex word + */ + Status Wake(int32_t val, size_t &numWaiter); + + /** + * Check for interrupt + * @return true is interrupted. + * @note If a consumer is interrupted, it will return K_SC_NO_PRODUCER + * If a producer is interrupted, it will return K_SC_NO_CONSUMER + */ + bool ForceClose() const; + + /** + * Interrupt a consumer when a producer is gone, or + * interrupt a producer when a consumer is gone + */ + void SetForceClose(); + + /** + * Set the element count + * @param val + */ + void SetElementCount(uint64_t val); + + /** + * Increment the element count + * @param inc + * @return value before increment + */ + uint64_t IncrementElementCount(uint64_t inc = 1); + + /** + * Get the element count + */ + uint64_t GetElementCount() const; + + /** + * @brief Get the element count and reset it to 0. + * @return + */ + uint64_t GetElementCountAndReset(); + + /** + * Increment the request count + * @param inc + * @return value before increment + */ + uint64_t IncrementRequestCount(); + + /** + * Get the request count and reset it to 0 + */ + uint64_t GetRequestCountAndReset(); + + /** + * Get ShmView of the last locked page + */ + Status GetLastLockedPageView(ShmView &shm, uint64_t timeoutMs) const; + + /** + * Set the ShmView of the last locked page + * @param shm + */ + Status SetLastLockedPage(const ShmView &shm, uint64_t timeoutMs); + + /** + * @brief Update the eye catcher field for client version. + */ + Status SetClientVersion(uint32_t val); + + /** + * @brief Update the eye catcher field for worker version. + */ + Status SetWorkerVersion(uint32_t val); + + /** + * @brief Retrieve the eye catcher version for client. + */ + uint32_t GetClientVersion() const; + + /** + * @brief Retrieve the eye catcher version for worker. + */ + uint32_t GetWorkerVersion() const; + + Status ForceUnLock(uint32_t lockId, const std::string &msg); + +private: + uint8_t *ptr_; + uint64_t *lastAckCursor_{ nullptr }; + SharedMemView *lastPage_{ nullptr }; + uint32_t *futexWord_{ nullptr }; + uint32_t *forceClose_{ nullptr }; + uint64_t *elementCount_{ nullptr }; + uint64_t *requestCount_{ nullptr }; + uint32_t *eyeCatcher_{ nullptr }; + uint32_t *waitCount_{ nullptr }; + SharedMemView *lastLockedPage_{ nullptr }; + const size_t sz_; + const uint32_t lockId_; + std::shared_ptr lastPageShmView_; + std::shared_ptr lastLockedShmView_; + std::shared_ptr mmapEntry_; + + std::shared_ptr lastPageRefShmView_; + std::shared_ptr refMmapEntry_; + + /** + * @brief Get ShmView from a SharedMemViewImpl + */ + static Status GetPageView(const std::shared_ptr &impl, ShmView &shm, uint64_t timeoutMs); + + /** + * @brief Set the ShmView to a SharedMemViewImpl + */ + static Status SetPage(std::shared_ptr &impl, const ShmView &shm, uint64_t timeoutMs); + + /** + * @brief Helper function to update the eye catcher field for client or worker depending on the mask given. + */ + Status SetEyeCatcherHelper(uint32_t val, uint32_t mask); + + /** + * @brief Retrieve the eye catcher + */ + uint32_t GetEyeCatcher() const; +}; +} // namespace datasystem +#endif diff --git a/src/datasystem/common/stream_cache/stream_data_page.cpp b/src/datasystem/common/stream_cache/stream_data_page.cpp new file mode 100644 index 0000000..c8c6f05 --- /dev/null +++ b/src/datasystem/common/stream_cache/stream_data_page.cpp @@ -0,0 +1,1153 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" + +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/util/bitmask_enum.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/memory.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/stream/stream_config.h" + +namespace datasystem { +Status PageLock::FutexWait(uint32_t *lockArea, uint32_t *waitCount, uint32_t val, uint64_t timeoutMs) +{ + auto t = MilliSecondsToTimeSpec(timeoutMs); + auto fetchVal1 = __atomic_fetch_add(waitCount, 1, __ATOMIC_SEQ_CST); + auto res = syscall(SYS_futex, lockArea, FUTEX_WAIT, val, &t, nullptr, 0); + auto fetchVal2 = __atomic_fetch_sub(waitCount, 1, __ATOMIC_SEQ_CST); + // Always log if the ref count is abnormally large, it can be a sign of problem. + const int warningVal = 1000; + LOG_IF(INFO, fetchVal2 > warningVal) << FormatString( + "Wait count before increment: %zu, wait count before decrement: %zu", fetchVal1, fetchVal2); + // Examine the return code of res. EAGAIN is actually ok for FUTEX_WAIT. + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + res != -1 || errno == EAGAIN || errno == ETIMEDOUT || errno == EINTR, K_RUNTIME_ERROR, + FormatString("Futex wait error. Errno = %d. Message %s", errno, StrErr(errno))); + RETURN_OK_IF_TRUE(res == 0 || errno == EAGAIN || errno == EINTR); + RETURN_STATUS(K_TRY_AGAIN, FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); +} + +Status PageLock::FutexWake(uint32_t *lockArea, uint32_t *waitCount, int numToWakeUp) +{ + PerfPoint point1(PerfKey::PAGE_WAKE_CONSUMER); + // syscall futex is not cheap. Only call it when there are waiters. OK if there is none. + auto numWaiter = __atomic_load_n(waitCount, __ATOMIC_SEQ_CST); + RETURN_OK_IF_TRUE(numWaiter == 0); + PerfPoint point(PerfKey::PAGE_FUTEX_WAKE); + auto res = syscall(SYS_futex, lockArea, FUTEX_WAKE, numToWakeUp, nullptr, nullptr, 0); + point.Record(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + res != -1, K_RUNTIME_ERROR, FormatString("futex wake error. Errno = %d. Message %s", errno, StrErr(errno))); + VLOG_IF(SC_INTERNAL_LOG_LEVEL, res > 0) << FormatString("Wake up %zu waiters", res); + return Status::OK(); +} + +PageLock::PageLock(uint32_t *lockArea, uint32_t *waitArea, uint32_t lockId) + : lockFlag_(lockArea), waitCount_(waitArea), lockId_(lockId) +{ +} + +Status PageLock::Lock(uint64_t timeoutMs) +{ + const uint64_t minTimeoutMs = 5; + timeoutMs = std::max(minTimeoutMs, timeoutMs); + PerfPoint point(PerfKey::PAGE_INSERT_GET_LOCK); + Timer timer; + uint64_t useTimeMs = 0; + const uint64_t futexThreshold = 10; + auto lockFunc = [this]() { + uint32_t val = __atomic_load_n(lockFlag_, __ATOMIC_SEQ_CST); + if (val & WRITE_LOCK_NUM) { + return false; + } + // We only need the lower order bit for locked or unlocked. The rest of the bits are used to + // store the lock id. + uint32_t lockVal = (lockId_ << SHIFT) | WRITE_LOCK_NUM; + // Compare and set lock area + return __atomic_compare_exchange_n(lockFlag_, &val, lockVal, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + }; + do { + // We will do some hybrid approach. If we can't get the lock, spin for certain number of times. + // The reason is producer will not hold the page lock for a long time. More precisely, + // the producer only holds the lock to get the offset which shouldn't take long. + // After certain number of spins, and we still can't get the lock, we will do a futex wait. + RETURN_OK_IF_TRUE(lockFunc()); + useTimeMs = static_cast(timer.ElapsedMilliSecond()); + CHECK_FAIL_RETURN_STATUS(useTimeMs < timeoutMs, K_TRY_AGAIN, + FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); + if (static_cast(timer.ElapsedMilliSecond() >= futexThreshold)) { + auto remainingMs = timeoutMs - useTimeMs; + Status rc = PageLock::FutexWait(lockFlag_, waitCount_, WRITE_LOCK_NUM, remainingMs); + useTimeMs = static_cast(timer.ElapsedMilliSecond()); + if (rc.IsOk()) { + continue; + } + RETURN_IF_NOT_OK_EXCEPT(rc, K_TRY_AGAIN); + } + } while (useTimeMs < timeoutMs); + RETURN_STATUS(K_TRY_AGAIN, FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); +} + +void PageLock::Unlock() +{ + PerfPoint point(PerfKey::PAGE_INSERT_RELEASE_LOCK); + if (__atomic_load_n(lockFlag_, __ATOMIC_SEQ_CST) & WRITE_LOCK_NUM) { + uint32_t expectedVal = (lockId_ << SHIFT) | WRITE_LOCK_NUM; + if (__atomic_compare_exchange_n(lockFlag_, &expectedVal, NO_LOCK_NUM, false, __ATOMIC_SEQ_CST, + __ATOMIC_SEQ_CST)) { + VLOG(SC_DEBUG_LOG_LEVEL) << "Success to unlock the write lock"; + // There is no need to wake up all producers. Only one of them can write + LOG_IF_ERROR(PageLock::FutexWake(lockFlag_, waitCount_, 1), "Futex unlock"); + } + } +} + +bool PageLock::TryUnlockByLockId(uint32_t lockId) +{ + // If the page is locked with lockId, construct the expected value + uint32_t expectedVal = (lockId << SHIFT) | WRITE_LOCK_NUM; + // Switch the lock ownership to this worker (and has lock id 0). + uint32_t newVal = WRITE_LOCK_NUM; + return __atomic_compare_exchange_n(lockFlag_, &expectedVal, newVal, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); +} + +StreamPageLock::StreamPageLock(std::shared_ptr page) : pageLocked_(false), page_(std::move(page)) +{ +} + +StreamPageLock::~StreamPageLock() +{ + if (pageLocked_) { + page_->Unlock(); + } +} + +Status StreamPageLock::Lock(uint64_t timeoutMs) +{ + RETURN_RUNTIME_ERROR_IF_NULL(page_); + RETURN_IF_NOT_OK(page_->Lock(timeoutMs)); + pageLocked_ = true; + return Status::OK(); +} + +void StreamPageLock::Unlock() +{ + if (pageLocked_) { + page_->Unlock(); + pageLocked_ = false; + } +} + +void ElementHeader::Set(Ptr ptr, Size size, Version version) +{ + headerPtr_ = ptr; + headerSize_ = size; + headerVersion_ = version; +} + +HeaderAndData::HeaderAndData(const Element &element, const ElementHeader &header, uint64_t streamNo) + : Element(std::move(element)), ElementHeader(std::move(header)), streamNo(streamNo) +{ +} + +HeaderAndData::HeaderAndData(const Ptr ptr, const Size size, uint64_t streamNo) : Element(ptr, size), streamNo(streamNo) +{ +} + +HeaderAndData::Size HeaderAndData::TotalSize() const +{ + if (headerSize_) { + return size + headerSize_ + sizeof(Version); + } + return size; +} + +Status HeaderAndData::MemoryCopyTo(Ptr dest) const +{ + if (headerSize_) { + // Copy header version number + *dest = headerVersion_; + dest++; + // Copy header + RETURN_IF_NOT_OK(HugeMemoryCopy(dest, headerSize_, headerPtr_, headerSize_)); + dest += headerSize_; + } + // Copy raw data + return HugeMemoryCopy(dest, size, ptr, size); +} + +DataVerificationHeader::DataVerificationHeader(SeqNo seqNo, SenderProducerNo senderProducerNo, Address address, + Port port) +{ + hdr.seqNo = seqNo; + hdr.senderProducerNo = senderProducerNo; + hdr.address = address; + hdr.port = port; +} + +DataVerificationHeader::DataVerificationHeader(const ElementHeader &ele) +{ + HugeMemoryCopy(bytes, sizeof(bytes), ele.headerPtr_, sizeof(bytes)); +} + +DataVerificationHeader::SeqNo DataVerificationHeader::GetSeqNo() const +{ + return hdr.seqNo; +} + +DataVerificationHeader::SenderProducerNo DataVerificationHeader::GetSenderProducerNo() const +{ + return hdr.senderProducerNo; +} + +DataVerificationHeader::Address DataVerificationHeader::GetAddress() const +{ + return hdr.address; +} + +DataVerificationHeader::Port DataVerificationHeader::GetPort() const +{ + return hdr.port; +} + +DataVerificationHeader::Size DataVerificationHeader::HeaderSize() const +{ + return sizeof(bytes); +} + +void DataVerificationHeader::Set(SeqNo seqNo, SenderProducerNo senderProducerNo, Address address, Port port) +{ + hdr.seqNo = seqNo; + hdr.senderProducerNo = senderProducerNo; + hdr.address = address; + hdr.port = port; +} + +Status DataVerificationHeader::ExtractHeader(DataElement &element, ElementHeader &header) +{ + CHECK_FAIL_RETURN_STATUS( + element.size > sizeof(bytes), K_OUT_OF_RANGE, + FormatString("Element (header + data) size %llu is not greater than DataVerificationHeader size %lu", + element.size, sizeof(bytes))); + header.Set(element.ptr, sizeof(bytes), DATA_VERIFICATION_HEADER); + element.ptr += sizeof(bytes); + element.size -= sizeof(bytes); + return Status::OK(); +} + +std::string StreamPageBase::CreatePageId(const std::shared_ptr &shmInfo) +{ + return FormatString("F:%zu-M:%zu-O:%zu-S:%zu", shmInfo->fd, shmInfo->mmapSize, shmInfo->offset, shmInfo->size); +} + +StreamPageBase::StreamPageBase(std::shared_ptr shmInfo) +{ + pageUnit_ = std::move(shmInfo); + // Use the shmView as the id string rather than generating uuid + pageUnit_->id = CreatePageId(pageUnit_); +} + +StreamPageBase::StreamPageBase(std::shared_ptr shmInfo, std::shared_ptr mmapEntry) + : StreamPageBase(std::move(shmInfo)) +{ + mmapEntry_ = std::move(mmapEntry); +} + +void StreamPageBase::Init(bool isClient) +{ + startOfPage_ = reinterpret_cast(pageUnit_->pointer) + ((isClient) ? pageUnit_->offset : 0); +} + +ShmView StreamPageBase::GetShmView() const +{ + ShmView v = { .fd = pageUnit_->fd, .mmapSz = pageUnit_->mmapSize, .off = pageUnit_->offset, .sz = pageUnit_->size }; + return v; +} + +std::shared_ptr StreamPageBase::GetShmUnitInfo() const +{ + // Return a new copy of pageUnit_, not pageUnit_ itself. + return std::make_shared(pageUnit_->id, GetShmView(), pageUnit_->pointer); +} + +StreamLobPage::StreamLobPage(std::shared_ptr shmInfo, bool isClient) + : StreamPageBase(std::move(shmInfo)), isClient_(isClient) +{ +} + +StreamLobPage::StreamLobPage(std::shared_ptr shmInfo, bool isClient, + std::shared_ptr mmapEntry) + : StreamPageBase(std::move(shmInfo), std::move(mmapEntry)), isClient_(isClient) +{ +} + +Status StreamLobPage::Insert(const HeaderAndData &element) +{ + size_t totalFreeSpace = pageUnit_->size; + auto spaceNeeded = element.TotalSize(); + CHECK_FAIL_RETURN_STATUS(spaceNeeded <= totalFreeSpace, K_NO_SPACE, "Not enough space"); + RETURN_IF_NOT_OK(element.MemoryCopyTo(startOfPage_)); + LOG(INFO) << FormatString("[%s] Big element insert successful. Size %zu", GetPageId(), element.size); + return Status::OK(); +} + +Status StreamLobPage::Init() +{ + RETURN_RUNTIME_ERROR_IF_NULL(pageUnit_->pointer); + StreamPageBase::Init(isClient_); + return Status::OK(); +} + +StreamDataPage::StreamDataPage(std::shared_ptr shmInfo, uint32_t lockId, bool isClient, bool isSharedPage, + std::shared_ptr mmapEntry) + : StreamPageBase(std::move(shmInfo), std::move(mmapEntry)), + lockId_(lockId), + isClient_(isClient), + maxElementSize_(0), + isSharedPage_(isSharedPage) +{ +} + +Status StreamDataPage::Init() +{ + RETURN_RUNTIME_ERROR_IF_NULL(pageUnit_->pointer); + StreamPageBase::Init(isClient_); + // We are going to traverse the shared memory page to set up various pointers + // based on known offsets. + auto *data = startOfPage_; + pageHeader_ = reinterpret_cast(data); + // Tail leading to the next StreamDataPage. It is at the end of the page. 32 bytes. + tail_ = reinterpret_cast(data + PageSize() - sizeof(SharedMemView)); + // First area is the lock area. Always round up to 8 bytes in size. + pageLock_ = std::make_shared(&pageHeader_->lockArea_, &pageHeader_->lockWait_, lockId_); + // Start of the slot directory. Each slot is 4 byte. Slot directory grows forward and Element data are + // packed at the end of the page before the tail, and grows backward. + slotDir_ = reinterpret_cast(&pageHeader_->slot0_); + if (isClient_) { + auto slot0 = __atomic_load_n(slotDir_, __ATOMIC_SEQ_CST); + isSharedPage_ = slot0 & PAGE_SHARED_BIT; + } + // Compute how much space left. Slot 0 is in use. Space for slot 1 is pre-allocated. + // The remaining space (not counting the tail) should be at least big enough to hold 1 byte of element + maxElementSize_ = static_cast(PageSize()) - static_cast(PageOverhead(isSharedPage_)); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(maxElementSize_ > 0, K_INVALID, + FormatString("Page size %zu too small", PageSize())); + nextPage_ = std::make_shared(tail_, sizeof(SharedMemView), lockId_); + RETURN_IF_NOT_OK(nextPage_->Init()); + return Status::OK(); +} + +Status StreamDataPage::ResetToEmpty() +{ + // This is very much like InitEmptyPage except the underlying page has been + // called InitEmptyPage() earlier. We only reset a few things to allow to + // be reused as empty page. + CHECK_FAIL_RETURN_STATUS(!isClient_, K_INVALID, "Only worker can init the page"); + // Lock the page + StreamPageLock xlock(shared_from_this()); + RETURN_IF_NOT_OK(xlock.Lock(std::numeric_limits::max())); + // Wait until the reference drop to 1. This is to solve a racing condition that a producer fixed + // a page that has been recycled. We must drain the producers. These late producers will time + // out waiting for the lock and unfix the current page. + do { + uint32_t expected = __atomic_load_n(&pageHeader_->refCount_, __ATOMIC_RELAXED); + if (expected == 1) { + break; + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Waiting for page<%s> to be unreferenced. Current ref count %zu", + GetPageId(), expected); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } while (true); + auto totalFreeSpace = PagePayloadSize(); + __atomic_store_n(&pageHeader_->totalFreeSpace_, totalFreeSpace, __ATOMIC_RELAXED); + // Slot count back to 0. + __atomic_store_n(&pageHeader_->slotCount_, 0, __ATOMIC_RELAXED); + // begCursor is 0 so function like Insert can detect the page has been recycled. + __atomic_store_n(&pageHeader_->begCursor_, 0, __ATOMIC_RELAXED); + // Clear the next pointer + nextPage_->SetView(ShmView(), false, std::numeric_limits::max()); + // Clear the page level BigElement bit + UnsetPageHasBigElement(); + return Status::OK(); +} + +Status StreamDataPage::InitEmptyPage() +{ + CHECK_FAIL_RETURN_STATUS(!isClient_, K_INVALID, "Only worker can init the page"); + size_t freeSpace = PageSize(); + // Clear everything up to slotDir_ + size_t destSz = reinterpret_cast(slotDir_) - startOfPage_; + freeSpace -= destSz; // before slot0 + auto rc = memset_s(startOfPage_, PageSize(), 0, destSz); + CHECK_FAIL_RETURN_STATUS(rc == 0, K_RUNTIME_ERROR, FormatString("memset_s fails. Errno = %d", errno)); + // Clear the tail pointer which is at the end of the page + uint8_t *endOfPage = startOfPage_ + PageSize(); + auto *nextPtr = reinterpret_cast(tail_); + destSz = endOfPage - nextPtr; + RETURN_IF_NOT_OK(nextPage_->Init(true)); + freeSpace -= destSz; // SharedMemView + // Manually set those fields that are not zero. + freeSpace -= GetMetaSize(isSharedPage_); // slot 0 + // Free space is the amount of space left. + // Another way to verify the total free space of an empty page (see TryUnlockByLockId) and they should match + auto totalFreeSpace = PagePayloadSize(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + freeSpace == totalFreeSpace, K_RUNTIME_ERROR, + FormatString("Free space mismatch. Expect %zu but get %zu", freeSpace, totalFreeSpace)); + // Must be able to insert one element with max size + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + static_cast(maxElementSize_) + GetMetaSize(isSharedPage_) == freeSpace, K_RUNTIME_ERROR, + FormatString("Expect max element size %zu + one slot == free space %zu", maxElementSize_, freeSpace)); + __atomic_store_n(&pageHeader_->totalFreeSpace_, freeSpace, __ATOMIC_SEQ_CST); + // Init first cursor on this page to 1 for now. It will be updated later by the caller. + __atomic_store_n(&pageHeader_->begCursor_, 1, __ATOMIC_SEQ_CST); + // One reference to this page + __atomic_store_n(&pageHeader_->refCount_, 1, __ATOMIC_SEQ_CST); + // The slot directory is always one plus the number of elements such that + // the size of an element is inferred from the neighbour offset. + auto offset = static_cast(nextPtr - reinterpret_cast(slotDir_)); + auto slotFlag = isSharedPage_ ? PAGE_SHARED_BIT : 0; + auto slotAddr = GetSlotAddr(0); + slotAddr->StoreAll(isSharedPage_, slotFlag, offset, 0); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "Init page<%s> success. freeSpace = %zu, maxElementSize = %zu, slotDir_[0] = %d", GetPageId(), freeSpace, + maxElementSize_, slotDir_[0]); + return Status::OK(); +} + +Status StreamDataPage::Lock(uint64_t timeoutMs) +{ + return pageLock_->Lock(timeoutMs); +} + +void StreamDataPage::Unlock() +{ + pageLock_->Unlock(); +} + +void StreamDataPage::TryUnlockByLockId(uint32_t lockId) +{ + if (pageLock_->TryUnlockByLockId(lockId)) { + // If we can unlock, we now own the lock. Previous producer held the lock while it crashed, + // the page can be in an inconsistent state, let's fix up (if possible) before we unlock + // (a) Decrement the reference count. + (void)ReleasePage(FormatString("%s:%s", __FUNCTION__, __LINE__)); + // (b) The in flight slot directory and the totalFreeSpace can't be trusted anymore + // if the producer crashed after it held the lock but before it can release the lock. + // What we can do is bring both slot count to a slot with the consistency bit. + // total free space must be recalculated. + uint32_t slotCount = 0; + auto pendingSlotCount = GetSlotCount(); + auto totalFreeSpace = PagePayloadSize(); + for (uint32_t i = 0; i < pendingSlotCount; ++i) { + auto slot = i + 1; // slot directory is always one plus more. + if (GetSlotFlag(slot) & ELEMENT_DATA_CONSISTENT) { + ++slotCount; + totalFreeSpace -= GetMetaSize(isSharedPage_); + // Keep in mind the slot directory grow forward but elements grow backward. So its length + // should be calculated from the previous slot. + totalFreeSpace -= (GetSlotOffset(slot - 1) - GetSlotOffset(slot)); + } else { + break; + } + } + __atomic_store_n(&pageHeader_->slotCount_, slotCount, __ATOMIC_SEQ_CST); + __atomic_store_n(&pageHeader_->totalFreeSpace_, totalFreeSpace, __ATOMIC_SEQ_CST); + auto begCursor = pageHeader_->begCursor_; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "[Page:%s] Page recover success. begCursor = %zu, slot count = %zu, freeSpace = %zu", GetPageId(), + begCursor, slotCount, totalFreeSpace); + // Let go of the lock + Unlock(); + } +} + +uint64_t StreamDataPage::GetBegCursor() const +{ + auto begCursor = __atomic_load_n(&pageHeader_->begCursor_, __ATOMIC_RELAXED); + return begCursor; +} + +uint32_t StreamDataPage::GetSlotCount() const +{ + return __atomic_load_n(&pageHeader_->slotCount_, __ATOMIC_ACQUIRE); +} + +uint64_t StreamDataPage::GetLastCursor() const +{ + return GetBegCursor() + GetSlotCount() - 1; +} + +bool StreamDataPage::Empty() const +{ + return GetSlotCount() == 0; +} + +void StreamDataPage::UpdateSlotConsistentBit(uint32_t slot) +{ + auto slotAddr = GetSlotAddr(slot); + slotAddr->SetFlagBit(ELEMENT_DATA_CONSISTENT); +} + +uint32_t StreamDataPage::GetRefCount() const +{ + return __atomic_load_n(&pageHeader_->refCount_, __ATOMIC_RELAXED); +} + +Status StreamDataPage::RefPage(const std::string &logPrefix) +{ + PerfPoint point(PerfKey::PAGE_REF_INC); + auto curCount = __atomic_fetch_add(&pageHeader_->refCount_, 1, __ATOMIC_RELAXED); + if (VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL)) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, Page:%s] refCount after increase: %zu", logPrefix, + GetPageId(), 1 + curCount); + } + return Status::OK(); +} + +Status StreamDataPage::ReleasePage(const std::string &logPrefix) +{ + // All callers of ReleasePage pass in a log prefix containing the stream name + PerfPoint point(PerfKey::PAGE_REF_DEC); + constexpr static uint64_t MIN_REF_COUNT = 2; + bool success = false; + uint32_t curCount = 0; + do { + curCount = __atomic_load_n(&pageHeader_->refCount_, __ATOMIC_SEQ_CST); + // The initial reference count is always 1 when it is created. + // RefPage/ReleasePage must be called in the correct order. + // We can't call ReleasePage first and then followed by RefPage. + // So in this case, we expect the reference count is at least two. + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + curCount >= MIN_REF_COUNT, K_RUNTIME_ERROR, + FormatString("[%s, Page:%s] Unexpected reference count %zu", logPrefix, GetPageId(), curCount)); + success = __atomic_compare_exchange_n(&pageHeader_->refCount_, &curCount, curCount - 1, false, __ATOMIC_RELAXED, + __ATOMIC_RELAXED); + } while (!success); + // Always log if the ref count is abnormally large, it can be a sign of problem. + const int warningVal = 1000; + if (curCount > warningVal) { + LOG(INFO) << FormatString("[%s, Page:%s] refCount after decrease: %zu", logPrefix, GetPageId(), curCount - 1); + } else { + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString("[%s, Page:%s] refCount after decrease: %zu", logPrefix, + GetPageId(), curCount - 1); + } + return Status::OK(); +} + +bool StreamDataPage::HasNextPage() const +{ + ShmView v = GetNextPage(); + return v.fd != 0 && v.fd != -1; +} + +void StreamDataPage::SetNextPage(const ShmView &shm) +{ + nextPage_->SetView(shm, false, std::numeric_limits::max()); +} + +ShmView StreamDataPage::GetNextPage() const +{ + ShmView v; + bool isFreePage; + nextPage_->GetView(v, isFreePage, std::numeric_limits::max()); + // This form of GetNext only returns a view if the page is in use. + if (isFreePage) { + return {}; + } + return v; +} + +Status StreamDataPage::WakeUpConsumers() +{ + return PageLock::FutexWake(&pageHeader_->slotCount_, &pageHeader_->slotWait_); +} + +inline void SetAttributeBits(InsertFlags flags, SlotFlag &offset) +{ + if (TESTFLAG(flags, InsertFlags::REMOTE_ELEMENT)) { + offset |= REMOTE_ELEMENT_BIT; + } + if (TESTFLAG(flags, InsertFlags::BIG_ELEMENT)) { + offset |= BIG_ELEMENT_BIT; + } + if (TESTFLAG(flags, InsertFlags::HEADER)) { + offset |= HEADER_BIT; + } +} + +void StreamDataPage::SetPageHasBigElement() +{ + // It is an expensive operation to traverse each slot to check if is + // a big element row. To optimize the work, we are going to steal + // the high bits of slot0 which is never use until now. If the + // bit is set, there exists at one big element + SlotFlagOffset offset = __atomic_load_n(slotDir_, __ATOMIC_ACQUIRE); + offset |= BIG_ELEMENT_BIT; + __atomic_store_n(slotDir_, offset, __ATOMIC_RELEASE); +} + +void StreamDataPage::UnsetPageHasBigElement() +{ + // It is an expensive operation to traverse each slot to check if is + // a big element row. To optimize the work, we are going to steal + // the high bits of slot0 which is never use until now. If the + // bit is set, there exists at one big element + SlotFlagOffset offset = __atomic_load_n(slotDir_, __ATOMIC_ACQUIRE); + offset &= ~BIG_ELEMENT_BIT; + __atomic_store_n(slotDir_, offset, __ATOMIC_RELEASE); +} + +bool StreamDataPage::PageHasBigElement() +{ + SlotFlagOffset offset = __atomic_load_n(slotDir_, __ATOMIC_RELAXED); + return TESTFLAG(offset, BIG_ELEMENT_BIT); +} + +size_t StreamDataPage::GetMetaSize(bool isSharedPage) +{ + return isSharedPage ? sizeof(SlotType) : sizeof(SlotFlagOffset); +} + +size_t StreamDataPage::PageOverhead(bool isSharedPage) +{ + // Everything up to and including slot0, and we can use slot1 as a reference. + // Also take account of the next page pointer. Also reserve space for slot 1. + // To prevent the compiler adding any padding, we explicitly use offsetof plus + // sizeof rather than to use sizeof(StreamPageHeader). Without the 'packed' + // attribute, c++ compiler can the structure with another 4 bytes at the end + // of the struct. + return offsetof(StreamPageHeader, slot0_) + sizeof(SharedMemView) + GetMetaSize(isSharedPage) + + GetMetaSize(isSharedPage); +} + +size_t StreamDataPage::PagePayloadSize() +{ + return PageSize() - offsetof(StreamPageHeader, slot0_) - sizeof(SharedMemView) - GetMetaSize(isSharedPage_); +} + +size_t StreamDataPage::GetFreeSpaceSize() +{ + return __atomic_load_n(&pageHeader_->totalFreeSpace_, __ATOMIC_RELAXED); +} + +SlotFlag StreamDataPage::GetSlotFlag(size_t index) +{ + auto addr = GetSlotAddr(index); + return isSharedPage_ ? addr->value.flag : (addr->flagWithOffset & ~SLOT_VALUE_MASK); +} + +SlotOffset StreamDataPage::GetSlotOffset(size_t index) +{ + auto addr = GetSlotAddr(index); + return isSharedPage_ ? addr->value.offset : (addr->flagWithOffset & SLOT_VALUE_MASK); +} + +SlotType *StreamDataPage::GetSlotAddr(size_t index) +{ + if (!isSharedPage_) { + return reinterpret_cast(slotDir_ + index); + } + // (flag0, offset0, streamNo0), (flag1, offset1, streamNo1), ... + return reinterpret_cast(slotDir_) + index; +} + +Status StreamDataPage::ExtractBigElementsUpTo(uint64_t ackCursor, std::vector> &bigId, + bool deCouple) +{ + auto begCursor = GetBegCursor(); + RETURN_OK_IF_TRUE(ackCursor < begCursor); + // Because we are modifying the page, we need to lock to block + // other producers. + StreamPageLock pageLock(shared_from_this()); + const uint64_t DEFAULT_TIMEOUT_MS = 1000; + RETURN_IF_NOT_OK(pageLock.Lock(DEFAULT_TIMEOUT_MS)); + RETURN_OK_IF_TRUE(!PageHasBigElement()); + size_t offset1 = reinterpret_cast(slotDir_) - startOfPage_; + auto slotCount = GetSlotCount(); + for (size_t i = 0; i < slotCount; ++i) { + uint64_t cursor = begCursor + i; + if (cursor > ackCursor) { + break; + } + auto slotAddr = GetSlotAddr(i + 1); + DataElement ele; + ele.attr_ = slotAddr->LoadFlag(isSharedPage_); + // Like Receive, if the data is not ready, break out from the loop. We will resume again next time. + if (!ele.DataIsReady()) { + break; + } + if (!ele.IsBigElement()) { + continue; + } + auto offset = slotAddr->LoadOffset(isSharedPage_); + SlotOffset b4 = GetSlotAddr(i)->LoadOffset(isSharedPage_); + ele.size = b4 - offset; + ele.ptr = startOfPage_ + offset1 + offset; + ele.id = cursor; + // We are going to decouple the pointer to the big element page, and turn off the BigElement bit. + // If we need to revisit this page again, we will then ignore it. + ShmView pageView; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ParseShmViewPb(ele.ptr, ele.size, pageView), "ReleaseBigElementsUpTo"); + if (deCouple) { + // We no longer consider this as a big element row. + slotAddr->ClearFlagBit(BIG_ELEMENT_BIT); + } + bigId.emplace_back(ele.id, pageView); + } + return Status::OK(); +} + +Status StreamDataPage::Insert(const HeaderAndData &element, uint64_t timeoutMs, InsertFlags &flags, + const std::string &logPrefix) +{ + INJECT_POINT("producer_insert"); + PerfPoint point(PerfKey::PAGE_INSERT_ELEMENT); + auto *totalFreeSpace_ = &pageHeader_->totalFreeSpace_; + auto *slotCount_ = &pageHeader_->slotCount_; + size_t finalElementSize = element.TotalSize(); + size_t spaceNeeded = GetMetaSize(isSharedPage_) + finalElementSize; + // Make sure this element is not exceeding the maximum free space. There is no way + // any page can hold the big element. We need to account a new slot for Element. + // Slot 0 is in use. Space for slot 1 is pre-allocated, + // and so we can allow element of max stream element size. + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + finalElementSize <= static_cast(maxElementSize_), K_INVALID, + FormatString("Element size %zu (plus internal overhead) is exceeding the maximum free space %zu", + finalElementSize, maxElementSize_)); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(element.size > 0, K_INVALID, "Element size should be greater than 0"); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(element.ptr != nullptr, K_INVALID, "Element ptr should not be a nullptr"); + // The maximum length we can support is 30 bits + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + (finalElementSize & ~(static_cast(SLOT_VALUE_MASK))) == 0, K_INVALID, + FormatString("Element size %zu is exceeding the maximum length", finalElementSize)); + StreamPageLock pageLock(shared_from_this()); + // A few shortcuts before we try to hold the lock for insert. + // These methods involve looking at some atomic fields. + // (a) If the totalFreeSpace is too small, don't bother to get the lock + auto totalFreeSpace = __atomic_load_n(totalFreeSpace_, __ATOMIC_RELAXED); + CHECK_FAIL_RETURN_STATUS(spaceNeeded <= totalFreeSpace, K_NO_SPACE, "Not enough space"); + // (b) If this page has a next pointer. If there is one, just follow the pointer. + CHECK_FAIL_RETURN_STATUS(!HasNextPage(), K_SC_END_OF_PAGE, "Check next page for new elements"); + // Now we lock the page for insert because this page is shared by producers. + // Consumers on the other hand traverse the page without any lock. + if (!TESTFLAG(flags, InsertFlags::SKIP_LOCK)) { + RETURN_IF_NOT_OK(pageLock.Lock(timeoutMs)); + } + INJECT_POINT("producer_obtained_lock"); + // There is a racing condition that a page can be cycled and put into the free list. + // One way to detect this page is on a free list is check the begCursor_. + auto begCursor = __atomic_load_n(&pageHeader_->begCursor_, __ATOMIC_RELAXED); + CHECK_FAIL_RETURN_STATUS(begCursor > 0, K_TRY_AGAIN, "Page is already recycled"); + // After we get the lock, do the same check again + CHECK_FAIL_RETURN_STATUS(!HasNextPage(), K_SC_END_OF_PAGE, "Check next page for new elements."); + totalFreeSpace = __atomic_load_n(totalFreeSpace_, __ATOMIC_RELAXED); + CHECK_FAIL_RETURN_STATUS(spaceNeeded <= totalFreeSpace, K_NO_SPACE, "Not enough space"); + totalFreeSpace = __atomic_sub_fetch(totalFreeSpace_, spaceNeeded, __ATOMIC_RELAXED); + INJECT_POINT("producer_update_free_space"); + auto numElement = __atomic_load_n(slotCount_, __ATOMIC_ACQUIRE); + + SlotOffset offset = GetSlotOffset(numElement) - static_cast(finalElementSize); + uint8_t *dest = reinterpret_cast(slotDir_) + offset; + // We will do asynchronous memory copy by letting go of the page lock + // once we know where we will write the data to. + // Need to distinguish an inserted element is from a local producer or a remote producer. + // If coming from a remote producer, set the high bit. Same for big element + SlotFlag slotFlag = 0; + SetAttributeBits(flags, slotFlag); + // The slot directory is always one plus the number of elements. + auto slotAddr = GetSlotAddr(numElement + 1); + slotAddr->StoreAll(isSharedPage_, slotFlag, offset, element.streamNo); + if (TESTFLAG(flags, InsertFlags::BIG_ELEMENT)) { + SetPageHasBigElement(); + } + INJECT_POINT("producer_update_slot_directory"); + // Slot count is the last step to update. Consumer may futex sleep on the slotCount_ + // and we should only it when free space and slot offset are set. + __atomic_store_n(slotCount_, 1 + numElement, __ATOMIC_RELEASE); + INJECT_POINT("producer_update_pending_slot_count_holding_lock"); + // Let go of the lock + pageLock.Unlock(); + INJECT_POINT("producer_update_pending_slot_count_without_lock"); + // Caution! After this point, we no longer hold any lock. + PerfPoint perfPoint(PerfKey::PAGE_ELEMENT_MEMORY_COPY); + RETURN_IF_NOT_OK(element.MemoryCopyTo(dest)); + perfPoint.RecordAndReset(PerfKey::PAGE_CAS_SLOT_COUNT); + // Update the slot with the consistent bit + UpdateSlotConsistentBit(numElement + 1); + perfPoint.Record(); + if (!TESTFLAG(flags, InsertFlags::DELAY_WAKE)) { + RETURN_IF_NOT_OK(WakeUpConsumers()); + } + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%sCursor %zu] Add element success. slot = %zu, offset = %zu, length = %zu, freeSpace = %zu, bigEle = %s, " + "header = %s, sharedPage = %s, streamNo = %zu, pageId = %s", + (!logPrefix.empty() ? logPrefix + " " : ""), pageHeader_->begCursor_ + numElement, numElement + 1, offset, + finalElementSize, totalFreeSpace, BoolToString(TESTFLAG(flags, InsertFlags::BIG_ELEMENT)), + BoolToString(TESTFLAG(flags, InsertFlags::HEADER)), BoolToString(isSharedPage_), element.streamNo, GetPageId()); + SETFLAG(flags, InsertFlags::INSERT_SUCCESS); + return Status::OK(); +} + +Status CalcMaxAllowRows(void *buf, std::vector &sz, const size_t totalFreeSpace, bool isSharedPage, + StreamMetaShm *streamMetaShm, uint8_t *&src, size_t &spaceNeeded, size_t &numInsert, + size_t &totalLength) +{ + size_t bufSz = std::accumulate(sz.begin(), sz.end(), 0ul); + // Elements are packed in reverse order. We will find out how many elements we can insert + // and the caller can continue from where we left off next time. + src = reinterpret_cast(buf) + bufSz; + spaceNeeded = 0; + for (size_t i = 0; i < sz.size(); ++i) { + auto eleSz = sz[i]; + auto sizeNeeded = eleSz + StreamDataPage::GetMetaSize(isSharedPage); // account for one slot + if (spaceNeeded + sizeNeeded > totalFreeSpace) { + break; + } + if (streamMetaShm != nullptr) { + RETURN_IF_NOT_OK(streamMetaShm->TryIncUsage(eleSz)); + } + // We can take this element and fit on the page + spaceNeeded += sizeNeeded; + src -= eleSz; // Elements are packed in reverse order + totalLength += eleSz; + ++numInsert; + } + return Status::OK(); +} + +Status StreamDataPage::BatchInsert(void *buf, std::vector &sz, uint64_t timeoutMs, + std::pair &res, InsertFlags flags, + const std::vector &headerBits, StreamMetaShm *streamMetaShm) +{ + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(!isSharedPage_, K_RUNTIME_ERROR, + FormatString("BatchInsert not allow apply for shared page %s", GetPageId())); + size_t numInsert = 0; + size_t totalLength = 0; + Raii raii([&res, &numInsert, &totalLength]() { + res.first = numInsert; + res.second = totalLength; + }); + // This is a special form of batch insert mainly for remote worker. + auto *totalFreeSpace_ = &pageHeader_->totalFreeSpace_; + auto *slotCount_ = &pageHeader_->slotCount_; + StreamPageLock pageLock(shared_from_this()); + if (!TESTFLAG(flags, InsertFlags::SKIP_LOCK)) { + RETURN_IF_NOT_OK(pageLock.Lock(timeoutMs)); + } + // There is a racing condition that a page can be cycled and put into the free list. + // One way to detect this page is on a free list is check the begCursor_. + auto begCursor = __atomic_load_n(&pageHeader_->begCursor_, __ATOMIC_RELAXED); + CHECK_FAIL_RETURN_STATUS(begCursor > 0, K_TRY_AGAIN, "Page is already recycled"); + // (a) If this page has a next pointer. If there is one, just follow the pointer. + CHECK_FAIL_RETURN_STATUS(!HasNextPage(), K_SC_END_OF_PAGE, "Check next page for new elements"); + auto totalFreeSpace = __atomic_load_n(totalFreeSpace_, __ATOMIC_RELAXED); + // Elements are packed in reverse order. We will find out how many elements we can insert + // and the caller can continue from where we left off next time. + uint8_t *src = nullptr; + size_t spaceNeeded = 0; + auto rc = CalcMaxAllowRows(buf, sz, totalFreeSpace, isSharedPage_, streamMetaShm, src, spaceNeeded, numInsert, + totalLength); + // (b) If the totalFreeSpace is too small, next page. + if (numInsert == 0) { + return rc.IsError() ? rc : Status(K_NO_SPACE, "Not enough space"); + } + uint32_t numElement = __atomic_load_n(slotCount_, __ATOMIC_ACQUIRE); + totalFreeSpace = __atomic_sub_fetch(totalFreeSpace_, spaceNeeded, __ATOMIC_RELAXED); + for (size_t i = 0; i < numInsert; ++i) { + auto slot = numElement + i + 1; // slot directory is always one more + SlotFlagOffset offset = GetSlotOffset(slot - 1) - static_cast(sz[i]); + if (headerBits[i]) { + SetAttributeBits(flags | InsertFlags::HEADER, offset); + } else { + SetAttributeBits(flags, offset); + } + __atomic_store_n(static_cast(slotDir_ + slot), offset, __ATOMIC_RELEASE); + } + if (TESTFLAG(flags, InsertFlags::BIG_ELEMENT)) { + SetPageHasBigElement(); + } + uint8_t *dest = reinterpret_cast(slotDir_) + GetSlotOffset(numElement + numInsert); + // Slot count is the last step to update. Consumer may futex sleep on the slotCount_ + // and we should only it when free space and slot offset are set. + __atomic_store_n(slotCount_, numElement + static_cast(numInsert), __ATOMIC_RELEASE); + // Just like the Insert case, we will let go of the lock once we get all the offsets + pageLock.Unlock(); + // Caution! After this point, we no longer hold any lock. + RETURN_IF_NOT_OK(HugeMemoryCopy(dest, totalLength, src, totalLength)); + for (size_t i = 0; i < numInsert; ++i) { + auto slot = numElement + i + 1; // slot directory is always one more + UpdateSlotConsistentBit(slot); + } + // Wake up (futex) any reader + if (!TESTFLAG(flags, InsertFlags::DELAY_WAKE)) { + RETURN_IF_NOT_OK(WakeUpConsumers()); + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "Batch add %zu element into %s success. " + "begCursor = %zu, freeSpace = %zu, bigEle = %s", + numInsert, GetPageId(), pageHeader_->begCursor_ + numElement, totalFreeSpace, + BoolToString(TESTFLAG(flags, InsertFlags::BIG_ELEMENT))); + return Status::OK(); +} + +Status StreamDataPage::WaitForNewElement(uint64_t lastRecvCursor, uint64_t timeoutMs) +{ + uint32_t slotCount = GetSlotCount(); + auto endCursor = pageHeader_->begCursor_ + slotCount; + // Early exit if we have some new elements since the last read + RETURN_OK_IF_TRUE(lastRecvCursor + 1 < endCursor); + // If there is a next page, check the next page. + if (HasNextPage()) { + // But refresh the slotCount and check again. This slotCount is final. + slotCount = GetSlotCount(); + endCursor = pageHeader_->begCursor_ + slotCount; + RETURN_OK_IF_TRUE(lastRecvCursor + 1 < endCursor); + // Now we truly exhaust all the elements on this page + RETURN_STATUS(K_SC_END_OF_PAGE, "Check next page for new elements."); + } + // In all other cases, continue to futex wait. + if (timeoutMs == 0) { + RETURN_STATUS(K_TRY_AGAIN, "Non-blocking call and there is no new element inserted"); + } + INJECT_POINT("StreamDataPage.WaitOnFutexForever", [this, timeoutMs, slotCount]() { + // This is for the testcase ProducerTest::TestConsumerFutexWake. We wait until + // the producer signal us to proceed to wait on the futex + LOG(INFO) << "Wait for signal from producer. TimeoutMs = " << timeoutMs << ". slotCount = " << slotCount; + while (!HasNextPage()) { + std::this_thread::yield(); + } + // Wait a bit before we wait on the futex + const auto sleepMs = 5'000ul; + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); + return Status::OK(); + }); + Status rc = PageLock::FutexWait(&pageHeader_->slotCount_, &pageHeader_->slotWait_, slotCount, timeoutMs); + if (rc.IsOk()) { + // Update the slotCount again for spurious wake up + slotCount = GetSlotCount(); + endCursor = pageHeader_->begCursor_ + slotCount; + RETURN_OK_IF_TRUE(lastRecvCursor + 1 < endCursor); + // We can also be waked up CreateNewPage. But let the caller handle it. + RETURN_STATUS(K_TRY_AGAIN, "No new element inserted"); + } + return rc; +} + +Status StreamDataPage::ParseShmViewPb(const void *ptr, size_t sz, ShmView &out) +{ + ShmViewPb pb; + bool success = pb.ParseFromArray(ptr, sz); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(success, K_OUT_OF_RANGE, "ShmViewPb parse error"); + ShmView v{ .fd = pb.fd(), .mmapSz = pb.mmap_size(), .off = static_cast(pb.offset()), .sz = pb.size() }; + out = v; + return Status::OK(); +} + +Status StreamDataPage::SerializeToShmViewPb(const ShmView &pageView, std::string &out) +{ + ShmViewPb pb; + pb.set_fd(pageView.fd); + pb.set_mmap_size(pageView.mmapSz); + pb.set_offset(pageView.off); + pb.set_size(pageView.sz); + bool rc = pb.SerializeToString(&out); + CHECK_FAIL_RETURN_STATUS(rc, K_RUNTIME_ERROR, "Serialization error"); + return Status::OK(); +} + +Status StreamDataPage::Receive(uint64_t lastRecvCursor, uint64_t timeoutMs, std::vector &out, + const std::string &logPrefix) +{ + const auto &begCursor = pageHeader_->begCursor_; + auto startCursor = lastRecvCursor + 1; + if (startCursor < begCursor) { + RETURN_STATUS_LOG_ERROR( + K_OUT_OF_RANGE, FormatString("[P:%s] Starting read position %zu not on this page [%zu, %zu)", GetPageId(), + startCursor, begCursor, begCursor + GetSlotCount())); + } + INJECT_POINT("StreamDataPage::Receive.sleep"); + // Unlike producers, consumers do not hold the lock to read. We rely on + // certain fields are updated atomically. + RETURN_IF_NOT_OK(WaitForNewElement(lastRecvCursor, timeoutMs)); + uint32_t slotCount = GetSlotCount(); + auto endCursor = begCursor + slotCount; + // All slot offsets are relative to the slot directory + size_t offset1 = reinterpret_cast(slotDir_) - startOfPage_; + // Go over the cursor in [startCursor .. endCursor) + for (auto i = startCursor; i < endCursor; ++i) { + auto slot = i - begCursor; + auto slotAddr = GetSlotAddr(slot + 1); + DataElement ele; + ele.attr_ = slotAddr->LoadFlag(isSharedPage_, __ATOMIC_ACQUIRE); + INJECT_POINT("StreamDataPage::Receive.fake.BIG_ELEMENT", [startCursor, begCursor, &ele]() { + LOG(INFO) << "startCursor = " << startCursor << " begCursor = " << begCursor; + if (startCursor != begCursor) { + ele.attr_ |= BIG_ELEMENT_BIT; + ele.attr_ |= ELEMENT_DATA_CONSISTENT; + } + return Status::OK(); + }); + // If the data is still in flight, return whatever we have so far. + if (!ele.DataIsReady()) { + break; + } + auto offset = slotAddr->LoadOffset(isSharedPage_, __ATOMIC_ACQUIRE); + // The slot directory is always one plus the number of elements. + // Data grows backward. So its size is calculated from the offset + // of slot before it. + SlotOffset b4 = GetSlotAddr(slot)->LoadOffset(isSharedPage_); + ele.size = b4 - offset; + // Set up the pointer + ele.ptr = startOfPage_ + offset1 + offset; + // Pass the cursor as well + ele.id = i; + // get stream no. + ele.streamNo_ = slotAddr->LoadStreamNo(isSharedPage_); + // Tag where the elements come from + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%sCursor %zu] Fetch element success. slot = %zu, offset = %zu, length = %zu, local = %s, " + "bigElement = %s, header = %s, sharedPage = %s, streamNo = %zu, pageId = %s", + (!logPrefix.empty() ? logPrefix + " " : ""), ele.id, slot + 1, offset, ele.size, + BoolToString(!ele.IsRemote()), BoolToString(ele.IsBigElement()), BoolToString(ele.HasHeader()), + BoolToString(isSharedPage_), ele.streamNo_, GetPageId()); + out.emplace_back(ele); + } + return Status::OK(); +} + +Status StreamDataPage::Seal(const ShmView &nextPage, uint64_t timeoutMs, + std::function &)> locatePage, + const std::string &logPrefix) +{ + // Do not seal a page with null pointer + CHECK_FAIL_RETURN_STATUS(nextPage.fd > 0, K_INVALID, + FormatString("[%s] Seal a page with invalid pointer %s", GetPageId(), nextPage.ToStr())); + // To seal a page, we update the next pointer. All producers will then stop inserting + // into the current page. The last cursor on the page is used as an index key to the next page. + // Some racing condition to consider. + // (a) Page has no room for producer A to insert a new element. + // (b) A updates the next pointer and use the last cursor as the key. + // (c) Producer B however can insert a small element to the page. + // (d) The begCursor of the new page will have the same cursor value as the element inserted by B. + // So when we *seal* a page, we must block any producer to insert any more + // new element to the old page. + StreamPageLock pageLock(shared_from_this()); + RETURN_IF_NOT_OK(pageLock.Lock(timeoutMs)); + // Never seal an empty page + CHECK_FAIL_RETURN_STATUS(!Empty(), K_RUNTIME_ERROR, FormatString("[%s] Empty page", GetPageId())); + // If we have a pointer and is in use already, report error unless it is the same given ShmView + ShmView v = GetNextPage(); + if (v.fd <= 0) { + // This is the index key to the next page and to the index chain. + uint64_t lastAppendCursor = GetLastCursor(); + std::shared_ptr page; + RETURN_IF_NOT_OK(locatePage(nextPage, page)); + // Main reason why we need to lock the current page in order to set the starting cursor of the next page. + // Also, atomically update two atomic fields. + auto func = [this, &nextPage, lastAppendCursor, &page]() { + __atomic_store_n(&page->pageHeader_->begCursor_, lastAppendCursor + 1, __ATOMIC_SEQ_CST); + tail_->CopyFrom(nextPage); + }; + nextPage_->LockExclusiveAndExec(func, std::numeric_limits::max()); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Chain page<%s> [%zu, ) to page<%s> [%zu, %zu]", logPrefix, + page->GetPageId(), lastAppendCursor + 1, GetPageId(), + GetBegCursor(), lastAppendCursor); + // Wake up anyone waiting on the old page + RETURN_IF_NOT_OK(WakeUpConsumers()); + return Status::OK(); + } + CHECK_FAIL_RETURN_STATUS(v == nextPage, K_RUNTIME_ERROR, + FormatString("Page<%s> is sealed already. Next page %s", GetPageId(), v.ToStr())); + RETURN_STATUS(K_DUPLICATED, FormatString("Page<%s> is sealed already", GetPageId())); +} + +size_t StreamDataPage::GetTotalEleSize() +{ + if (GetSlotCount() == 0) { + return 0; + } + auto end = GetSlotAddr(0)->LoadOffset(isSharedPage_); + auto start = GetSlotAddr(GetSlotCount())->LoadOffset(isSharedPage_); + if (end <= start) { + LOG(WARNING) << FormatString("The layout of this page may be confusing, start: %zu, end: %zu", start, end); + return 0; + } + return end - start; +} + +Status DataElement::CheckAttribute() const +{ + // For compatibility with downlevel client, we only make sure no reserved bits are in use. + auto highBits = attr_; + CLEARFLAG(highBits, SLOT_VALUE_MASK); + CLEARFLAG(highBits, ELEMENT_DATA_CONSISTENT); + CLEARFLAG(highBits, REMOTE_ELEMENT_BIT); + CLEARFLAG(highBits, BIG_ELEMENT_BIT); + CLEARFLAG(highBits, HEADER_BIT); + INJECT_POINT("StreamDataPage.CheckHighBits", [&highBits](uint32_t v) { + highBits = v; + return Status::OK(); + }); + // What remains should be 0. If any bit is set, this is uplevel code. + RETURN_OK_IF_TRUE(highBits == 0); + std::stringstream oss; + oss << "Incompatibility with up level worker detected. Slot value = 0x" << std::hex << attr_; + RETURN_STATUS_LOG_ERROR(K_CLIENT_WORKER_VERSION_MISMATCH, oss.str()); +} + +bool DataElement::DataIsReady() const +{ + return TESTFLAG(attr_, ELEMENT_DATA_CONSISTENT); +} + +bool DataElement::IsRemote() const +{ + return TESTFLAG(attr_, REMOTE_ELEMENT_BIT); +} + +bool DataElement::IsBigElement() const +{ + return TESTFLAG(attr_, BIG_ELEMENT_BIT); +} + +bool DataElement::HasHeader() const +{ + return TESTFLAG(attr_, HEADER_BIT); +} + +uint64_t DataElement::GetStreamNo() const +{ + return streamNo_; +} +} // namespace datasystem diff --git a/src/datasystem/common/stream_cache/stream_data_page.h b/src/datasystem/common/stream_cache/stream_data_page.h new file mode 100644 index 0000000..f926c15 --- /dev/null +++ b/src/datasystem/common/stream_cache/stream_data_page.h @@ -0,0 +1,575 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_STREAM_CACHE_STREAM_DATA_PAGE_H +#define DATASYSTEM_STREAM_CACHE_STREAM_DATA_PAGE_H + +#include + +#include "datasystem/client/mmap_table.h" +#include "datasystem/common/object_cache/lock.h" +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/stream_cache/cursor.h" +#include "datasystem/common/stream_cache/stream_meta_shm.h" +#include "datasystem/common/util/bitmask_enum.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/stream/element.h" +#include "datasystem/utils/optional.h" +#include "datasystem/protos/stream_posix.pb.h" + +namespace datasystem { +// For the lock on StreamDataPage, we only need exclusive lock functionality. +// Producers always lock exclusively. Consumer never locks the page. +// We only need the lower order bit for locked or unlocked. The rest of the bits +// are to store the lockId. So totally 4 bytes are sufficient. +class PageLock { +public: + constexpr static int NO_LOCK_NUM = 0; + constexpr static int WRITE_LOCK_NUM = 1; + constexpr static int SHIFT = 1; + explicit PageLock(uint32_t *lockArea, uint32_t *waitArea, uint32_t lockId = 0); + ~PageLock() = default; + + /** + * Lock exclusively + * @param timeoutMs in milliseconds + * @return OK if successful + */ + Status Lock(uint64_t timeoutMs); + + /** + * Unlock a previously held lock + * @param tid + */ + void Unlock(); + + /** + * @brief General form of calling futex wait + * @param lockArea + * @param waitCount + * @param timeoutMs + * @return + */ + static Status FutexWait(uint32_t *lockArea, uint32_t *waitCount, uint32_t val, uint64_t timeoutMs); + + /** + * @brief General form of calling futex wake + * @param lockArea + * @param waitCount + * @param numToWakeUp + * @return + */ + static Status FutexWake(uint32_t *lockArea, uint32_t *waitCount, int numToWakeUp = INT_MAX); + + /** + * Unlock by lockID. Used by worker for client crash recovery. + * @param lockId + * @return true if the lock is held by the given lockId + */ + bool TryUnlockByLockId(uint32_t lockId); + +private: + uint32_t *lockFlag_; + uint32_t *waitCount_; + uint32_t lockId_; +}; + +// A slot is 4 bytes. The first 8 high bits are used for special purposes. +// That leaves us 24 bits for offset. Maximum offset is then 16'777'215ul +constexpr static uint32_t REMOTE_ELEMENT_BIT = static_cast(0x80000000); +constexpr static uint32_t ELEMENT_DATA_CONSISTENT = static_cast(0x40000000); +constexpr static uint32_t BIG_ELEMENT_BIT = static_cast(0x20000000); +constexpr static uint32_t HEADER_BIT = static_cast(0x10000000); +constexpr static uint32_t FUTURE_1_BIT = static_cast(0x08000000); +constexpr static uint32_t FUTURE_2_BIT = static_cast(0x04000000); +constexpr static uint32_t FUTURE_3_BIT = static_cast(0x02000000); +constexpr static uint32_t FUTURE_4_BIT = static_cast(0x01000000); +constexpr static uint32_t SLOT_VALUE_MASK = static_cast(0x00FFFFFF); +// Is shared page if the first bit of slot0 is set +constexpr static uint32_t PAGE_SHARED_BIT = static_cast(0x80000000); + +// This class extends the Element class with the additional attribute byte from above +class DataElement : public Element { +public: + bool DataIsReady() const; + bool IsRemote() const; + bool IsBigElement() const; + bool HasHeader() const; + Status CheckAttribute() const; + uint64_t GetStreamNo() const; + uint32_t attr_{ 0 }; + uint64_t streamNo_{ 0 }; +}; + +class ElementHeader { +public: + typedef uint8_t *Ptr; + typedef uint32_t Size; + typedef uint8_t Version; + Ptr headerPtr_{ nullptr }; + Size headerSize_{ 0 }; + Version headerVersion_{ 0 }; + void Set(Ptr ptr, Size size, Version version); +}; + +constexpr static ElementHeader::Version DATA_VERIFICATION_HEADER = static_cast(1); +class HeaderAndData : public Element, public ElementHeader { +public: + typedef uint32_t Size; + typedef uint8_t *Ptr; + HeaderAndData(const Element &element, const ElementHeader &header, uint64_t streamNo); + HeaderAndData(const Ptr ptr, const Size size, uint64_t streamNo); + Size TotalSize() const; + Status MemoryCopyTo(Ptr dest) const; + uint64_t streamNo{ 0 }; +}; + +// This struct contain fields to be copied in front of element data if needed. +struct DataVerificationHeader { + typedef uint64_t SeqNo; + typedef uint64_t SenderProducerNo; + typedef uint32_t Size; + typedef uint32_t Address; + typedef uint16_t Port; + + union { + struct { + SeqNo seqNo; + SenderProducerNo senderProducerNo; + Address address; + Port port; + } hdr; + uint8_t bytes[sizeof(SeqNo) + sizeof(SenderProducerNo) + sizeof(Address) + sizeof(Port)]; + }; + + DataVerificationHeader(SeqNo seqNo = std::numeric_limits::max(), + SenderProducerNo localProducerNo = std::numeric_limits::max(), + Address address = std::numeric_limits

::max(), + Port port = std::numeric_limits::max()); + DataVerificationHeader(const ElementHeader &ele); + SeqNo GetSeqNo() const; + SenderProducerNo GetSenderProducerNo() const; + Address GetAddress() const; + Port GetPort() const; + Size HeaderSize() const; + void Set(SeqNo seqNo, SenderProducerNo producerNo, Address address, Port port); + static Status ExtractHeader(DataElement &element, ElementHeader &header); +}; + +class StreamPageBase { +public: + explicit StreamPageBase(std::shared_ptr); + explicit StreamPageBase(std::shared_ptr shmInfo, std::shared_ptr mmapEntry); + virtual ~StreamPageBase() = default; + + /** + * Initialization + * @param isClient + */ + void Init(bool isClient); + + /** + * @brief return the start address + * @return start address + */ + void *GetPointer() const + { + return startOfPage_; + } + + /** + * @brief return the page size. + * @return page size + */ + auto PageSize() const + { + return pageUnit_->GetSize(); + } + + static std::string CreatePageId(const std::shared_ptr &pageUnit); + + /** + * @brief Return the page id + * @return + */ + std::string GetPageId() const + { + return pageUnit_->GetId(); + } + + /** + * @brief Return ShmView of this page + */ + ShmView GetShmView() const; + + /** + * @brief Return the ShmUnitInfo for this page + * @return + */ + std::shared_ptr GetShmUnitInfo() const; + +protected: + std::shared_ptr pageUnit_; + std::shared_ptr mmapEntry_; + uint8_t *startOfPage_{ nullptr }; + +private: +}; + +// A blank stream page (without any header) that contains only raw data +class StreamLobPage : public StreamPageBase, public std::enable_shared_from_this { +public: + explicit StreamLobPage(std::shared_ptr shmInfo, bool isClient); + explicit StreamLobPage(std::shared_ptr shmInfo, bool isClient, + std::shared_ptr mmapEntry); + ~StreamLobPage() override = default; + Status Init(); + Status Insert(const HeaderAndData &element); + +private: + const bool isClient_; +}; + +enum class InsertFlags : uint32_t { + NONE = 0, + REMOTE_ELEMENT = 1u, + DELAY_WAKE = 1u << 1, + SKIP_LOCK = 1u << 2, + BIG_ELEMENT = 1u << 3, + HEADER = 1u << 4, + INSERT_SUCCESS = 1u << 5 +}; +ENABLE_BITMASK_ENUM_OPS(InsertFlags); + +namespace worker { +namespace stream_cache { +class ExclusivePageQueue; +} +} // namespace worker + +namespace client { +namespace stream_cache { +class ProducerImpl; +} +} // namespace client + +using SlotFlagOffset = uint32_t; +using SlotFlag = uint32_t; +using SlotOffset = uint32_t; + +union SlotType { + SlotFlagOffset flagWithOffset; // 4 bytes for exclusive page. + struct FlagOffsetStreamNo { // for shared page. + SlotFlag flag; // 4 bytes + SlotOffset offset; // 4 bytes + uint64_t streamNo; // 8 bytes + } value; + uint64_t flagAndOffset; + + void StoreAll(bool enableSharedPage, SlotFlag flag, SlotOffset offset, uint64_t streamNo, + int memModel = __ATOMIC_RELEASE) + { + if (enableSharedPage) { + SlotType slot; + slot.value.flag = flag; + slot.value.offset = offset; + slot.value.streamNo = streamNo; + __atomic_store_n(&flagAndOffset, slot.flagAndOffset, memModel); + __atomic_store_n(&value.streamNo, slot.value.streamNo, memModel); + } else { + SlotFlagOffset flagOffset = flag & ~SLOT_VALUE_MASK; + flagOffset |= offset & SLOT_VALUE_MASK; + __atomic_store_n(&flagWithOffset, flagOffset, memModel); + } + } + + SlotFlag LoadFlag(bool enableSharedPage, int memModel = __ATOMIC_SEQ_CST) + { + if (enableSharedPage) { + return __atomic_load_n(&value.flag, memModel); + } + return __atomic_load_n(&flagWithOffset, memModel) & ~SLOT_VALUE_MASK; + } + + SlotOffset LoadOffset(bool enableSharedPage, int memModel = __ATOMIC_SEQ_CST) + { + if (enableSharedPage) { + return __atomic_load_n(&value.offset, memModel); + } + return __atomic_load_n(&flagWithOffset, memModel) & SLOT_VALUE_MASK; + } + + uint64_t LoadStreamNo(bool enableSharedPage, int memModel = __ATOMIC_SEQ_CST) + { + if (enableSharedPage) { + return __atomic_load_n(&value.streamNo, memModel); + } + return 0; + } + + void SetFlagBit(SlotFlag addBit) + { + SlotFlag slotFlag = __atomic_load_n(&flagWithOffset, __ATOMIC_ACQUIRE); + SETFLAG(slotFlag, addBit); + __atomic_store_n(&flagWithOffset, slotFlag, __ATOMIC_RELEASE); + } + + void ClearFlagBit(SlotFlag delBit) + { + SlotFlag slotFlag = __atomic_load_n(&flagWithOffset, __ATOMIC_ACQUIRE); + CLEARFLAG(slotFlag, delBit); + __atomic_store_n(&flagWithOffset, slotFlag, __ATOMIC_RELEASE); + } +}; +// Begin memory layout of the page in order. Do NOT add anything that +// breaks alignment. All fields must be packed without any gap. +struct StreamPageHeader { + uint64_t begCursor_; // 8 bytes + uint32_t lockArea_; // 4 bytes + uint32_t lockWait_; // 4 bytes + uint32_t refCount_; // 4 bytes + uint32_t totalFreeSpace_; // 4 bytes + uint32_t slotCount_; // 4 bytes + uint32_t slotWait_; // 4 bytes + SlotType slot0_; // start of variable size array. +}; + +class StreamDataPage : public StreamPageBase, public std::enable_shared_from_this { +public: + explicit StreamDataPage(std::shared_ptr shmInfo, uint32_t lockId, bool isClient, + bool isSharedPage = false, std::shared_ptr mmapEntry = nullptr); + ~StreamDataPage() override = default; + + /** + * Initialization + */ + Status Init(); + + /** + * @brief Init a page to empty. Call by worker only + * @return Status object + */ + Status InitEmptyPage(); + + /** + * @brief Reset a valid page back to empty + * @return + */ + Status ResetToEmpty(); + + /** + * @brief Insert one single element + * @param element header + data + * @param timeoutMs timeout in millisecond + * @param logPrefix + * @return Status object + */ + Status Insert(const HeaderAndData &element, uint64_t timeoutMs, InsertFlags &flags, + const std::string &logPrefix = ""); + + /** + * @brief Wake up consumer waiting for new element + */ + Status WakeUpConsumers(); + + /** + * @brief Receive a vector of elements in the form of Element + * @param lastRecvCursor + * @param timeoutMs timeout in millisecond + * @param[out] out + * @param logPrefix + * @return Status object + */ + Status Receive(uint64_t lastRecvCursor, uint64_t timeoutMs, std::vector &out, + const std::string &logPrefix = ""); + + /** + * @brief Return ShmView of the next page + */ + ShmView GetNextPage() const; + + /** + * Set the ShmView of the next page + * @param shm + */ + void SetNextPage(const ShmView &shm); + + /** + * @brief Atomic check if there is a next page. + * @return T/F + */ + bool HasNextPage() const; + + /** + * @brief Get the cursor of the beginning slot + * @return + */ + uint64_t GetBegCursor() const; + + /** + * @brief Get the last cursor of the page + */ + uint64_t GetLastCursor() const; + + /** + * @brief Atomically get the number of elements on the page + * @return number of elements on the page + */ + uint32_t GetSlotCount() const; + + /** + * @brief Check if the page is empty or not + */ + bool Empty() const; + + /** + * @brief Atomically increase the reference count + */ + Status RefPage(const std::string &logPrefix = ""); + + /** + * @brief Atomically decrease the reference count + */ + Status ReleasePage(const std::string &logPrefix = ""); + + /** + * @brief Atomically get the reference count + */ + uint32_t GetRefCount() const; + + /** + * Lock a page exclusively for producer. + * @param timeoutMs in millisecond + * @return Status + * @note Consumer is not affected + */ + Status Lock(uint64_t timeoutMs); + + /** + * Unlock a page. + */ + void Unlock(); + + /** + * Unlock and repair a page held by a crashed client with a given lockId + * @param lockId + */ + void TryUnlockByLockId(uint32_t lockId); + + /** + * @brief Batch insert. + * @param[in] buf contiguous payload of the elements in reverse order + * @param[in] sz vector of the size of the elements + * @param[in] headerBits Is data contain header for each element. + * @param[in] streamMetaShm The pointer to streamMetaShm + * @return Status + */ + Status BatchInsert(void *buf, std::vector &sz, uint64_t timeoutMs, std::pair &res, + InsertFlags flags, const std::vector &headerBits, StreamMetaShm *streamMetaShm); + + /** + * Size of the overhead of a page. + * @return + */ + static size_t PageOverhead(bool enableSharedPage = false); + + /** + * Size of the payload of a page, include slot value/streamNo/data. + * @return size_t + */ + size_t PagePayloadSize(); + + static Status ParseShmViewPb(const void *ptr, size_t sz, ShmView &out); + + static Status SerializeToShmViewPb(const ShmView &pageView, std::string &out); + + StreamPageHeader *GetPageHeader() + { + return pageHeader_; + } + + std::shared_ptr &GetSharedMemViewForNextPage() + { + return nextPage_; + } + + Status Seal(const ShmView &nextPage, uint64_t timeoutMs, + std::function &)> locatePage, + const std::string &logPrefix); + + Status ExtractBigElementsUpTo(uint64_t ackCursor, std::vector> &bigId, bool deCouple); + + bool IsSharedPage() const + { + return isSharedPage_; + } + + static size_t GetMetaSize(bool isSharedPage); + + size_t GetFreeSpaceSize(); + + size_t GetTotalEleSize(); + +private: + friend class worker::stream_cache::ExclusivePageQueue; + friend class client::stream_cache::ProducerImpl; + const uint32_t lockId_; + const bool isClient_; + int64_t maxElementSize_; + std::shared_ptr pageLock_{ nullptr }; + StreamPageHeader *pageHeader_{ nullptr }; + // End of page header. After the header, followed by stream page size of number of free bytes. + // After the free space, we have the pointer to the next StreamDataPage (if any). + SlotFlagOffset *slotDir_{ nullptr }; + SharedMemView *tail_{ nullptr }; // Tail pointer to next page + std::shared_ptr nextPage_; + bool isSharedPage_{ false }; + void UpdateSlotConsistentBit(uint32_t slot); + Status WaitForNewElement(uint64_t lastRecvCursor, uint64_t timeoutMs); + SlotOffset GetSlotOffset(size_t index); + SlotFlag GetSlotFlag(size_t index); + void SetPageHasBigElement(); + void UnsetPageHasBigElement(); + bool PageHasBigElement(); + SlotType *GetSlotAddr(size_t index); +}; + +// A helper class to lock a page and will ensure the lock will be released on exit +class StreamPageLock { +public: + explicit StreamPageLock(std::shared_ptr page); + ~StreamPageLock(); + + /** + * Lock a page exclusively for producer. + * @param timeoutMs in millisecond + * @return Status + * @note Consumer is not affected + */ + Status Lock(uint64_t timeoutMs); + + /** + * Unlock a page. + */ + void Unlock(); + +private: + bool pageLocked_ = false; + std::shared_ptr page_; +}; +} // namespace datasystem + +#endif // DATASYSTEM_STREAM_CACHE_STREAM_DATA_PAGE_H diff --git a/src/datasystem/common/stream_cache/stream_fields.h b/src/datasystem/common/stream_cache/stream_fields.h new file mode 100644 index 0000000..15fb020 --- /dev/null +++ b/src/datasystem/common/stream_cache/stream_fields.h @@ -0,0 +1,154 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: A common class for holding stream configuration fields internally. + */ +#ifndef DATASYSTEM_COMMON_STREAM_CACHE_STREAM_FIELDS_H +#define DATASYSTEM_COMMON_STREAM_CACHE_STREAM_FIELDS_H + +#include +#include + +#include "datasystem/stream/stream_config.h" + +/** + * @brief A simple class to facilitate the passing around of fields related to streams. Similar to std::pair, but has + * the benefit of named fields and is better for future growth if new fields are needed. + */ +namespace datasystem { +class StreamFields { +public: + /** + * @brief Init constructor + */ + StreamFields() + : maxStreamSize_(0), + pageSize_(0), + autoCleanup_(false), + retainForNumConsumers_(0), + encryptStream_(false), + reserveSize_(0) + { + } + + /** + * @brief Basic constructor + */ + StreamFields(uint64_t maxStreamSize, size_t pageSize, bool autoCleanup, uint64_t retainForNumConsumers, + bool encryptStream, uint64_t reserveSize, int32_t streamMode) + : maxStreamSize_(maxStreamSize), + pageSize_(pageSize), + autoCleanup_(autoCleanup), + retainForNumConsumers_(retainForNumConsumers), + encryptStream_(encryptStream), + reserveSize_(reserveSize), + streamMode_(static_cast(streamMode)) + { + } + + /** + * @brief Equality operator + * @return true if the objects are the same + */ + bool operator==(const StreamFields &other) const + { + return (this->maxStreamSize_ == other.maxStreamSize_ && this->pageSize_ == other.pageSize_ + && this->autoCleanup_ == other.autoCleanup_ + && this->retainForNumConsumers_ == other.retainForNumConsumers_ + && this->encryptStream_ == other.encryptStream_ + && this->streamMode_ == other.streamMode_); + } + + /** + * @brief Inequality operator + * @return true if the objects are not the same + */ + bool operator!=(const StreamFields &other) const + { + return !(*this == other); + } + + /** + * @brief Check if the fields are empty/un-initialized + * @return True if the fields are empty + */ + bool Empty() const + { + return (maxStreamSize_ == 0 && pageSize_ == 0); + } + + uint64_t maxStreamSize_; + size_t pageSize_; + bool autoCleanup_; + uint64_t retainForNumConsumers_; + bool encryptStream_; + uint64_t reserveSize_; + StreamMode streamMode_ = StreamMode::MPMC; +}; + +class RetainDataState { +public: + enum State : uint32_t { INIT = 1, RETAIN = 2, NOT_RETAIN = 3 }; + std::vector StateNames{ "", "INIT", "RETAIN", "NOT_RETAIN" }; + void Init(uint64_t retainForNumConsumers) + { + if (retainForNumConsumers == 0) { + SetRetainDataState(State::NOT_RETAIN); + } else { + SetRetainDataState(State::RETAIN); + } + } + + void SetRetainDataState(State nextState) + { + // Only update from INIT to RETAIN/NOT_RETAIN, or from RETAIN to NOT_RETAIN, not backward + while (true) { + State currentState = retainDataState_; + if (currentState >= nextState + || retainDataState_.compare_exchange_weak(currentState, nextState, std::memory_order_release, + std::memory_order_acquire)) { + break; + } + } + } + + State GetRetainDataState() + { + return retainDataState_; + } + + std::string PrintCurrentState() + { + return StateNames[retainDataState_]; + } + + // Called when first producer creation fails and reverted + void RollBackToInit() + { + retainDataState_ = State::INIT; + } + + bool IsDataRetained() + { + return retainDataState_ == State::RETAIN; + } + +private: + std::atomic retainDataState_{ State::INIT }; +}; +} // namespace datasystem +#endif // DATASYSTEM_COMMON_STREAM_CACHE_STREAM_FIELDS_H diff --git a/src/datasystem/common/stream_cache/stream_meta_shm.cpp b/src/datasystem/common/stream_cache/stream_meta_shm.cpp new file mode 100644 index 0000000..e679daa --- /dev/null +++ b/src/datasystem/common/stream_cache/stream_meta_shm.cpp @@ -0,0 +1,81 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Record meta for a stream. + */ + +#include "datasystem/common/stream_cache/stream_meta_shm.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/format.h" +#include "datasystem/stream/stream_config.h" + +namespace datasystem { +Status StreamMetaShm::Init(std::shared_ptr mmapTableEntry) +{ + RETURN_RUNTIME_ERROR_IF_NULL(shmPtr_); + auto *data = shmPtr_; + usage_ = reinterpret_cast(shmPtr_); + data += sizeof(*(usage_)); + CHECK_FAIL_RETURN_STATUS(static_cast((data) - (shmPtr_)) <= shmSz_, K_RUNTIME_ERROR, + "Work area size too small"); + if (mmapTableEntry != nullptr) { + mmapTableEntry_ = std::move(mmapTableEntry); + } + return Status::OK(); +} + +Status StreamMetaShm::TryIncUsage(uint64_t size) +{ + INJECT_POINT("StreamMetaShm.TryIncUsage"); + bool success = false; + uint64_t currUsage = 0; + do { + currUsage = __atomic_load_n(usage_, __ATOMIC_RELAXED); + CHECK_FAIL_RETURN_STATUS(UINT64_MAX - currUsage >= size, K_OUT_OF_RANGE, + "The usage of shared memory reached UINT64_MAX"); + uint64_t desiredVal = currUsage + size; + CHECK_FAIL_RETURN_STATUS(desiredVal <= maxStreamSize_, K_OUT_OF_MEMORY, + FormatString("stream: %s, currUsage: %llu, tryIncUsage: %llu, maxStreamSize_: %llu", + streamName_, currUsage, size, maxStreamSize_)); + success = + __atomic_compare_exchange_n(usage_, &currUsage, desiredVal, false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } while (!success); + + VLOG(SC_NORMAL_LOG_LEVEL) << "TryIncUsage for streamName:" << streamName_ << ", size: " << size + << ", before: " << currUsage << ", after: " << (currUsage + size); + return Status::OK(); +} + +Status StreamMetaShm::TryDecUsage(uint64_t size) +{ + bool success = false; + uint64_t currUsage = 0; + do { + currUsage = __atomic_load_n(usage_, __ATOMIC_RELAXED); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + currUsage >= size, K_RUNTIME_ERROR, + FormatString("[TryDecUsage error] stream: %s, currUsage: %llu, tryDecUsage: %llu", streamName_, currUsage, + size)); + uint64_t desiredVal = currUsage - size; + success = + __atomic_compare_exchange_n(usage_, &currUsage, desiredVal, false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } while (!success); + VLOG(SC_NORMAL_LOG_LEVEL) << "TryDecUsage for streamName:" << streamName_ << ", size: " << size + << ", before: " << currUsage << ", after: " << (currUsage - size); + return Status::OK(); +} +} // namespace datasystem diff --git a/src/datasystem/common/stream_cache/stream_meta_shm.h b/src/datasystem/common/stream_cache/stream_meta_shm.h new file mode 100644 index 0000000..75b66da --- /dev/null +++ b/src/datasystem/common/stream_cache/stream_meta_shm.h @@ -0,0 +1,65 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Record meta for a stream. + */ +#ifndef DATASYSTEM_COMMON_STREAM_CACHE_STREAM_META_SHM_H +#define DATASYSTEM_COMMON_STREAM_CACHE_STREAM_META_SHM_H + +#include +#include +#include "datasystem/client/mmap_table_entry.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +class StreamMetaShm { +public: + StreamMetaShm(std::string streamName, void *shmPtr, size_t shmSz, uint64_t maxStreamSize) + : streamName_(std::move(streamName)), + shmPtr_(reinterpret_cast((shmPtr))), + shmSz_(shmSz), + maxStreamSize_(maxStreamSize) + { + } + + Status Init(std::shared_ptr mmapTableEntry = nullptr); + + /** + * @brief Try to increase the usage of shared memory in this node for this stream. + * @param[in] size The size to be increased. + * @return Status of the call. + */ + Status TryIncUsage(uint64_t size); + + /** + * @brief Try to decrease the usage of shared memory in this node for this stream. + * @param[in] size The size to be increased. + * @return Status of the call. + */ + Status TryDecUsage(uint64_t size); + +private: + const std::string streamName_; + uint8_t *shmPtr_; + const size_t shmSz_; + uint64_t *usage_{ nullptr }; + std::shared_ptr mmapTableEntry_; // for client. + uint64_t maxStreamSize_ = 0; +}; +} // namespace datasystem +#endif // DATASYSTEM_COMMON_STREAM_CACHE_STREAM_META_SHM_H diff --git a/src/datasystem/common/stream_cache/util.h b/src/datasystem/common/stream_cache/util.h new file mode 100644 index 0000000..fe09606 --- /dev/null +++ b/src/datasystem/common/stream_cache/util.h @@ -0,0 +1,53 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Utility function for stream cache. + */ +#ifndef DATASYSTEM_COMMON_STREAM_CACHE_UTIL_H +#define DATASYSTEM_COMMON_STREAM_CACHE_UTIL_H + +#include + +#include "datasystem/protos/master_stream.service.rpc.pb.h" +#include "datasystem/protos/stream_posix.service.rpc.pb.h" +#include "datasystem/utils/status.h" + +namespace datasystem { +/** + * @brief A helper function to flow error or success back on the unary rpc api. + * @tparam reqT The request type for the unary rpc + * @tparam rspT The response type for the unary rpc + * @param[in] rc The status code to return back through the api + * @param[in] rsp The response structure to return (if not an error case) + * @param[in] errMsg The error message to log if the rc was an error case. + * @param[in] serverApi The unary api to send responses on + */ +template +inline void CheckErrorReturn(const Status &rc, const rspT &rsp, const std::string &errMsg, + std::shared_ptr> serverApi) +{ + if (rc.IsOk()) { + // Success case, flow the response back to client (rc of OK is inferred) + LOG_IF_ERROR(serverApi->Write(rsp), "Write reply to client stream failed"); + } else { + // Error case, flow the rc back to client + LOG(ERROR) << errMsg << rc.ToString(); + LOG_IF_ERROR(serverApi->SendStatus(rc), "Write reply to client stream failed"); + } +} +} // namespace datasystem +#endif // DATASYSTEM_COMMON_STREAM_CACHE_UTIL_H diff --git a/src/datasystem/common/string_intern/CMakeLists.txt b/src/datasystem/common/string_intern/CMakeLists.txt new file mode 100644 index 0000000..aee28ed --- /dev/null +++ b/src/datasystem/common/string_intern/CMakeLists.txt @@ -0,0 +1,11 @@ +set(STRING_REF_SRC + string_entity.cpp +) + +set(STRING_REF_DEPENDS_LIB + ${TBB_LIBRARY} + common_log +) + +add_library(string_ref STATIC ${STRING_REF_SRC}) +target_link_libraries(string_ref PRIVATE ${STRING_REF_DEPENDS_LIB}) diff --git a/src/datasystem/common/string_intern/string_entity.cpp b/src/datasystem/common/string_intern/string_entity.cpp new file mode 100644 index 0000000..9d93e1c --- /dev/null +++ b/src/datasystem/common/string_intern/string_entity.cpp @@ -0,0 +1,97 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: StringEntity with reference count implementation. + */ +#include "datasystem/common/string_intern/string_entity.h" + +#include + +namespace datasystem { +namespace intern { +std::hash hasher; + +StringEntity::StringEntity(std::string val) : countRef_(0), value_(std::move(val)), hash_(hasher(value_)) +{ +} + +StringEntity::StringEntity(const StringEntity &rStr) : countRef_(0), value_(rStr.value_), hash_(rStr.hash_) +{ +} + +StringEntity::StringEntity(StringEntity &&rStr) noexcept + : countRef_(0), value_(std::move(rStr.value_)), hash_(rStr.hash_) +{ +} + +StringEntity &StringEntity::operator=(const StringEntity &rStr) +{ + countRef_ = 0; + value_ = rStr.value_; + hash_ = rStr.hash_; + return *this; +} + +StringEntity &StringEntity::operator=(StringEntity &&rStr) noexcept +{ + countRef_ = 0; + value_ = std::move(rStr.value_); + hash_ = rStr.hash_; + return *this; +} + +const std::string &StringEntity::ToStr() const +{ + return value_; +} + +int32_t StringEntity::IncRef() const +{ + return ++countRef_; +} + +bool StringEntity::DecRef() const +{ + return (--countRef_ == 0); +} + +void StringEntity::IncDelRef() const +{ + (void)delRef_.fetch_add(1, std::memory_order_relaxed); +} + +bool StringEntity::DecDelRef() const +{ + return --delRef_ == 0; +} + +size_t StringEntity::GetHash() const +{ + return hash_; +} + +size_t StringEntity::GetRef() const +{ + return countRef_.load(std::memory_order_relaxed); +} + +bool StringEntity::operator==(const StringEntity &rhs) const +{ + return this == &rhs || this->value_ == rhs.value_; +} +} // namespace intern +} // namespace datasystem diff --git a/src/datasystem/common/string_intern/string_entity.h b/src/datasystem/common/string_intern/string_entity.h new file mode 100644 index 0000000..9f82224 --- /dev/null +++ b/src/datasystem/common/string_intern/string_entity.h @@ -0,0 +1,124 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: StringEntity reference count declaration. + */ +#ifndef DATASYSTEM_COMMON_STRING_INTERN_STRING_ENTITY_H +#define DATASYSTEM_COMMON_STRING_INTERN_STRING_ENTITY_H + +#include +#include +#include + +#include +namespace datasystem { +namespace intern { +class StringEntity { +public: + explicit StringEntity(std::string val); + StringEntity() = delete; + + explicit StringEntity(const StringEntity &rStr); + + StringEntity(StringEntity &&rStr) noexcept; + + StringEntity &operator=(const StringEntity &rStr); + + StringEntity &operator=(StringEntity &&rStr) noexcept; + + /** + * @brief Get the const reference of std::string. + * @return The the const reference of std::string. + */ + const std::string &ToStr() const; + + /** + * @brief Add the reference count of this string. + * @return The reference count after add. + */ + int32_t IncRef() const; + + /** + * @brief Release a reference count of this string. + * @return Whether the reference count is 0 after release. + */ + bool DecRef() const; + + /** + * @brief Add a delete reference count of this string. + */ + void IncDelRef() const; + + /** + * @brief Release a delete reference count of this string. + * @return Whether the reference count is 0 after release. + */ + bool DecDelRef() const; + + /** + * @brief Get the hash of string. + * @return The hash of string. + */ + size_t GetHash() const; + + /** + * @brief Get the reference count of this string. + * @return The reference count. + */ + size_t GetRef() const; + + bool operator==(const StringEntity &rhs) const; + +private: + /** + * Only countRef_ may lead to a data rance: + * 1. Thread A detaches the last reference to x and is preempted. + * 2. Thread B look for x, find it and attaches a reference to it. + * 3. Thread A resumes and proceeds with erasing x, leaving a dangling reference in thread B. + * Here is where the delRef_ count comes into play. This count is + * incremented when countRef_ changes from 0 to 1, and decremented + * when a thread is about to check a value for erasure. + * (Multi threads may check countRef_ is 0 and try to call erase) + * It can be seen that a value is effectively erasable only when the delRef_ count goes down to 0. + */ + mutable std::atomic_int32_t delRef_{ 0 }; + mutable std::atomic_int32_t countRef_{ 0 }; + std::string value_; + size_t hash_; +}; +} // namespace intern +} // namespace datasystem + +namespace tbb { +using datasystem::intern::StringEntity; +template <> +#if TBB_INTERFACE_VERSION >= 12050 +struct detail::d1::tbb_hash_compare { +#else +struct tbb_hash_compare { +#endif + static size_t hash(const StringEntity &a) + { + return a.GetHash(); + } + static size_t equal(const StringEntity &a, const StringEntity &b) + { + return a.ToStr() == b.ToStr(); + } +}; +} // namespace tbb +#endif diff --git a/src/datasystem/common/string_intern/string_pool.h b/src/datasystem/common/string_intern/string_pool.h new file mode 100644 index 0000000..87cf0f4 --- /dev/null +++ b/src/datasystem/common/string_intern/string_pool.h @@ -0,0 +1,108 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: StringPool implementation. + */ +#ifndef DATASYSTEM_COMMON_STRING_INTERN_STRING_POOL_H +#define DATASYSTEM_COMMON_STRING_INTERN_STRING_POOL_H + +#include + +#include + +#include "datasystem/common/string_intern/string_ptr.h" +#include "datasystem/common/string_intern/string_entity.h" +#include "datasystem/common/log/log.h" + +namespace datasystem { +namespace intern { +using StringEntityMap = tbb::concurrent_hash_map; + +template +class StringPool { +public: + StringPool() = default; + ~StringPool() + { + if (Size() > 0) { + LOG(ERROR) << "Some RCString still in pool: " << Size() << " when pool finalize, may cause segment fault."; + } + } + StringPool(StringPool &&) = delete; // Move construct + StringPool(const StringPool &) = delete; // Copy construct + StringPool &operator=(const StringPool &) = delete; // Copy assign + StringPool &operator=(StringPool &&) = delete; // Move assign + + /** + * @brief Get the Singleton StringPool instance. + * @return StringPool instance. + */ + static StringPool &Instance() + { + static StringPool instance; + return instance; + } + + /** + * @brief Init the StringPool, use it to control the construction and destruction timing. + */ + void Init() + { + LOG(INFO) << "StringPool init"; + } + + /** + * @brief Intern the std::string to pool and return the handle of this string. + * @param[in] val The std::string ready to intern. + * @param[out] The handle of intern string. + */ + StringPtr Intern(const std::string &val) + { + StringEntity rcStr(val); + StringEntityMap::const_accessor readAccessor; + (void)pool_.insert(readAccessor, rcStr); + return StringPtr(readAccessor->first); + } + + /** + * @brief Try to Erase the StringEntity by handle if its reference count is 0. + * @param[in] handle The handle whose ptr_ is ready to erase. + */ + void Erase(StringPtr &handle) + { + const auto val = handle.GetEntity(); + StringEntityMap::accessor accessor; + if (val != nullptr && pool_.find(accessor, *val) && accessor->first.DecDelRef()) { + (void)pool_.erase(accessor); + } + } + + /** + * @brief Return the size ofStringPool + * @return The size ofStringPool + */ + size_t Size() + { + return pool_.size(); + } + +private: + StringEntityMap pool_; +}; +} // namespace intern +} // namespace datasystem +#endif diff --git a/src/datasystem/common/string_intern/string_ptr.h b/src/datasystem/common/string_intern/string_ptr.h new file mode 100644 index 0000000..0e1b752 --- /dev/null +++ b/src/datasystem/common/string_intern/string_ptr.h @@ -0,0 +1,93 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: StringPtr implementation. + */ +#ifndef DATASYSTEM_COMMON_STRING_INTERN_STRING_PTR_H +#define DATASYSTEM_COMMON_STRING_INTERN_STRING_PTR_H + +#include +#include +#include + +#include + +#include "datasystem/common/string_intern/string_entity.h" + +namespace datasystem { +namespace intern { +/** + * @brief A handle of StringEntityto control the reference count like shared_ptr. + */ +class StringPtr { +public: + StringPtr() : ptr_(nullptr) + { + } + + explicit StringPtr(const StringEntity &str) : ptr_(&str) + { + IncRef(); + } + + /** + * @brief Get the const reference of std::string. + * @return The the const reference of std::string. + */ + const std::string &ToStr() const + { + if (ptr_ != nullptr) { + return ptr_->ToStr(); + } + static std::string defaultStr; + return defaultStr; + } + + /** + * @brief Get the const reference of StringEntity. + * @return The the const reference of StringEntity. + */ + const StringEntity *GetEntity() const + { + return ptr_; + } + + size_t GetHash() const + { + static size_t emptyStringHashVal = std::hash()(""); + return ptr_ != nullptr ? ptr_->GetHash() : emptyStringHashVal; + } + + void IncRef() const + { + if (ptr_ != nullptr && ptr_->IncRef() == 1) { + ptr_->IncDelRef(); + } + } + + bool DecRef() const + { + return ptr_ != nullptr ? ptr_->DecRef() : false; + } + +private: + const StringEntity *ptr_; +}; +} // namespace intern +} // namespace datasystem + +#endif diff --git a/src/datasystem/common/string_intern/string_ref.h b/src/datasystem/common/string_intern/string_ref.h new file mode 100644 index 0000000..822c410 --- /dev/null +++ b/src/datasystem/common/string_intern/string_ref.h @@ -0,0 +1,262 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: StringRef implementation. + */ +#ifndef DATASYSTEM_COMMON_STRING_INTERN_STRING_REF_H +#define DATASYSTEM_COMMON_STRING_INTERN_STRING_REF_H + +#include "datasystem/common/string_intern/string_entity.h" +#include "datasystem/common/string_intern/string_ptr.h" +#include "datasystem/common/string_intern/string_pool.h" + +namespace datasystem { +namespace intern { +enum class KeyType : size_t { OTHER, OBJECT_KEY, CLIENT_KEY, SHM_KEY }; + +template +class StringRef { +public: + static StringRef Intern(const std::string &str); + + StringRef() = default; + + StringRef(StringPtr handle) : handle_(handle) + { + } + + StringRef(const StringRef &other) noexcept : handle_(other.handle_) + { + handle_.IncRef(); + } + StringRef &operator=(const StringRef &other) noexcept + { + if (this != &other) { + handle_ = other.ptr_; + handle_.IncRef(); + } + return *this; + } + + StringRef(StringRef &&other) noexcept + { + std::swap(handle_, other.handle_); + } + + StringRef &operator=(StringRef &&other) noexcept + { + std::swap(handle_, other.handle_); + other.Clear(); + return *this; + } + + ~StringRef() + { + Clear(); + } + + void Clear() + { + if (handle_.DecRef()) { + StringPool::Instance().Erase(handle_); + } + handle_ = StringPtr(); + } + + template = true> + StringRef(const std::string &val) noexcept + { + if (!val.empty()) { + handle_ = StringPool::Instance().Intern(val); + } + } + + template = true> + StringRef(const char *cStr) : StringRef(std::string(cStr)) + { + } + + /** + * @brief Get the hash of StringRef. + * @return The hash of StringRef. + */ + size_t GetHash() const + { + return handle_.GetHash(); + } + + /** + * @brief Get the const reference of std::string. + * @return The the const reference of std::string. + */ + const std::string &ToString() const + { + return handle_.ToStr(); + } + + bool operator==(const StringRef &rhs) const + { + return this == &rhs || handle_.GetEntity() == rhs.handle_.GetEntity(); + } + + bool operator!=(const StringRef &rhs) const + { + return this != &rhs && ToString() != rhs.ToString(); + } + + bool operator<(const StringRef &rhs) const + { + return ToString() < rhs.ToString(); + } + + /** + * @brief The operator to convert a StringRef to std::string. + * @return The the const reference of std::string. + */ + operator const std::string &() const + { + return handle_.ToStr(); + } + + const char *Data() const + { + return ToString().data(); + } + + std::string::size_type Size() const + { + return ToString().size(); + } + +private: + StringPtr handle_; +}; + +template +inline StringRef StringRef::Intern(const std::string &str) +{ + if (str.empty()) { + return StringRef(StringPtr()); + } + return StringRef(StringPool::Instance().Intern(str)); +} + +template +std::ostream &operator<<(std::ostream &os, const StringRef &obj) +{ + os << obj.ToString(); + return os; +} + +template +inline std::string operator+(S &&lhs, const StringRef &rhs) +{ + return std::forward(lhs) + rhs.ToString(); +} + +template +inline std::string operator+(const StringRef &lhs, S &&rhs) +{ + return lhs.ToString() + std::forward(rhs); +} +} // namespace intern + +using ObjectKey = intern::StringRef; +using ObjectKeyPool = intern::StringPool; + +using ClientKey = intern::StringRef; +using ClientKeyPool = intern::StringPool; + +using ShmKey = intern::StringRef; +using ShmKeyPool = intern::StringPool; + +using OtherKey = intern::StringRef; +using OtherKeyPool = intern::StringPool; + +} // namespace datasystem + +#if TBB_INTERFACE_VERSION >= 12050 +#define STRING_REF_IMPL_FOR_TBB(key) \ + template <> \ + struct detail::d1::tbb_hash_compare> { \ + static size_t hash(const StringRef &a) \ + { \ + return a.GetHash(); \ + } \ + static size_t equal(const StringRef &a, const StringRef &b) \ + { \ + return a == b; \ + } \ + } +#else +#define STRING_REF_IMPL_FOR_TBB(key) \ + template <> \ + struct tbb_hash_compare> { \ + static size_t hash(const StringRef &a) \ + { \ + return a.GetHash(); \ + } \ + static size_t equal(const StringRef &a, const StringRef &b) \ + { \ + return a == b; \ + } \ + } +#endif + +#define STRING_REF_IMPL_FOR_STD(key) \ + template <> \ + struct hash> { \ + size_t operator()(const StringRef &str) const \ + { \ + return str.GetHash(); \ + } \ + }; \ + \ + template <> \ + struct equal_to> { \ + bool operator()(const StringRef &lhs, const StringRef &rhs) const \ + { \ + return lhs == rhs; \ + } \ + }; \ + \ + template <> \ + struct less> { \ + bool operator()(const StringRef &lhs, const StringRef &rhs) const \ + { \ + return lhs < rhs; \ + } \ + } + +namespace tbb { +using datasystem::intern::KeyType; +using datasystem::intern::StringRef; +STRING_REF_IMPL_FOR_TBB(OBJECT_KEY); +STRING_REF_IMPL_FOR_TBB(CLIENT_KEY); +STRING_REF_IMPL_FOR_TBB(SHM_KEY); +STRING_REF_IMPL_FOR_TBB(OTHER); +} // namespace tbb + +namespace std { +using datasystem::intern::KeyType; +using datasystem::intern::StringRef; +STRING_REF_IMPL_FOR_STD(OBJECT_KEY); +STRING_REF_IMPL_FOR_STD(CLIENT_KEY); +STRING_REF_IMPL_FOR_STD(SHM_KEY); +STRING_REF_IMPL_FOR_STD(OTHER); +} // namespace std +#endif diff --git a/src/datasystem/common/util/gflag/common_gflags.cpp b/src/datasystem/common/util/gflag/common_gflags.cpp index 7f58a74..ead11ba 100644 --- a/src/datasystem/common/util/gflag/common_gflags.cpp +++ b/src/datasystem/common/util/gflag/common_gflags.cpp @@ -57,11 +57,31 @@ bool ValidateEnableUrma(const char *flagName, bool value) return true; #endif } + +bool ValidateUrmaMode(const char *flagName, const std::string &value) +{ + (void)flagName; + (void)value; +#ifdef USE_URMA + if (value == "IB") { + return true; + } +#ifdef URMA_OVER_UB + if (value == "UB") { + return true; + } +#endif + return false; +#else + return true; +#endif +} } // namespace DS_DEFINE_bool(enable_urma, false, "Option to turn on urma for OC worker to worker data transfer, default false."); DS_DEFINE_validator(enable_urma, &ValidateEnableUrma); - +DS_DEFINE_string(urma_mode, "IB", "Option to enable URMA over IB or UB, default IB to run with URMA over IB."); +DS_DEFINE_validator(urma_mode, &ValidateUrmaMode); DS_DEFINE_uint32(urma_poll_size, 8, "Number of complete record to poll at a time, 16 is the max this device can poll"); DS_DEFINE_uint32(urma_connection_size, 16, "Number of jfs and jfr pair"); DS_DEFINE_bool(urma_register_whole_arena, true, diff --git a/src/datasystem/common/util/id_tool.cpp b/src/datasystem/common/util/id_tool.cpp index 05c26d6..201180d 100644 --- a/src/datasystem/common/util/id_tool.cpp +++ b/src/datasystem/common/util/id_tool.cpp @@ -25,7 +25,7 @@ #include "datasystem/utils/status.h" #include "datasystem/common/util/id_tool.h" -DS_DECLARE_string(other_az_names); +DS_DECLARE_string(other_cluster_names); namespace datasystem { Status TrySplitWorkerIdFromObjecId(const std::string &objKey, std::string &workerUuid) diff --git a/src/datasystem/common/util/status_code.def b/src/datasystem/common/util/status_code.def index a31f152..6df20e4 100644 --- a/src/datasystem/common/util/status_code.def +++ b/src/datasystem/common/util/status_code.def @@ -32,6 +32,7 @@ STATUS_CODE_DEF(K_SERVER_FD_CLOSED, "The server fd has been closed") STATUS_CODE_DEF(K_RETRY_IF_LEAVING, "Try again when worker is leaving") STATUS_CODE_DEF(K_SCALE_DOWN, "The worker is exiting") STATUS_CODE_DEF(K_SCALING, "The cluster is scaling") +STATUS_CODE_DEF(K_CLIENT_DEADLOCK, "The client may deadlock") STATUS_CODE_DEF(K_LRU_HARD_LIMIT, "Lru hard limit") STATUS_CODE_DEF(K_LRU_SOFT_LIMIT, "Lru soft limit") @@ -52,6 +53,19 @@ STATUS_CODE_DEF(K_FUTURE_TIMEOUT, "The future is timeout") STATUS_CODE_DEF(K_ACL_ERROR, "Acl api error") STATUS_CODE_DEF(K_HCCL_ERROR, "Hccl api error") +// stream +STATUS_CODE_DEF(K_SC_STREAM_NOT_FOUND, "Stream name not found") +STATUS_CODE_DEF(K_SC_PRODUCER_NOT_FOUND, "Producer not found") +STATUS_CODE_DEF(K_SC_CONSUMER_NOT_FOUND, "Consumer not found") +STATUS_CODE_DEF(K_SC_END_OF_PAGE, "End of page reached") +STATUS_CODE_DEF(K_SC_STREAM_IN_RESET_STATE, "Stream is currently in reset state") +STATUS_CODE_DEF(K_SC_WORKER_WAS_LOST, "Worker crashed or restarted") +STATUS_CODE_DEF(K_SC_STREAM_IN_USE, "Stream is still in use") +STATUS_CODE_DEF(K_SC_STREAM_DELETE_IN_PROGRESS, "Stream is getting deleted") +STATUS_CODE_DEF(K_SC_STREAM_RESOURCE_ERROR, "Stream resource error") +STATUS_CODE_DEF(K_SC_ALREADY_CLOSED, "Producer or consumer already closed") +STATUS_CODE_DEF(K_SC_STREAM_NOTIFICATION_PENDING, "Notifications are pending") + // rdma STATUS_CODE_DEF(K_OC_REMOTE_GET_NOT_ENOUGH, "Size on the remote node has changed") STATUS_CODE_DEF(K_URMA_ERROR, "Urma operation failed") diff --git a/src/datasystem/common/util/strings_util.h b/src/datasystem/common/util/strings_util.h index d3b5811..df5168f 100644 --- a/src/datasystem/common/util/strings_util.h +++ b/src/datasystem/common/util/strings_util.h @@ -29,6 +29,7 @@ #include #include +#include #include #include @@ -87,6 +88,20 @@ std::string VectorToString(const Vec &vec, bool allowCut = true) return out.str(); } +template +inline std::string ToStringHelper(const T &value) +{ + if constexpr (std::is_base_of_v<::google::protobuf::Message, T>) { + return value.DebugString(); + } else { + if constexpr (std::is_same_v) { + return value; + } else { + return std::to_string(value); + } + } +} + /** * @brief Print map. * @param[in] map Map to print. @@ -97,7 +112,7 @@ std::string MapToString(const Map &map) { std::stringstream out; for (auto &item : map) { - out << "{" << item.first << ": " << item.second << "} "; + out << "{" << item.first << ": " << ToStringHelper(item.second) << "} "; } return out.str(); } @@ -140,9 +155,9 @@ inline bool StringToInt(const std::string &str, int &num) } /** -* @brief Check if the string contains the negative sign. -* @param[in] str string to be checked. -* @return true if has negative sign, else false. + * @brief Check if the string contains the negative sign. + * @param[in] str string to be checked. + * @return true if has negative sign, else false. */ inline bool IsNegative(const std::string &str) { @@ -157,10 +172,10 @@ inline bool IsNegative(const std::string &str) } /** -* @brief Convert string to unsigned long, using stoull directly will interpret negative number as extremely large -* positive values, therefore it’s necessary to check if the string contains a negative sign. -* @param[in] str string to be interpreted. -* @return Converted number. + * @brief Convert string to unsigned long, using stoull directly will interpret negative number as extremely large + * positive values, therefore it’s necessary to check if the string contains a negative sign. + * @param[in] str string to be interpreted. + * @return Converted number. */ inline unsigned long StrToUnsignedLong(const std::string &str) { @@ -171,10 +186,10 @@ inline unsigned long StrToUnsignedLong(const std::string &str) } /** -* @brief Convert string to unsigned long long, using stoull directly will interpret negative number as extremely -* large positive values, therefore it’s necessary to check if the string contains a negative sign. -* @param[in] str string to be interpreted. -* @return Converted number. + * @brief Convert string to unsigned long long, using stoull directly will interpret negative number as extremely + * large positive values, therefore it’s necessary to check if the string contains a negative sign. + * @param[in] str string to be interpreted. + * @return Converted number. */ inline unsigned long long StrToUnsignedLongLong(const std::string &str) { @@ -420,7 +435,7 @@ inline const char *BoolToString(bool val) * @param input The input string. * @param retainDigits The retained digits. * @return std::string Truncated string. -*/ + */ inline std::string GetTruncatedStr(const std::string &input, size_t retainDigits = 6) { const size_t minDigits = 10; diff --git a/src/datasystem/common/util/thread_pool.h b/src/datasystem/common/util/thread_pool.h index 450eab8..65352ad 100644 --- a/src/datasystem/common/util/thread_pool.h +++ b/src/datasystem/common/util/thread_pool.h @@ -260,4 +260,109 @@ inline bool IsThreadFinished(std::shared_future const &f, const int &timeout) return f.wait_for(std::chrono::seconds(timeout)) == std::future_status::ready; } } // namespace datasystem + +class OrderedThreadPool { +public: + explicit OrderedThreadPool(size_t threadCount) + : taskQueues_(threadCount), queueMutexes_(threadCount), conditionVars_(threadCount), threadCount_(threadCount) + { + for (size_t i = 0; i < threadCount_; ++i) { + workers_.emplace_back([this, i] { Run(i); }); + } + } + + void Run(size_t index) + { + while (true) { + std::shared_ptr task; + { + std::unique_lock lock(queueMutexes_[index]); + conditionVars_[index].wait(lock, [this, index] { return stop_.load() || !taskQueues_[index].empty(); }); + + if (stop_.load() && taskQueues_[index].empty()) { + return; + } + + task = taskQueues_[index].front(); + taskQueues_[index].pop(); + } + + try { + task->func(); + task->promise.set_value(); + } catch (...) { + task->promise.set_exception(std::current_exception()); + } + } + } + + ~OrderedThreadPool() + { + stop_.store(true); + for (auto &cv : conditionVars_) { + cv.notify_all(); + } + for (auto &worker : workers_) { + worker.join(); + } + } + + std::future Submit(const std::string &key, std::function func) + { + size_t index = GetQueueIndex(key); + auto task = std::make_shared(std::move(func), key); + auto future = task->promise.get_future(); + + { + std::lock_guard lock(queueMutexes_[index]); + taskQueues_[index].push(task); + } + conditionVars_[index].notify_one(); + + return future; + } + + /** + * @brief Check whether some async tasks in the list. + * @return True if all of async list is empty. + */ + bool AreAllQueuesEmpty() + { + for (size_t i = 0; i < threadCount_; ++i) { + std::lock_guard lock(queueMutexes_[i]); + if (!taskQueues_[i].empty()) { + return false; + } + } + return true; + } + +private: + struct Task { + std::function func; + std::promise promise; + std::string key; + + Task(std::function f, const std::string &k) : func(std::move(f)), key(k) + { + } + }; + + std::vector>> taskQueues_; + std::vector queueMutexes_; + std::vector conditionVars_; + std::vector workers_; + std::atomic stop_{ false }; + size_t threadCount_; + + /** + * @brief Calculate a index of list according to key. + * @param[in] key The Id of the object need to be calculated. + * @return Index of list. + */ + size_t GetQueueIndex(const std::string &key) + { + return std::hash{}(key) % threadCount_; + } +}; #endif // DATASYSTEM_COMMON_UTIL_THREAD_POOL_H diff --git a/src/datasystem/common/util/validator.h b/src/datasystem/common/util/validator.h index 33469b9..71c1a01 100644 --- a/src/datasystem/common/util/validator.h +++ b/src/datasystem/common/util/validator.h @@ -921,5 +921,21 @@ public: } return true; } + + /** + * @brief Validate the given string matches the cache type supported. + * @param[in] flagName Cache type flag. + * @param[in] value The string to be checked. + * @return True if valid. + */ + static bool ValidateRocksdbModeType(const char *flagName, const std::string &value) + { + if (value == "none" || value == "sync" || value == "async") { + return true; + } + LOG(ERROR) << FormatString( + "The value of %s flag is %s, which must be 'none'/'sync'/'async'.", flagName, value); + return false; + } }; #endif // DATASYSTEM_COMMON_UTIL_FLAG_VALIDATOR_H \ No newline at end of file diff --git a/src/datasystem/common/util/wait_post.cpp b/src/datasystem/common/util/wait_post.cpp index b2b543f..c74e880 100644 --- a/src/datasystem/common/util/wait_post.cpp +++ b/src/datasystem/common/util/wait_post.cpp @@ -79,4 +79,19 @@ void Barrier::Wait() cv_.wait(lock, [this, gen] { return gen != generation_; }); } + +void WaitPost::SetWithStatus(const Status &status) +{ + std::unique_lock lock(mux_); + val_ = 1; + status_ = status; + cv_.notify_all(); +} + +Status WaitPost::WaitAndGetStatus() +{ + std::unique_lock lock(mux_); + cv_.wait(lock, [this]() { return val_ != 0; }); + return status_; +} } // namespace datasystem diff --git a/src/datasystem/common/util/wait_post.h b/src/datasystem/common/util/wait_post.h index d3956bb..b05ec4e 100644 --- a/src/datasystem/common/util/wait_post.h +++ b/src/datasystem/common/util/wait_post.h @@ -23,6 +23,8 @@ #include #include +#include "datasystem/utils/status.h" + namespace datasystem { /** * A WaitPost is an implementation of &visited) +{ + visited.insert(node); + TbbDeadLockDetectionGraph::const_accessor acc; + if (deadLockDetectionGraph_.find(acc, node)) { + for (const std::string &neighbor : acc->second) { + if (neighbor == parent) { + continue; + } + if (visited.find(neighbor) != visited.end()) { + return true; + } + if (HasCycle(neighbor, node, visited)) { + // Cycle detected in recursive traversal + return true; + } + } + } + return false; +} + +bool MasterDevDeadLockManager::IsExistDeadlock() +{ + PerfPoint point(PerfKey::MASTER_IS_EXIST_DEAD_LOCK); + auto lockGuard = std::lock_guard(edgesMutex_); + + // Duplicate edges immediately indicate potential deadlock + if (!duplicateEdges_.empty()) { + return true; + } + + // Perform DFS cycle detection on each connected component + tbb::concurrent_unordered_set visited; + for (const auto &pair : deadLockDetectionGraph_) { + const std::string &node = pair.first; + if (visited.find(node) == visited.end()) { + if (HasCycle(node, "", visited)) { + return true; + } + } + } + return false; +} + +void MasterDevDeadLockManager::AddDependencyEdge(const std::string &from, const std::string &to) +{ + { + auto lockGuard = std::lock_guard(edgesMutex_); + std::string u = from; + std::string v = to; + if (u > v) + std::swap(u, v); + + // If edge already exists, mark as duplicate and return without adding to graph + auto edge = std::make_pair(u, v); + if (edgesSet_.find(edge) != edgesSet_.end()) { + duplicateEdges_.insert(edge); + return; + } + + // Add the new edge to tracking set and dependency graph + edgesSet_.insert(edge); + } + + TbbDeadLockDetectionGraph::accessor acc1; + deadLockDetectionGraph_.insert(acc1, from); + acc1->second.insert(to); + TbbDeadLockDetectionGraph::accessor acc2; + deadLockDetectionGraph_.insert(acc2, to); + acc2->second.insert(from); +} + +void MasterDevDeadLockManager::RemoveDependencyEdge(const std::string &from, const std::string &to) +{ + { + auto lockGuard = std::lock_guard(edgesMutex_); + + // Normalize edge representation same as in AddDependencyEdge + std::string u = from; + std::string v = to; + if (u > v) + std::swap(u, v); + + // If this edge is marked as duplicate, just remove from duplicates + auto edge = std::make_pair(u, v); + if (duplicateEdges_.find(edge) != duplicateEdges_.end()) { + duplicateEdges_.unsafe_erase(edge); + return; + } + + // Remove edge from tracking set + edgesSet_.unsafe_erase(edge); + } + + // Remove bidirectional edges from the dependency graph + TbbDeadLockDetectionGraph::accessor acc; + if (deadLockDetectionGraph_.find(acc, from)) { + acc->second.unsafe_erase(to); + if (acc->second.empty()) { + deadLockDetectionGraph_.erase(acc); + } + } + if (deadLockDetectionGraph_.find(acc, to)) { + acc->second.unsafe_erase(from); + if (acc->second.empty()) { + deadLockDetectionGraph_.erase(acc); + } + } +} +} // namespace master +} // namespace datasystem \ No newline at end of file diff --git a/src/datasystem/master/object_cache/device/master_dev_dead_lock_manager.h b/src/datasystem/master/object_cache/device/master_dev_dead_lock_manager.h new file mode 100644 index 0000000..b6a418b --- /dev/null +++ b/src/datasystem/master/object_cache/device/master_dev_dead_lock_manager.h @@ -0,0 +1,101 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_MASTER_OBJECT_CACHE_DEVICE_MASTER_DEV_DEAD_LOCK_MANAGER_H +#define DATASYSTEM_MASTER_OBJECT_CACHE_DEVICE_MASTER_DEV_DEAD_LOCK_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datasystem/common/log/log.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/util/format.h" +#include +#include + +namespace datasystem { +namespace master { + +using TbbDeadLockDetectionGraph = tbb::concurrent_hash_map>; + +class MasterDevDeadLockManager { +public: + MasterDevDeadLockManager() = default; + + ~MasterDevDeadLockManager() = default; + + /** + * @brief Checks if there exists a deadlock in the current dependency graph. + * @return true if a cycle is detected indicating potential deadlock, false otherwise. + * @note This method acquires a lock to ensure thread-safe graph traversal. + * Should be called when adding new dependencies to prevent deadlocks. + */ + bool IsExistDeadlock(); + + /** + * @brief Adds a dependency edge from one entity to another in the dependency graph. + * @param from The source entity that depends on the target entity. + * @param to The target entity that the source entity depends on. + * @note Duplicate edges are tracked but not added to the graph to maintain efficiency. + * This method is thread-safe. + */ + void AddDependencyEdge(const std::string &from, const std::string &to); + + /** + * @brief Removes a dependency edge from the graph. + * @param from The source entity of the dependency edge to remove. + * @param to The target entity of the dependency edge to remove. + * @note If the edge exists, it will be removed from both the graph and tracking sets. + * This method is thread-safe. + */ + void RemoveDependencyEdge(const std::string &from, const std::string &to); + +private: + /** + * @brief Performs depth-first search to detect cycles in the dependency graph. + * @param node The current node being visited in the DFS traversal. + * @param parent The parent node from which we reached the current node. + * @param visited Set of nodes that have been visited in the current DFS path. + * @return true if a cycle is detected, false otherwise. + * @note This is a recursive helper function used by IsExistDeadlock(). + */ + bool HasCycle(const std::string &node, const std::string &parent, + tbb::concurrent_unordered_set &visited); + + // Thread-safe dependency graph for deadlock detection using TBB containers. + TbbDeadLockDetectionGraph deadLockDetectionGraph_; + + // Protects non-thread-safe STL containers (edgesSet_ and duplicateEdges_) + std::mutex edgesMutex_; + + // Thread-safe set tracking all unique dependency edges in the graph. + tbb::concurrent_unordered_set> edgesSet_; + + // Thread-safe set tracking duplicate edges for monitoring and debugging. + tbb::concurrent_unordered_set> duplicateEdges_; +}; +} // namespace master +} // namespace datasystem +#endif \ No newline at end of file diff --git a/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.cpp b/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.cpp index 8c22970..187dd10 100644 --- a/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.cpp +++ b/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.cpp @@ -36,6 +36,13 @@ Status HcclRootInfoTable::PutRootInfo(const std::string &hcclPeerId, const std:: return Status::OK(); } +bool HcclRootInfoTable::IsExistRootInfo(const std::string &dstNpuId) +{ + std::shared_lock lock(rootInfoTableMutex_); + TbbRootInfoTable::accessor rootInfoAccess; + return rootInfoTable_.find(rootInfoAccess, dstNpuId); +} + void HcclRootInfoTable::GetAndEraseRootInfo(const std::string &hcclPeerId, const std::function &onGetCallback) { @@ -58,6 +65,11 @@ void HcclRootInfoTable::EraseRootInfo(const std::string &hcclPeerId) } } +bool HcclRootInfoSubscriptionTable::IsExistRecvRootInfoReq(const std::string &objectKey) +{ + return recvRootInfoRequestTable_.ObjectInRequest(objectKey); +} + void HcclRootInfoSubscriptionTable::EraseRootInfoSubscription(const std::string &hcclPeerId) { recvRootInfoRequestTable_.EraseSub(hcclPeerId); @@ -94,9 +106,10 @@ void HcclRootInfoSubscriptionTable::RemoveRecvRootInfoRequest(std::shared_ptr req) { LOG_IF_ERROR(ReturnFromRecvRootInfoRequest(req), "ReturnFromRecvRootInfoRequest failed"); @@ -125,6 +138,7 @@ Status HcclRootInfoSubscriptionTable::ReturnFromRecvRootInfoRequest(std::shared_ if (request->objects_.find(accessor, objectKey) && accessor->second != nullptr) { auto ¶m = accessor->second; resp.set_root_info(std::string(std::begin(param->rootInfo), std::end(param->rootInfo))); + resp.set_is_dead_lock(param->isDeadLock); isFindObj = true; } if (!isFindObj) { @@ -155,7 +169,7 @@ void HcclRelationshipTable::AddEdge(const std::string &dataReceiver, const std:: { std::shared_lock lock(mutex_); TbbHcclRelationshipTable::accessor acc; - (void)graph_.insert(acc, dataReceiver); + (void)hcclRelationshipGraph_.insert(acc, dataReceiver); acc->second.insert(dataSender); } @@ -163,7 +177,7 @@ bool HcclRelationshipTable::Contains(const std::string &dataReceiver, const std: { std::shared_lock lock(mutex_); TbbHcclRelationshipTable::const_accessor acc; - if (!graph_.find(acc, dataReceiver)) { + if (!hcclRelationshipGraph_.find(acc, dataReceiver)) { return false; } return acc->second.count(dataSender) > 0; @@ -173,7 +187,7 @@ std::set HcclRelationshipTable::GetClientNpuId(const std::string &c { std::set clientNpuIds; std::lock_guard lock(mutex_); - for (const auto &iter : graph_) { + for (const auto &iter : hcclRelationshipGraph_) { if (iter.first.find(clientId) != std::string::npos) { clientNpuIds.insert(iter.first); } @@ -191,13 +205,13 @@ void HcclRelationshipTable::EraseNode(const std::string &dataSender, std::set lock(mutex_); TbbHcclRelationshipTable::const_accessor acc; - if (graph_.find(acc, dataSender)) { + if (hcclRelationshipGraph_.find(acc, dataSender)) { connectIds = std::set(acc->second.begin(), acc->second.end()); } - graph_.erase(acc); + hcclRelationshipGraph_.erase(acc); std::set eraseList; - for (auto &iter : graph_) { + for (auto &iter : hcclRelationshipGraph_) { auto &connectSet = iter.second; if (connectSet.find(dataSender) == connectSet.end()) { continue; @@ -209,14 +223,14 @@ void HcclRelationshipTable::EraseNode(const std::string &dataSender, std::set lock(mutex_); - for (const auto &iter : graph_) { + for (const auto &iter : hcclRelationshipGraph_) { auto HcclRelationshipPb = req.add_hccl_relationship(); HcclRelationshipPb->set_data_receiver_id(iter.first); @@ -239,7 +253,7 @@ void HcclRelationshipTable::Clear() { std::lock_guard lock(mutex_); LOG(INFO) << "Clear HcclRelationshipTable."; - graph_.clear(); + hcclRelationshipGraph_.clear(); } } // namespace master diff --git a/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.h b/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.h index 086d200..8fa9263 100644 --- a/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.h +++ b/src/datasystem/master/object_cache/device/master_dev_hccl_rootinfo.h @@ -35,6 +35,13 @@ class HcclRootInfoTable { public: Status PutRootInfo(const std::string &dstNpuId, const std::string &rootInfo); + /** + * @brief Check if the RootInfo exists in the table for the specified NPU ID. + * @param[in] dstNpuId The destination NPU ID to check for existence. + * @return true if the RootInfo exists, false otherwise. + */ + bool IsExistRootInfo(const std::string &dstNpuId); + void GetAndEraseRootInfo(const std::string &dstNpuId, const std::function &onGetCallback); /** @@ -67,12 +74,15 @@ private: }; struct RecvRootInfoEntryParams { - static std::shared_ptr ConstructRecvRootInfoEntryParams(const std::string &rootInfo) + static std::shared_ptr ConstructRecvRootInfoEntryParams(const std::string &rootInfo, + const bool isDeadLock) { - return std::make_shared(RecvRootInfoEntryParams{ .rootInfo = rootInfo }); + return std::make_shared( + RecvRootInfoEntryParams{ .rootInfo = rootInfo, .isDeadLock = isDeadLock }); } const std::string rootInfo; + const bool isDeadLock; }; using RecvRootInfoRequest = UnaryRequest; @@ -91,6 +101,13 @@ public: */ Status AddRecvRootInfoRequest(const std::string &objectKey, std::shared_ptr &request); + /** + * @brief Check if the received RootInfo request exists for the specified object key. + * @param[in] objectKey The object key to check for existence in the request table. + * @return true if the request exists, false otherwise. + */ + bool IsExistRecvRootInfoReq(const std::string &objectKey); + /** * @brief Remove the RecvRootInfo request from the waiting requests table. * @param[in] request The request need to remove. @@ -103,7 +120,8 @@ public: * @param[in] rootInfo The rootInfo message. * @return Status of the call. */ - Status UpdateRecvRootInfoRequestForSuccess(const std::string &objectKey, const std::string &rootInfo); + Status UpdateRecvRootInfoRequestForSuccess(const std::string &objectKey, const std::string &rootInfo, + const bool isDeadLock = false); /** * @brief Reply to client with the device RecvRootInfo request. @@ -127,7 +145,7 @@ using TbbHcclRelationshipTable = tbb::concurrent_hash_map value: Set(DataSender) - + HcclRelationshipTable() = default; ~HcclRelationshipTable() = default; @@ -174,13 +192,13 @@ public: void SaveMigrateData(const MigrateMetadataReqPb &req); /** - * @brief Clear graph_ table. + * @brief Clear hcclRelationshipGraph_ table. */ void Clear(); private: std::shared_timed_mutex mutex_; - TbbHcclRelationshipTable graph_; + TbbHcclRelationshipTable hcclRelationshipGraph_; }; } // namespace master diff --git a/src/datasystem/master/object_cache/device/master_dev_oc_manager.cpp b/src/datasystem/master/object_cache/device/master_dev_oc_manager.cpp index 990b7d3..dcb232b 100644 --- a/src/datasystem/master/object_cache/device/master_dev_oc_manager.cpp +++ b/src/datasystem/master/object_cache/device/master_dev_oc_manager.cpp @@ -59,6 +59,8 @@ void MasterDevOcManager::Init() objectKeyLockTable_ = std::make_shared(); + masterDevDeadLockManager_ = std::make_shared(); + CheckAndClearDeviceMeta::GetInstance().AddSubscriber( MASTER_DEV_OC_MANAGER, [this](const std::string &objectKey) { return CheckAndClearDeviceMeta(objectKey); }); @@ -314,46 +316,90 @@ Status MasterDevOcManager::ProcessSubscribeReceiveEventRequest( Status MasterDevOcManager::SendRootInfoImpl(const SendRootInfoReqPb &req, SendRootInfoRspPb &resp) { - auto hcclPeerId = GetHcclPeerId(req.src_client_id(), req.src_device_id(), req.dst_client_id(), req.dst_device_id()); - std::string rootInfo = std::string(std::begin(req.root_info()), std::end(req.root_info())); - // Step 1: Update Key Value Table. - - LOG_IF_ERROR(deviceMetaOpRecordTable_->AddValue(RecordType::HCCLPEERID, req.dst_client_id(), hcclPeerId), - "add record error"); - LOG_IF_ERROR(deviceMetaOpRecordTable_->AddValue(RecordType::HCCLPEERID, req.src_client_id(), hcclPeerId), - "add record error"); + (void)resp; - rootInfoTable_->PutRootInfo(hcclPeerId, rootInfo); + // Extract request parameters + auto srcClientId = req.src_client_id(); + auto dstClientId = req.dst_client_id(); + auto hcclPeerId = GetHcclPeerId(req.src_client_id(), req.src_device_id(), req.dst_client_id(), req.dst_device_id()); + auto rootInfo = std::string(std::begin(req.root_info()), std::end(req.root_info())); auto dataSender = ConcatClientAndDeviceId(req.src_client_id(), req.src_device_id()); auto dataReceiver = ConcatClientAndDeviceId(req.dst_client_id(), req.dst_device_id()); - commRelationMap_->AddEdge(dataReceiver, dataSender); - // Step 2: Notify the Request Table. - // Try to update the SubscribeReceiveEvent request because there may be a request subscribed to the objectKey - RETURN_IF_NOT_OK(rootInfoSubscriptionTable_->UpdateRecvRootInfoRequestForSuccess(hcclPeerId, rootInfo)); - (void)resp; - return Status::OK(); + auto recordPeerIds = [&]() { + RETURN_IF_NOT_OK(deviceMetaOpRecordTable_->AddValue(RecordType::HCCLPEERID, dstClientId, hcclPeerId)); + return deviceMetaOpRecordTable_->AddValue(RecordType::HCCLPEERID, srcClientId, hcclPeerId); + }; + + auto updateCommRelation = [&]() { + commRelationMap_->AddEdge(ConcatClientAndDeviceId(dstClientId, req.dst_device_id()), + ConcatClientAndDeviceId(srcClientId, req.src_device_id())); + return rootInfoTable_->PutRootInfo(hcclPeerId, rootInfo); + }; + + // Case 1: Matching RecvRootInfo subscription exists + if (rootInfoSubscriptionTable_->IsExistRecvRootInfoReq(hcclPeerId)) { + masterDevDeadLockManager_->RemoveDependencyEdge(srcClientId, dstClientId); + RETURN_IF_NOT_OK(recordPeerIds()); + RETURN_IF_NOT_OK(updateCommRelation()); + + return rootInfoSubscriptionTable_->UpdateRecvRootInfoRequestForSuccess(hcclPeerId, rootInfo); + } + + // Case 2: No matching subscription exists + masterDevDeadLockManager_->AddDependencyEdge(srcClientId, dstClientId); + // Check for deadlock + if (masterDevDeadLockManager_->IsExistDeadlock()) { + std::string msg = + FormatString("Deadlock detected (peer:%s), notify the receiver to release the lock and retry.", hcclPeerId); + LOG(INFO) << msg; + masterDevDeadLockManager_->RemoveDependencyEdge(srcClientId, dstClientId); + RETURN_STATUS(K_CLIENT_DEADLOCK, msg); + } + + // Store root info and update communication relation + LOG(INFO) << FormatString("updateCommRelation, hcclPeerId:%s", hcclPeerId); + RETURN_IF_NOT_OK(recordPeerIds()); + return updateCommRelation(); } Status MasterDevOcManager::ProcessRecvRootInfoRequest( const RecvRootInfoReqPb &req, const std::shared_ptr> &serverApi) { + auto srcClientId = req.src_client_id(); + auto dstClientId = req.dst_client_id(); auto hcclPeerId = GetHcclPeerId(req.src_client_id(), req.src_device_id(), req.dst_client_id(), req.dst_device_id()); auto request = std::make_shared(std::vector{ hcclPeerId }, serverApi, req.dst_client_id(), req.dst_device_id(), req); - // Step 1: To avoid missing subscription notifications, - // we firstly subscribe and then check the key value table. - RETURN_IF_NOT_OK(rootInfoSubscriptionTable_->AddRecvRootInfoRequest(hcclPeerId, request)); + bool isExistRootInfo = rootInfoTable_->IsExistRootInfo(hcclPeerId); + if (isExistRootInfo) { + masterDevDeadLockManager_->RemoveDependencyEdge(srcClientId, dstClientId); + rootInfoTable_->GetAndEraseRootInfo(hcclPeerId, [&request, &hcclPeerId](const std::string &rootInfo) { + request->objects_.emplace(hcclPeerId, + RecvRootInfoEntryParams::ConstructRecvRootInfoEntryParams(rootInfo, false)); + request->numSatisfiedObjects_.fetch_add(1); + }); - // Step 2: Check root info. - rootInfoTable_->GetAndEraseRootInfo(hcclPeerId, [&request, &hcclPeerId](const std::string &rootInfo) { - request->objects_.emplace(hcclPeerId, RecvRootInfoEntryParams::ConstructRecvRootInfoEntryParams(rootInfo)); - request->numSatisfiedObjects_.fetch_add(1); - }); + RETURN_IF_NOT_OK(rootInfoSubscriptionTable_->AddRecvRootInfoRequest(hcclPeerId, request)); + + return DirectRespOrAddTimer(request, [this](std::shared_ptr req) { + return rootInfoSubscriptionTable_->ReturnFromRecvRootInfoRequest(std::move(req)); + }); + } + + masterDevDeadLockManager_->AddDependencyEdge(srcClientId, dstClientId); + bool isExistDeadlock = masterDevDeadLockManager_->IsExistDeadlock(); + if (isExistDeadlock) { + masterDevDeadLockManager_->RemoveDependencyEdge(srcClientId, dstClientId); + RecvRootInfoRspPb resp; + resp.set_is_dead_lock(true); + return request->serverApi_->Write(resp); + } + RETURN_IF_NOT_OK(rootInfoSubscriptionTable_->AddRecvRootInfoRequest(hcclPeerId, request)); return DirectRespOrAddTimer(request, [this](std::shared_ptr req) { return rootInfoSubscriptionTable_->ReturnFromRecvRootInfoRequest(std::move(req)); }); diff --git a/src/datasystem/master/object_cache/device/master_dev_oc_manager.h b/src/datasystem/master/object_cache/device/master_dev_oc_manager.h index ea92a50..a6e98bc 100644 --- a/src/datasystem/master/object_cache/device/master_dev_oc_manager.h +++ b/src/datasystem/master/object_cache/device/master_dev_oc_manager.h @@ -27,6 +27,7 @@ #include "datasystem/common/immutable_string/immutable_string.h" #include "datasystem/common/rpc/rpc_server_stream_base.h" #include "datasystem/client/hetero_cache/device_util.h" +#include "datasystem/master/object_cache/device/master_dev_dead_lock_manager.h" #include "datasystem/master/object_cache/device/master_dev_hccl_rootinfo.h" #include "datasystem/master/object_cache/device/master_dev_npu_events.h" #include "datasystem/master/object_cache/device/master_dev_oc_directory.h" @@ -321,7 +322,9 @@ private: std::shared_ptr commRelationMap_{ nullptr }; - std::shared_ptr objectKeyLockTable_{ nullptr }; + std::shared_ptr objectKeyLockTable_ {nullptr}; + + std::shared_ptr masterDevDeadLockManager_ {nullptr}; }; } // namespace master } // namespace datasystem diff --git a/src/datasystem/master/object_cache/master_master_oc_api.cpp b/src/datasystem/master/object_cache/master_master_oc_api.cpp index 431ee59..eb5767d 100644 --- a/src/datasystem/master/object_cache/master_master_oc_api.cpp +++ b/src/datasystem/master/object_cache/master_master_oc_api.cpp @@ -70,7 +70,7 @@ Status MasterMasterOCApi::ReleaseGRefsOfRemoteClientId(const ReleaseGRefsReqPb & { int64_t remainingTime = reqTimeoutDuration.CalcRemainingTime(); CHECK_FAIL_RETURN_STATUS(remainingTime > 0, K_RPC_DEADLINE_EXCEEDED, - FormatString("Request timeout (%ld ms).", -remainingTime)); + FormatString("Request timeout (%lld ms).", -remainingTime)); if (remainingTime > INT_MAX) { remainingTime = INT_MAX; } diff --git a/src/datasystem/master/object_cache/master_oc_service_impl.h b/src/datasystem/master/object_cache/master_oc_service_impl.h index e019a43..420f2c0 100644 --- a/src/datasystem/master/object_cache/master_oc_service_impl.h +++ b/src/datasystem/master/object_cache/master_oc_service_impl.h @@ -51,6 +51,9 @@ public: */ ~MasterOCServiceImpl(); + /** + * @brief shutdown master oc server impl. + */ void Shutdown(); /** diff --git a/src/datasystem/master/object_cache/master_worker_oc_api.cpp b/src/datasystem/master/object_cache/master_worker_oc_api.cpp index 754f7e0..8ad7322 100644 --- a/src/datasystem/master/object_cache/master_worker_oc_api.cpp +++ b/src/datasystem/master/object_cache/master_worker_oc_api.cpp @@ -33,6 +33,7 @@ namespace master { static constexpr int64_t MASTER_TIMEOUT_MINUS_MILLISECOND = 5 * 1000; static constexpr float MASTER_TIMEOUT_DESCEND_FACTOR = 0.9; +std::atomic MasterLocalWorkerOCApi::g_localTagGen_{1}; inline int64_t MasterGetRequestTimeout(int32_t timeout) { return std::max(int64_t(timeout * MASTER_TIMEOUT_DESCEND_FACTOR), timeout - MASTER_TIMEOUT_MINUS_MILLISECOND); @@ -307,30 +308,33 @@ Status MasterLocalWorkerOCApi::DeleteNotificationSend(std::unique_ptrDeleteNotification. + // A shared_mutex keeps concurrent Send/Receive operations on the same API + // instance serialized. RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(*req)); - deleteReq_.SetRealObject(std::move(req)); + tag = g_localTagGen_.fetch_add(1); + + std::unique_lock lock(localReqMutex_); + localReqMap_[tag] = std::move(req); return Status::OK(); } Status MasterLocalWorkerOCApi::DeleteNotificationReceive(int64_t tag, DeleteObjectRspPb &rsp) { - CHECK_FAIL_RETURN_STATUS(tag == 0, K_RUNTIME_ERROR, "Invalid tag for local api"); - CHECK_FAIL_RETURN_STATUS(deleteReq_.IsWLockedByCurrentThread(), K_RUNTIME_ERROR, - "Async read does not have request locked."); - auto reqPtr = deleteReq_.Detach(); - deleteReq_.WUnlock(); - RETURN_IF_NOT_OK(workerOC_->DeleteNotification(*reqPtr, rsp)); - return Status::OK(); + std::unique_ptr req; + { + std::unique_lock lock(localReqMutex_); + auto it = localReqMap_.find(tag); + if (it == localReqMap_.end() || !it->second) { + RETURN_STATUS(StatusCode::K_NOT_FOUND, FormatString("Local tag %ld not found", tag)); + } + req = std::move(it->second); + localReqMap_.erase(it); + } + return workerOC_->DeleteNotification(*req, rsp); } Status MasterLocalWorkerOCApi::QueryGlobalRefNumOnWorker(QueryGlobalRefNumReqPb &req, QueryGlobalRefNumRspPb &rsp) diff --git a/src/datasystem/master/object_cache/master_worker_oc_api.h b/src/datasystem/master/object_cache/master_worker_oc_api.h index b0f43cf..1534d9f 100644 --- a/src/datasystem/master/object_cache/master_worker_oc_api.h +++ b/src/datasystem/master/object_cache/master_worker_oc_api.h @@ -235,7 +235,11 @@ public: private: object_cache::MasterWorkerOCServiceImpl *workerOC_{ nullptr }; - SafeObject deleteReq_; + mutable std::shared_mutex localReqMutex_; // protects localReqMap_ + // Map from local tag to pending DeleteObject request. + std::unordered_map> localReqMap_; + // Atomic tag generator for unique local DeleteObject request identification. + static std::atomic g_localTagGen_; }; } // namespace master diff --git a/src/datasystem/master/object_cache/oc_global_cache_delete_manager.cpp b/src/datasystem/master/object_cache/oc_global_cache_delete_manager.cpp index 63bec7b..846f606 100644 --- a/src/datasystem/master/object_cache/oc_global_cache_delete_manager.cpp +++ b/src/datasystem/master/object_cache/oc_global_cache_delete_manager.cpp @@ -304,7 +304,7 @@ Status OCGlobalCacheDeleteManager::RecoverDeletedIds(bool isFromRocksdb, const s RETURN_IF_NOT_OK(objectStore_->PutToRocksStore(GLOBAL_CACHE_TABLE, iter.first, iter.second)); } } else { - if (isFromRocksdb) { + if (isFromRocksdb && objectStore_->IsRocksdbEnableWriteMeta()) { RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->GetAllFromRocks(GLOBAL_CACHE_TABLE, deleteObjects), "Load global cache delete objects from rocksdb failed."); } else { diff --git a/src/datasystem/master/object_cache/oc_global_cache_delete_manager.h b/src/datasystem/master/object_cache/oc_global_cache_delete_manager.h index 985c74e..38c7774 100644 --- a/src/datasystem/master/object_cache/oc_global_cache_delete_manager.h +++ b/src/datasystem/master/object_cache/oc_global_cache_delete_manager.h @@ -158,6 +158,7 @@ private: * @param[in] objectKey the object path in cloud storage, equal to objectKey * @param[in] objectVersion the object version * @param[in] maxVersionToDel indicate delete all the versions which <= maxVersionToDel + * @return Status of the call. */ Status DelPersistenceObj(const std::string &objectKey, uint64_t objectVersion, uint64_t maxVersionToDel); diff --git a/src/datasystem/master/object_cache/oc_metadata_manager.cpp b/src/datasystem/master/object_cache/oc_metadata_manager.cpp index 9efd118..d0853a3 100644 --- a/src/datasystem/master/object_cache/oc_metadata_manager.cpp +++ b/src/datasystem/master/object_cache/oc_metadata_manager.cpp @@ -63,6 +63,7 @@ #include "datasystem/protos/master_object.pb.h" #include "datasystem/protos/object_posix.pb.h" #include "datasystem/protos/worker_object.pb.h" +#include "datasystem/protos/worker_stream.pb.h" #include "datasystem/utils/status.h" #include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" #include "datasystem/worker/object_cache/master_worker_oc_service_impl.h" @@ -81,7 +82,8 @@ DS_DECLARE_int32(rpc_thread_num); DS_DECLARE_bool(enable_meta_replica); DS_DECLARE_bool(oc_io_from_l2cache_need_metadata); DS_DECLARE_bool(enable_reconciliation); -DS_DECLARE_string(other_az_names); +DS_DECLARE_string(other_cluster_names); +DS_DECLARE_string(rocksdb_write_mode); namespace datasystem { namespace master { @@ -444,7 +446,7 @@ void OCMetadataManager::SetMetaInfo(const ObjectMetaPb &newMeta, const std::stri Status OCMetadataManager::NotifyOtherAzNodeRemoveMeta(const std::string &objectKey, int64_t version, const ObjectMetaPb &newMeta) { - if (!FLAGS_other_az_names.empty()) { + if (!FLAGS_other_cluster_names.empty()) { LOG(INFO) << "Notify nodes in other clusters to remove meta for object: " << objectKey; } std::unordered_map metaAddrInfos; @@ -934,7 +936,6 @@ Status OCMetadataManager::CreateCopyMeta(const CreateCopyMetaReqPb &request, Cre "CreateCopyMeta: Cannot CreateCopyMeta with empty objectKey or server address."); FillRedirectResponseInfo(response, objectKey, redirect); RETURN_OK_IF_TRUE(redirect); - uint32_t writeMode; { // Check meta info in cache and rocksdb. Timer timer; @@ -951,14 +952,13 @@ Status OCMetadataManager::CreateCopyMeta(const CreateCopyMetaReqPb &request, Cre response.set_version(accessor->second.meta.version()); response.set_life_state(accessor->second.meta.life_state()); - writeMode = accessor->second.meta.config().write_mode(); // If the address already exists, return success. if (!accessor->second.locations.insert(address).second) { return Status::OK(); } accessor.release(); } - return objectStore_->AddObjectLocation(objectKey, address, WriteMode2MetaType(writeMode)); + return objectStore_->AddObjectLocation(objectKey, address); } std::string OCMetadataManager::SelectObjectLocation(const std::string &objectKey, const std::string &sourceWorker, @@ -1053,8 +1053,7 @@ Status OCMetadataManager::QueryMetaFromMetaTable(const QueryMetaReqPb &req, cons bool updateLocation = (ConsistencyType)(accessor->second.meta.config().consistency_type()) == ConsistencyType::PRAM; if (updateLocation && accessor->second.locations.insert(address).second) { - RETURN_IF_NOT_OK(objectStore_->AddObjectLocation( - objectKey, address, WriteMode2MetaType(accessor->second.meta.config().write_mode()))); + RETURN_IF_NOT_OK(objectStore_->AddObjectLocation(objectKey, address)); } accessor.release(); continue; @@ -1252,7 +1251,7 @@ void OCMetadataManager::GiveUpPrimaryLocation(const RemoveMetaReqPb &request, co (void)objectStore_->RemoveObjectLocation(objectKey, address); } for (const auto &objectKey : needRemoveIds) { - LOG_IF_ERROR(objectStore_->RemoveObjectLocation(objectKey, address, false), "Remove location failed"); + LOG_IF_ERROR(objectStore_->RemoveObjectLocation(objectKey, address), "Remove location failed"); LOG_IF_ERROR(objectStore_->RemoveMeta(objectKey, false), "Remove meta failed"); } SendChangePrimaryCopy(workerForChangePrimaryIds, response); @@ -1745,7 +1744,7 @@ void OCMetadataManager::NotifyDeleteAndClearMeta(DeleteObjectMediator &delMediat const auto &sendAllDelObjs = delMediator.GetIdsNeedToNotifyWorker(); INJECT_POINT_NO_RETURN("NotifyDeleteAndClearMeta"); Status lastErr = NotifyWorkerDelete(delMediator.GetSourceWorker(), sendAllDelObjs, false, failedNotifyObjects); - Raii removeIsDeletingObjs([&sendAllDelObjs, this] () { + Raii removeIsDeletingObjs([&sendAllDelObjs, this]() { for (const auto &info : sendAllDelObjs) { std::lock_guard l(isDeletingObjMutex_); isDeletingObjs_.erase(info.first); @@ -1876,7 +1875,7 @@ Status OCMetadataManager::ClearOneMetaInfo(const TbbMetaTable::const_accessor &a const auto &objectKey = accessor->first; // remove object location. for (const auto &address : accessor->second.locations) { - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->RemoveObjectLocation(objectKey, address, !isDataMigration), + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->RemoveObjectLocation(objectKey, address), FormatString("[ObjectKey %s] RemoveObjectLocation failed", objectKey)); } // remote meta info @@ -2139,8 +2138,7 @@ Status OCMetadataManager::RecoverObjectLocations( return Status::OK(); } -Status OCMetadataManager::LoadObjectLocations(bool isFromRocksdb, const std::vector &workerUuids, - const worker::HashRange &extraRanges, +Status OCMetadataManager::LoadObjectLocations(bool isFromRocksdb, std::unordered_map> &objLocMap) { INJECT_POINT("OCNotifyWorkerManager.NoNeedRecoveryMeta"); @@ -2148,10 +2146,6 @@ Status OCMetadataManager::LoadObjectLocations(bool isFromRocksdb, const std::vec if (isFromRocksdb) { RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->GetAllFromRocks(LOCATION_TABLE, objectLocations), "Load object location from rocksdb into memory failed."); - } else { - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->GetFromEtcd(ETCD_LOCATION_TABLE_PREFIX, LOCATION_TABLE, - workerUuids, extraRanges, objectLocations), - "Load object location from etcd into memory failed."); } for (auto &info : objectLocations) { // key format: WorkerAddr_ObjectKey @@ -2257,10 +2251,10 @@ Status OCMetadataManager::LoadMeta(bool isFromRocksdb, const std::vector> objLocMap; - RETURN_IF_NOT_OK(LoadObjectLocations(isFromRocksdb, workerUuids, extraRanges, objLocMap)); + RETURN_IF_NOT_OK(LoadObjectLocations(isFromRocksdb, objLocMap)); RETURN_IF_NOT_OK(HandleLoadMeta(metas, expireObjects, objLocMap, isFromRocksdb, workerUuids, extraRanges)); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(RecoverObjectLocations(objLocMap), "Recovery object locations into memory failed"); - if (isFromRocksdb) { + if (isFromRocksdb && objectStore_->IsRocksdbEnableWriteMeta()) { RETURN_IF_NOT_OK_PRINT_ERROR_MSG(nestedRefManager_->RecoverRelationshipData(NESTED_TABLE, NESTED_COUNT_TABLE), "Load Nested relationship for rocksdb failed."); } @@ -4155,7 +4149,7 @@ Status OCMetadataManager::CheckRocksdbStatusAndLoadL2Table(const std::string &ta RETURN_IF_NOT_OK(objectStore_->PutToRocksStore(rocksTable, iter.first, iter.second)); } } else { - if (isFromRocksdb) { + if (isFromRocksdb && objectStore_->IsRocksdbEnableWriteMeta()) { RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->GetAllFromRocks(rocksTable, outMetas), "Load meta from rocksdb into memory failed."); LOG(INFO) << "Load meta from rocksdb, count:" << outMetas.size(); @@ -4171,6 +4165,7 @@ Status OCMetadataManager::CheckRocksdbStatusAndLoadL2Table(const std::string &ta Status OCMetadataManager::ReplacePrimary(const ReplacePrimaryReqPb &req, ReplacePrimaryRspPb &rsp) { + INJECT_POINT("OCMetadataManager.ReplacePrimary"); std::vector notRedirectObjectKeys; std::transform(req.object_infos().begin(), req.object_infos().end(), std::back_inserter(notRedirectObjectKeys), [](const ReplacePrimaryReqPb::ObjectInfoPb &info) { return info.object_key(); }); diff --git a/src/datasystem/master/object_cache/oc_metadata_manager.h b/src/datasystem/master/object_cache/oc_metadata_manager.h index b6ffbbc..4d44455 100644 --- a/src/datasystem/master/object_cache/oc_metadata_manager.h +++ b/src/datasystem/master/object_cache/oc_metadata_manager.h @@ -1232,14 +1232,10 @@ private: /** * @brief Load object locations * @param[in] isFromRocksdb Specifies whether to obtain data from rocksdb. - * @param[in] workerUuids Recover location of specified worker uuids. If the value is empty, recover the data of the - * current worker. - * @param[in] extraRanges Recover location of specified hash ranges if not empty. * @param[out] objLocMap The map record object and locations. * @return Status of the call */ - Status LoadObjectLocations(bool isFromRocksdb, const std::vector &workerUuids, - const worker::HashRange &extraRanges, + Status LoadObjectLocations(bool isFromRocksdb, std::unordered_map> &objLocMap); /** diff --git a/src/datasystem/master/object_cache/oc_notify_worker_manager.cpp b/src/datasystem/master/object_cache/oc_notify_worker_manager.cpp index 1761aa7..efbc8be 100644 --- a/src/datasystem/master/object_cache/oc_notify_worker_manager.cpp +++ b/src/datasystem/master/object_cache/oc_notify_worker_manager.cpp @@ -80,11 +80,20 @@ OCNotifyWorkerManager::~OCNotifyWorkerManager() } } +struct SendResult { + std::shared_ptr api; + int64_t tag = -1; + std::string address; + Status status; +}; // Result bundle for a single DeleteObject notification + Status OCNotifyWorkerManager::Init() { LOG(INFO) << "init OCNotifyWorkerManager" << this; thread_ = std::make_unique(&OCNotifyWorkerManager::ProcessAsyncNotifyOp, this); thread_->set_name("ProcessAsyncNotifyOp"); + deleteThreadPool_ = std::make_unique(minDeleteThreadSize, maxDeleteThreadSize, + "NotifyDeleteSend"); EraseFailedNodeApiEvent::GetInstance().AddSubscriber(subscriberPrefix_ + "OCNotifyWorkerManager", [this](HostPort &node) { EraseMasterWorkerApi(node); }); RemoveDeadWorkerEvent::GetInstance().AddSubscriber( @@ -518,7 +527,7 @@ Status OCNotifyWorkerManager::RecoverCacheInvalidAndRemoveMeta(bool isFromRocksd RETURN_IF_NOT_OK(objectStore_->PutToRocksStore(ASYNC_WORKER_OP_TABLE, iter.first, iter.second)); } } else { - if (isFromRocksdb) { + if (isFromRocksdb && objectStore_->IsRocksdbEnableWriteMeta()) { RETURN_IF_NOT_OK_PRINT_ERROR_MSG(objectStore_->GetAllFromRocks(ASYNC_WORKER_OP_TABLE, cacheInvalids), "Load meta from rocksdb into memory failed."); } else { @@ -675,7 +684,6 @@ Status OCNotifyWorkerManager::DoNotifyWorkerDelete( { std::unordered_map, std::pair> api2Tag; Status lastErr = DoNotifyWorkerDeleteSendRequest(sourceWorker, replicas2Obj, isAsync, failedObjects, api2Tag); - for (const auto &kv : api2Tag) { const auto &masterWorkerApi = kv.first; const auto &tag = kv.second.first; @@ -712,7 +720,6 @@ Status OCNotifyWorkerManager::DoNotifyWorkerDelete( LOG(INFO) << FormatString("Start to remove meta location for objects[%s]", oss.str()); replicas2Obj.erase(address); } - return lastErr; } @@ -738,7 +745,7 @@ Status OCNotifyWorkerManager::ClearDataWithoutMeta(const worker::HashRange &rang } void OCNotifyWorkerManager::SetDeleteObjectReq( - std::unique_ptr &request, bool &isAsync, const std::string &sourceWorker, + std::unique_ptr &request, bool isAsync, const std::string &sourceWorker, const std::unordered_map> &objectItem) { for (const auto &item : objectItem) { @@ -750,49 +757,22 @@ void OCNotifyWorkerManager::SetDeleteObjectReq( VLOG(1) << "Notify worker to delete the object " << LogHelper::IgnoreSensitive(*request); } -Status OCNotifyWorkerManager::HandleDeleteNotificationSend( - const std::string &sourceWorker, const std::string &address, - const std::unordered_map> &objectItem, bool &isAsync, - std::unordered_map, std::pair> &api2Tag, Status &lastErr) -{ - std::shared_ptr masterWorkerApi; - Status status = GetMasterWorkerApi(address, masterWorkerApi); - - auto request = std::make_unique(); - SetDeleteObjectReq(request, isAsync, sourceWorker, objectItem); - - if (status.IsOk()) { - int64_t tag; - LOG(INFO) << FormatString("Send delete notify to: %s, objects[%s]", address, - VectorToString(request->object_keys())); - status = masterWorkerApi->DeleteNotificationSend(std::move(request), tag); - if (status.IsOk()) { - api2Tag.emplace(masterWorkerApi, std::make_pair(tag, address)); - } - } - if (status.IsError()) { - LOG(ERROR) << FormatString( - "Connect to worker failed when notify worker to delete object, address: %s, error: %s", address, - status.ToString()); - if (!isAsync) { - lastErr = status; - return status; - } - } - return Status::OK(); -} - Status OCNotifyWorkerManager::DoNotifyWorkerDeleteSendRequest( const std::string &sourceWorker, std::unordered_map>> &replicas2Obj, bool isAsync, std::unordered_set &failedObjects, std::unordered_map, std::pair> &api2Tag) { - Status lastErr; std::vector> asyncNotifyIds; - for (const auto &kv : replicas2Obj) { - auto &address = kv.first; - auto &objectItem = kv.second; + Timer timer; + int64_t realTimeoutMs = timeoutDuration.CalcRealRemainingTime(); + std::string traceID = Trace::Instance().GetTraceID(); + std::vector> futures; + futures.reserve(replicas2Obj.size()); + std::atomic needAbort{false}; + for (const auto &item : replicas2Obj) { + const auto &address = item.first; + const auto &objectItem = item.second; if (objectItem.empty()) { continue; } @@ -803,12 +783,46 @@ Status OCNotifyWorkerManager::DoNotifyWorkerDeleteSendRequest( if (!HandleWorkerDisconnection(address, objectItem, asyncNotifyIds)) { continue; } - Status rc = HandleDeleteNotificationSend(sourceWorker, address, objectItem, isAsync, api2Tag, lastErr); - if (rc.IsError()) { + if (needAbort.load()) { + LOG(WARNING) << "Aborting remaining tasks due to timeout."; break; } + futures.emplace_back( + deleteThreadPool_->Submit([=, &needAbort, &timer]() -> SendResult { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); + int64_t elapsed = static_cast(timer.ElapsedMilliSecond()); + if (elapsed >= realTimeoutMs) { + LOG(ERROR) << "RPC timeout. time elapsed " << elapsed << ", realTimeoutMs:" << realTimeoutMs + << ", NotifyDeleteSend threads Statistics: " << deleteThreadPool_->GetStatistics(); + needAbort.store(true); + return {nullptr, -1, address, Status(StatusCode::K_RUNTIME_ERROR, "Rpc timeout")}; + } + timeoutDuration.Init(realTimeoutMs - elapsed); + std::shared_ptr api; + Status st = GetMasterWorkerApi(address, api); + int64_t tag = -1; + if (st.IsOk()) { + auto req = std::make_unique(); + SetDeleteObjectReq(req, isAsync, sourceWorker, objectItem); + st = api->DeleteNotificationSend(std::move(req), tag); + } + return {api, tag, address, st}; + }) + ); + } + Status lastErr; + SendResult res; + for (auto &f : futures) { + res = f.get(); + if (res.status.IsError()) { + LOG(ERROR) << "Send delete to " << res.address << " failed: " << res.status.ToString(); + if (!isAsync) { + lastErr = res.status; + } + } else { + api2Tag.emplace(res.api, std::make_pair(res.tag, res.address)); + } } - RETURN_IF_NOT_OK(AsyncNotifyWorkerDelete(asyncNotifyIds, replicas2Obj, failedObjects)); return lastErr; } diff --git a/src/datasystem/master/object_cache/oc_notify_worker_manager.h b/src/datasystem/master/object_cache/oc_notify_worker_manager.h index b1e2e52..f4ad782 100644 --- a/src/datasystem/master/object_cache/oc_notify_worker_manager.h +++ b/src/datasystem/master/object_cache/oc_notify_worker_manager.h @@ -141,25 +141,9 @@ public: * @param[in] objectItem The failed object list. */ - void SetDeleteObjectReq(std::unique_ptr &request, bool &isAsync, const std::string &sourceWorker, + void SetDeleteObjectReq(std::unique_ptr &request, bool isAsync, const std::string &sourceWorker, const std::unordered_map> &objectItem); - /** - * @brief Notify worker which has object data to delete. - * @param[in] sourceWorker The worker initiates to delete object. - * @param[in] replicas2Obj Worker address and their object key. - * @param[in] isAsync Is async process mode. - * @param[out] failedObjects The failed object list. - * @param[out] api2Tag The map from request to tag. - * @param[out] lastErr The last error status of request. - * @return Status of the call. - */ - Status HandleDeleteNotificationSend( - const std::string &sourceWorker, const std::string &address, - const std::unordered_map> &objectItem, bool &isAsync, - std::unordered_map, std::pair> &api2Tag, - Status &lastErr); - /** * @brief Notify worker which has object data to delete. * @param[in] sourceWorker The worker initiates to delete object. @@ -559,6 +543,10 @@ private: Status RemoveNoTargetAsyncWorkerOp( const std::unordered_map> &objectKeys, NotifyWorkerOpType op); + const size_t minDeleteThreadSize = 1; + const size_t maxDeleteThreadSize = 8; + // Global thread pool for reusing worker threads across delete requests. + std::unique_ptr deleteThreadPool_; std::shared_ptr objectStore_; // Metadata store for object. std::shared_timed_mutex notifyWorkerOpMutex_; TbbNotifyWorkerOpTable notifyWorkerOpTable_; // Key is worker address, value is object keys. diff --git a/src/datasystem/master/object_cache/store/meta_async_queue.h b/src/datasystem/master/object_cache/store/meta_async_queue.h index 5fb5b66..ad04e75 100644 --- a/src/datasystem/master/object_cache/store/meta_async_queue.h +++ b/src/datasystem/master/object_cache/store/meta_async_queue.h @@ -31,7 +31,9 @@ #include #include #include + #include "datasystem/common/util/format.h" +#include "datasystem/utils/status.h" namespace datasystem { namespace master { @@ -238,6 +240,19 @@ public: return os; } + void SetPostHandler(std::function &&postHandler) + { + postHandler_ = std::move(postHandler); + } + + Status ExcutePostHandler() + { + if (postHandler_ == nullptr) { + return Status::OK(); + } + return postHandler_(); + } + private: const ReqType reqType_; const std::string objectKey_; @@ -246,6 +261,7 @@ private: const std::string value_; const std::chrono::time_point beginTimestamp_; const std::string traceID_; + std::function postHandler_ = nullptr; }; class MetaAsyncQueue { diff --git a/src/datasystem/master/object_cache/store/object_meta_store.cpp b/src/datasystem/master/object_cache/store/object_meta_store.cpp index cfb2366..9c1e9f5 100644 --- a/src/datasystem/master/object_cache/store/object_meta_store.cpp +++ b/src/datasystem/master/object_cache/store/object_meta_store.cpp @@ -52,6 +52,7 @@ DS_DEFINE_uint32(etcd_meta_pool_size, 8, "ETCD metadata async pool size"); DS_DECLARE_bool(oc_io_from_l2cache_need_metadata); +DS_DECLARE_string(rocksdb_write_mode); static bool ValidateEtcdPoolSize(const char *flagName, uint32_t value) { @@ -121,8 +122,6 @@ Status ObjectMetaStore::InitEtcdStore() // Hash table for normal key. RETURN_IF_NOT_OK(etcdStore_->CreateTable(std::string(ETCD_META_TABLE_PREFIX) + ETCD_HASH_SUFFIX, std::string(ETCD_META_TABLE_PREFIX) + ETCD_HASH_SUFFIX)); - RETURN_IF_NOT_OK(etcdStore_->CreateTable(std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_HASH_SUFFIX, - std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_HASH_SUFFIX)); RETURN_IF_NOT_OK(etcdStore_->CreateTable(std::string(ETCD_ASYNC_WORKER_OP_TABLE_PREFIX) + ETCD_HASH_SUFFIX, std::string(ETCD_ASYNC_WORKER_OP_TABLE_PREFIX) + ETCD_HASH_SUFFIX)); RETURN_IF_NOT_OK(etcdStore_->CreateTable(std::string(ETCD_GLOBAL_CACHE_TABLE_PREFIX) + ETCD_HASH_SUFFIX, @@ -131,8 +130,6 @@ Status ObjectMetaStore::InitEtcdStore() // Worker table for key with worker id. RETURN_IF_NOT_OK(etcdStore_->CreateTable(std::string(ETCD_META_TABLE_PREFIX) + ETCD_WORKER_SUFFIX, std::string(ETCD_META_TABLE_PREFIX) + ETCD_WORKER_SUFFIX)); - RETURN_IF_NOT_OK(etcdStore_->CreateTable(std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_WORKER_SUFFIX, - std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_WORKER_SUFFIX)); RETURN_IF_NOT_OK(etcdStore_->CreateTable(std::string(ETCD_ASYNC_WORKER_OP_TABLE_PREFIX) + ETCD_WORKER_SUFFIX, std::string(ETCD_ASYNC_WORKER_OP_TABLE_PREFIX) + ETCD_WORKER_SUFFIX)); return etcdStore_->CreateTable(std::string(ETCD_GLOBAL_CACHE_TABLE_PREFIX) + ETCD_WORKER_SUFFIX, @@ -239,12 +236,28 @@ void ObjectMetaStore::AsyncMetaOpToEtcdStorageHandler(int threadNum, const std:: if (!ret || element == nullptr) { continue; } - INJECT_POINT("AsyncMetaOpToEtcdStorageHandler.delete.delay", [element](int delayS) { - if (element->Table().find(ETCD_META_TABLE_PREFIX) != std::string::npos) { - std::this_thread::sleep_for(std::chrono::seconds(delayS)); +#ifdef WITH_TESTS + static const auto injectFunc = [&element, this](int delayMs, const std::string &tableName, bool passAdd = false, + bool passDel = false) { + if (passAdd && element->RequestType() == AsyncElement::ReqType::ADD) { + return; } - return; - }); + if (passDel && element->RequestType() == AsyncElement::ReqType::DEL) { + return; + } + if (element->Table().find(tableName) != std::string::npos) { + Timer timer; + while (running_ && timer.ElapsedMilliSecond() < delayMs) { + const int checkIntervalMs = 100; + std::this_thread::sleep_for(std::chrono::milliseconds(checkIntervalMs)); + } + } + }; +#endif + INJECT_POINT("ObjectMetaStore.AsyncMetaOpToEtcdStorageHandler.Delay.MetaTable", + [](int delayS) { injectFunc(delayS, ETCD_META_TABLE_PREFIX); }); + INJECT_POINT("ObjectMetaStore.AsyncMetaOpToEtcdStorageHandler.Delay.GlobalCacheTable.PassAdd", + [](int delayS) { injectFunc(delayS, ETCD_GLOBAL_CACHE_TABLE_PREFIX, true); }); const auto &etcdKey = element->Key(); TraceGuard traceGuard = Trace::Instance().SetTraceNewID(element->TraceID()); VLOG(1) << FormatString("handler %d get key: %s", threadNum, etcdKey); @@ -259,6 +272,7 @@ void ObjectMetaStore::AsyncMetaOpToEtcdStorageHandler(int threadNum, const std:: break; case AsyncElement::ReqType::DEL: EXEC_UTIL_SUCCESS(etcdStore_->Delete(element->Table(), etcdKey, asyncElapse), K_NOT_FOUND, !running_); + LOG_IF_ERROR(element->ExcutePostHandler(), "Excute post handler failed, etcd key: " + etcdKey); break; default: LOG(WARNING) << "unknown operation: " << static_cast(element->RequestType()); @@ -271,7 +285,7 @@ void ObjectMetaStore::AsyncMetaOpToEtcdStorageHandler(int threadNum, const std:: Status ObjectMetaStore::AddOneAsyncTaskToEtcdStore(const std::string &objectKey, const std::string &table, const std::string &etcdKey, const std::string &value, AsyncElement::ReqType requestType, uint64_t timestamp, - const std::string &traceId) + const std::string &traceId, std::function &&postHandler) { CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(!queues_.empty(), K_NOT_READY, "It does not support executing etcd asynchronous tasks currently."); @@ -279,6 +293,9 @@ Status ObjectMetaStore::AddOneAsyncTaskToEtcdStore(const std::string &objectKey, auto element = timestamp == 0 ? std::make_shared(objectKey, table, etcdKey, value, requestType) : std::make_shared(objectKey, table, etcdKey, value, requestType, timestamp, traceId); + if (postHandler) { + element->SetPostHandler(std::move(postHandler)); + } auto threadIdx = MurmurHash3_32(objectKey) % FLAGS_etcd_meta_pool_size; std::shared_ptr elderElement; int incrCnt = 0; @@ -354,7 +371,7 @@ Status ObjectMetaStore::BatchPutToEtcdStore(const std::string &tablePrefix, } Status ObjectMetaStore::RemoveEtcdKey(const std::string &objectKey, const std::string &key, - const std::string &tablePrefix) + const std::string &tablePrefix, std::function &&postHandler) { RETURN_OK_IF_TRUE(!EtcdEnable()); auto res = Split(objectKey, ";"); @@ -387,7 +404,8 @@ Status ObjectMetaStore::RemoveEtcdKey(const std::string &objectKey, const std::s Status rc; std::string etcdKey = Hash2Str(hash) + "/" + key; if (async) { - RETURN_IF_NOT_OK(AddOneAsyncTaskToEtcdStore(objectKey, table, etcdKey, "", AsyncElement::ReqType::DEL)); + RETURN_IF_NOT_OK(AddOneAsyncTaskToEtcdStore(objectKey, table, etcdKey, "", AsyncElement::ReqType::DEL, 0, "", + std::move(postHandler))); } else { rc = etcdStore_->Delete(table, etcdKey); } @@ -395,9 +413,9 @@ Status ObjectMetaStore::RemoveEtcdKey(const std::string &objectKey, const std::s if (rc.IsError() && rc.GetCode() != StatusCode::K_NOT_FOUND) { std::lock_guard l(etcdMtx_); (void)etcdKeyMap_[table].emplace(key, std::make_pair(hash, async)); + return rc; } - RETURN_OK_IF_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND); - return rc; + return postHandler != nullptr && !async ? postHandler() : Status::OK(); } void ObjectMetaStore::PrefixSearchAndErase(const std::string &table, const std::string &prefixKey, @@ -479,6 +497,11 @@ Status ObjectMetaStore::InitRocksStore() return Replica::CreateOcTable(rocksStore_); } +bool ObjectMetaStore::IsRocksdbEnableWriteMeta() +{ + return FLAGS_rocksdb_write_mode != "none"; +} + Status ObjectMetaStore::AddRocksdbHealthTag() { RETURN_IF_NOT_OK(PutToRocksStore(HEALTH_TABLE, "status", HEALTH_STATUS)); @@ -505,7 +528,7 @@ void ObjectMetaStore::GetMetasMatch( for (size_t i = 0; i < queues_.size(); ++i) { uint64_t count = 0; queues_[i]->PollMetasByObjectKey(std::forward>(matchFunc), objAsyncMap, - count); + count); (void)asyncReqSize_.fetch_sub(count, std::memory_order_relaxed); } { @@ -517,7 +540,7 @@ void ObjectMetaStore::GetMetasMatch( } void ObjectMetaStore::PollAsyncElementsByObjectKey(const std::string &objectKey, - std::unordered_set> &elements) + std::unordered_set> &elements) { if (queues_.empty()) { LOG(WARNING) << "It does not support executing etcd asynchronous tasks currently."; @@ -582,6 +605,7 @@ Status ObjectMetaStore::CreateOrUpdateBatchMeta(std::unordered_mapBatchPut(META_TABLE, metaInfos), FormatString("Failed to add object meta: %s", MapToString(metaInfos))); + Status rc = BatchPutToEtcdStore(ETCD_META_TABLE_PREFIX, metaInfos, type, true); if (rc.IsError()) { LOG(ERROR) << FormatString("Failed to add object meta to etcd store: %s", MapToString(metaInfos)); @@ -603,33 +627,23 @@ Status ObjectMetaStore::RemoveMeta(const std::string &key, bool needRemoveEtcdDa return Status::OK(); } -Status ObjectMetaStore::AddObjectLocation(const std::string &objectKey, const std::string &workerAddr, WriteType type) +Status ObjectMetaStore::AddObjectLocation(const std::string &objectKey, const std::string &workerAddr) { RETURN_OK_IF_TRUE(!isPersistenceEnabled_); PerfPoint point(PerfKey::MASTER_ROCKSDB_ADD_OBJ_LOCATION); std::string key = workerAddr + "_" + objectKey; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(LOCATION_TABLE, key, ""), FormatString("Failed to add global ref to rocksdb: %s", key)); - Status rc = PutToEtcdStore(ETCD_LOCATION_TABLE_PREFIX, objectKey, key, "", type); - if (rc.IsError()) { - LOG(ERROR) << FormatString("Failed to add object meta to etcd store: %s", key); - (void)rocksStore_->Delete(LOCATION_TABLE, key); - } - return rc; + return Status::OK(); } -Status ObjectMetaStore::RemoveObjectLocation(const std::string &objectKey, const std::string &workerAddr, - bool needRemoveEtcdData) +Status ObjectMetaStore::RemoveObjectLocation(const std::string &objectKey, const std::string &workerAddr) { RETURN_OK_IF_TRUE(!isPersistenceEnabled_); PerfPoint point(PerfKey::MASTER_ROCKSDB_REMOVE_OBJ_LOCATION); std::string key = workerAddr + "_" + objectKey; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(RemoveRocksKey(key, LOCATION_TABLE), FormatString("Failed to delete location from rocksdb: %s", key)); - if (needRemoveEtcdData) { - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(RemoveEtcdKey(objectKey, key, ETCD_LOCATION_TABLE_PREFIX), - FormatString("Failed to delete location from etcd: %s", key)); - } return Status::OK(); } @@ -653,7 +667,6 @@ Status ObjectMetaStore::AddNestedRelationship(const std::string &parentObjKey, c { RETURN_OK_IF_TRUE(!isPersistenceEnabled_); std::string key = parentObjKey + "_" + childObjKey; - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(NESTED_TABLE, key, childObjKey), FormatString("Failed to add nested relationship, objectKey: %s", childObjKey)); return Status::OK(); @@ -892,11 +905,16 @@ Status ObjectMetaStore::RemoveDeletedObject(const std::string &objectKey, uint64 { std::string versionStr = std::to_string(version); std::string key = objectKey + "/" + versionStr; + + auto postRemoveEtcdKeyFunc = [this, key, objectKey, versionStr]() -> Status { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + RemoveRocksKey(key, GLOBAL_CACHE_TABLE), + FormatString("Failed to delete l2 cache from rocksdb: objectKey=%s,version=%s", objectKey, versionStr)); + return Status::OK(); + }; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( - RemoveRocksKey(key, GLOBAL_CACHE_TABLE), - FormatString("Failed to delete l2 cache from rocksdb: objectKey=%s,version=%s", objectKey, versionStr)); - RETURN_IF_NOT_OK_PRINT_ERROR_MSG( - RemoveEtcdKey(objectKey, key, ETCD_GLOBAL_CACHE_TABLE_PREFIX), + RemoveEtcdKey(objectKey, key, ETCD_GLOBAL_CACHE_TABLE_PREFIX, std::move(postRemoveEtcdKeyFunc)), FormatString("Failed to delete l2 cache from etcd: objectKey=%s,version=%s", objectKey, versionStr)); return Status::OK(); } diff --git a/src/datasystem/master/object_cache/store/object_meta_store.h b/src/datasystem/master/object_cache/store/object_meta_store.h index bd65e3c..c8887f5 100644 --- a/src/datasystem/master/object_cache/store/object_meta_store.h +++ b/src/datasystem/master/object_cache/store/object_meta_store.h @@ -146,20 +146,17 @@ public: * @brief Add object location to rocksdb. * @param[in] objectKey Object key to be added. * @param[in] workerAddr Location of the object. - * @param[in] type Kv write type. * @return Status of the call. */ - Status AddObjectLocation(const std::string &objectKey, const std::string &workerAddr, WriteType type = ROCKS_ONLY); + Status AddObjectLocation(const std::string &objectKey, const std::string &workerAddr); /** * @brief Remove object location from rocksdb. * @param[in] objectKey Object key to be removed. * @param[in] workerAddr Location of the object. - * @param[in] needRemoveEtcdData Indicates whether to delete etcd data. * @return Status of the call. */ - Status RemoveObjectLocation(const std::string &objectKey, const std::string &workerAddr, - bool needRemoveEtcdData = true); + Status RemoveObjectLocation(const std::string &objectKey, const std::string &workerAddr); /** * @brief Get all pairs from KvStore table @@ -353,6 +350,12 @@ public: return isRocksdbRunning_; } + /** + * @brief Check whether to support metadata written to rocksdb. + * @return True if support metadata written to rocksdb. + */ + bool IsRocksdbEnableWriteMeta(); + /** * @brief Change rocksdb to running. */ @@ -399,7 +402,7 @@ public: * @param[out] elements Async elements. */ void PollAsyncElementsByObjectKey(const std::string &objectKey, - std::unordered_set> &elements); + std::unordered_set> &elements); /** * @brief Insert wait async elements to object meta store. @@ -466,9 +469,12 @@ private: * @param[in] objectKey Object key. * @param[in] key Key need to store. * @param[in] tablePrefix ETCD table prefix. + * @param[in] postHandler To ensure consistency, some things can only be done after successfully deleting the key in + * etcd. * @return Status of the call. */ - Status RemoveEtcdKey(const std::string &objectKey, const std::string &key, const std::string &tablePrefix); + Status RemoveEtcdKey(const std::string &objectKey, const std::string &key, const std::string &tablePrefix, + std::function &&postHandler = nullptr); /** * @brief Prefix search key and erase them from etcdKeyMap_ @@ -547,12 +553,14 @@ private: * @param[in] etcdKey Key need to remove. * @param[in] value Value need to store * @param[in] requestType The request's type, see the AsyncEtcdOpElement::RequestType for details. + * @param[in] postHandler To ensure consistency, some things can only be done after successfully deleting the key in + * etcd. * @return Status of the call */ Status AddOneAsyncTaskToEtcdStore(const std::string &objectKey, const std::string &table, const std::string &etcdKey, const std::string &value, AsyncElement::ReqType requestType, uint64_t timestamp = 0, - const std::string &traceId = ""); + const std::string &traceId = "", std::function &&postHandler = nullptr); // The backend rocksdb storage. RocksStore *rocksStore_; diff --git a/src/datasystem/master/replica_manager.cpp b/src/datasystem/master/replica_manager.cpp index c6f9967..c7eeedf 100644 --- a/src/datasystem/master/replica_manager.cpp +++ b/src/datasystem/master/replica_manager.cpp @@ -42,7 +42,9 @@ #include "datasystem/master/replica_rpc_channel_impl.h" #include "datasystem/master/object_cache/oc_metadata_manager.h" #include "datasystem/master/object_cache/store/object_meta_store.h" +#include "datasystem/master/stream_cache/sc_metadata_manager.h" #include "datasystem/protos/worker_object.pb.h" +#include "datasystem/protos/worker_stream.pb.h" #include "datasystem/utils/status.h" #include "datasystem/worker/cluster_event_type.h" #include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" @@ -51,6 +53,7 @@ DS_DEFINE_bool(enable_meta_replica, false, "Controls whether to enable multiple meta replica"); DS_DECLARE_uint32(rolling_update_timeout_s); +DS_DECLARE_string(rocksdb_write_mode); namespace datasystem { const std::string REPLICA_MANAGER = "ReplicaManager"; @@ -71,6 +74,9 @@ void MetadataManager::Shutdown() if (oc != nullptr) { oc->Shutdown(); } + if (sc != nullptr) { + sc->Shutdown(); + } } ReplicaManager::~ReplicaManager() @@ -92,6 +98,10 @@ ReplicaManager::~ReplicaManager() Status ReplicaManager::Init(ReplicaManagerParam param) { LOG(INFO) << "Init replica manager."; + if (FLAGS_enable_meta_replica && FLAGS_rocksdb_write_mode == "none") { + RETURN_STATUS(StatusCode::K_INVALID, + "When using enable_meta_replica, rocksdb_write_mode cannot be set to none."); + } dbRootPath_ = std::move(param.dbRootPath); currentWorkerId_ = std::move(param.currWorkerId); akSkManager_ = param.akSkManager; @@ -101,7 +111,9 @@ Status ReplicaManager::Init(ReplicaManagerParam param) etcdCM_ = param.etcdCM; masterWorkerService_ = param.masterWorkerService; workerWorkerService_ = param.workerWorkerService; + rpcSessionManager_ = param.rpcSessionManager; isOcEnabled_ = param.isOcEnabled; + isScEnabled_ = param.isScEnabled; bool multiReplicaEnabled = MultiReplicaEnabled(); if (multiReplicaEnabled) { const int queueSize = 1024; @@ -139,21 +151,22 @@ void ReplicaManager::SubscribeEvent() ReplicaEvent::GetInstance().AddSubscriber(REPLICA_MANAGER, [this](mvccpb::Event &event) { return EnqueEvent(event); }); - HashRingEvent::ClusterInitFinish::GetInstance().AddSubscriber( - REPLICA_MANAGER, [this](const std::string &primaryWorkerId, const std::string &standbyWorkerId) { - if (currentWorkerId_ != primaryWorkerId) { - return; - } - INJECT_POINT("worker.ClusterInitFinish", [&] { - ReplicaGroupPb replicaGroupPb = - CreateReplicaGroupPb(standbyWorkerId, { primaryWorkerId, standbyWorkerId }); - LOG_IF_ERROR(PutReplicaGroupToEtcd(primaryWorkerId, replicaGroupPb), "PutReplicaGroupToEtcd failed"); - }); - - LOG_IF_ERROR(AdjustReplicaLocationImpl(primaryWorkerId, { standbyWorkerId }), - "AdjustReplicaLocationImpl failed"); + HashRingEvent::ClusterInitFinish::GetInstance().AddSubscriber(REPLICA_MANAGER, [this](const std::string + &primaryWorkerId, + const std::string + &standbyWorkerId) { + if (currentWorkerId_ != primaryWorkerId) { + return; + } + INJECT_POINT("worker.ClusterInitFinish", [&] { + ReplicaGroupPb replicaGroupPb = CreateReplicaGroupPb(standbyWorkerId, { primaryWorkerId, standbyWorkerId }); + LOG_IF_ERROR(PutReplicaGroupToEtcd(primaryWorkerId, replicaGroupPb), "PutReplicaGroupToEtcd failed"); }); + LOG_IF_ERROR(AdjustReplicaLocationImpl(primaryWorkerId, { standbyWorkerId }), + "AdjustReplicaLocationImpl failed"); + }); + NodeTimeoutEvent::GetInstance().AddSubscriber( REPLICA_MANAGER, [this](const std::string &workerAddr, bool, bool, bool isOtherAzNode) { if (!isOtherAzNode) { @@ -187,13 +200,15 @@ bool ReplicaManager::MultiReplicaEnabled() return FLAGS_enable_meta_replica && (etcdCM_ == nullptr || !etcdCM_->IsCentralized()); } -Status ReplicaManager::CreateMetaManager(const std::string &dbName, RocksStore *objectRocksStore - ) +Status ReplicaManager::CreateMetaManager(const std::string &dbName, RocksStore *objectRocksStore, + RocksStore *streamRocksStore) { + (void)streamRocksStore; auto iter = metadataManagers_.find(dbName); if (iter == metadataManagers_.end()) { Timer timer; double ocElapsed = 0; + double scElapsed = 0; MetadataManager metadataManager; if (isOcEnabled_) { // create OCMetadataManager instance @@ -207,8 +222,20 @@ Status ReplicaManager::CreateMetaManager(const std::string &dbName, RocksStore * metadataManager.oc = std::move(oc); } + if (isScEnabled_) { + // create SCMetadataManager instance + auto sc = std::make_shared(masterAddress_, akSkManager_, rpcSessionManager_, + etcdCM_, streamRocksStore, dbName); + LOG(INFO) << "Start init SCMetadataManager for " << dbName; + timer.Reset(); + RETURN_IF_NOT_OK(sc->Init()); + scElapsed = timer.ElapsedMilliSecond(); + metadataManager.sc = std::move(sc); + } + metadataManagers_.emplace(dbName, std::move(metadataManager)); - LOG(INFO) << "OCMetadataManager init cost:" << ocElapsed << "ms for " << dbName; + LOG(INFO) << "OCMetadataManager init cost:" << ocElapsed << "ms, SCMetadataManager init cost:" << scElapsed + << "ms for " << dbName; } return Status::OK(); } @@ -249,6 +276,16 @@ Status ReplicaManager::GetOcMetadataManager(const std::string &dbName, return Status::OK(); } +Status ReplicaManager::GetScMetadataManager(const std::string &dbName, + std::shared_ptr &scMetadataManager) +{ + MetadataManager metadataManager; + RETURN_IF_NOT_OK(GetMetadataManager(dbName, metadataManager)); + RETURN_RUNTIME_ERROR_IF_NULL(metadataManager.sc); + scMetadataManager = metadataManager.sc; + return Status::OK(); +} + bool ReplicaManager::HaveAsyncMetaRequest() { std::shared_lock locker(mutex_); @@ -377,6 +414,10 @@ Status ReplicaManager::AddOrSwitchTo(const std::string &dbName, ReplicaType type RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Replica::CreateOcTable(replica->GetObjectRocksStore()), "Replica create oc table failed"); } + if (isScEnabled_) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Replica::CreateScTable(replica->GetStreamRocksStore()), + "Replica create sc table failed"); + } iter = replicas_.emplace(dbName, std::move(replica)).first; } auto &replica = *iter->second; @@ -387,7 +428,7 @@ Status ReplicaManager::AddOrSwitchTo(const std::string &dbName, ReplicaType type LOG(INFO) << "Replica " << dbName << ", switch to " << Replica::ReplicaTypeToString(type); replica.SetReplicaType(type); if (type == ReplicaType::Primary) { - RETURN_IF_NOT_OK(CreateMetaManager(dbName, replica.GetObjectRocksStore())); + RETURN_IF_NOT_OK(CreateMetaManager(dbName, replica.GetObjectRocksStore(), replica.GetStreamRocksStore())); } else { RETURN_IF_NOT_OK(DestroyMetaManager(dbName)); } @@ -507,7 +548,8 @@ Status ReplicaManager::AddDelayElectionTask(uint64_t delaySec, const std::string return Election(dbName); } TimerQueue::TimerImpl timer; - TimerQueue::GetInstance()->AddTimer(delaySec * SECTOMILLI, [this, dbName] { Election(dbName); }, timer); + TimerQueue::GetInstance()->AddTimer( + delaySec * SECTOMILLI, [this, dbName] { Election(dbName); }, timer); std::lock_guard locker(mutex_); timers_.emplace(dbName, std::make_unique(timer)); @@ -1163,7 +1205,8 @@ bool ReplicaManager::CheckMetaEmpty(const std::string &dbName) } auto oc = metadataManager.oc; - if (oc != nullptr && !oc->CheckMetaTableEmpty()) { + auto sc = metadataManager.sc; + if ((oc != nullptr && !oc->CheckMetaTableEmpty()) || (sc != nullptr && !sc->CheckMetaTableEmpty())) { return false; } @@ -1241,6 +1284,20 @@ Status ReplicaManager::CheckMappingExpired(std::set &expiredUuids) objKeys, &isUuidsEmpty); } + if (metadataManager.sc != nullptr && !isUuidsEmpty) { + metadataManager.sc->GetMetasMatch( + [&expiredUuids, &isUuidsEmpty](const std::string &streamName) { + std::string curUuid; + if (TrySplitWorkerIdFromObjecId(streamName, curUuid).IsOk() && ContainsKey(expiredUuids, curUuid)) { + expiredUuids.erase(curUuid); + isUuidsEmpty = expiredUuids.empty(); + return true; + } + return false; + }, + objKeys, &isUuidsEmpty); + } + return Status::OK(); } diff --git a/src/datasystem/master/replica_manager.h b/src/datasystem/master/replica_manager.h index 6e1afdc..e930a0d 100644 --- a/src/datasystem/master/replica_manager.h +++ b/src/datasystem/master/replica_manager.h @@ -30,6 +30,7 @@ #include "datasystem/common/util/queue/queue.h" #include "datasystem/master/meta_addr_info.h" #include "datasystem/master/object_cache/oc_metadata_manager.h" +#include "datasystem/master/stream_cache/sc_metadata_manager.h" #include "datasystem/worker/object_cache/master_worker_oc_service_impl.h" #include "datasystem/worker/object_cache/worker_worker_oc_service_impl.h" @@ -47,11 +48,14 @@ struct ReplicaManagerParam { EtcdClusterManager *etcdCM; object_cache::MasterWorkerOCServiceImpl *masterWorkerService; object_cache::WorkerWorkerOCServiceImpl *workerWorkerService; + std::shared_ptr rpcSessionManager; bool isOcEnabled; + bool isScEnabled; }; struct MetadataManager { std::shared_ptr oc; + std::shared_ptr sc; void Shutdown(); }; @@ -103,6 +107,15 @@ public: Status GetOcMetadataManager(const std::string &dbName, std::shared_ptr &ocMetadataManager); + /** + * @brief Get the ScMetadataManager instance. + * @param[in] dbName The rocksdb name. + * @param[out] scMetadataManager The ScMetadataManager instance. + * @return Status of this call. + */ + Status GetScMetadataManager(const std::string &dbName, + std::shared_ptr &scMetadataManager); + /** * @brief Check whether there are any requests for asynchronously writing metadata to ETCD. * @return True if there are unfinished async requests. @@ -275,9 +288,11 @@ protected: * @brief Create the metadata manager instance. * @param[in] dbName The rocksdb name. * @param[in] objectRocksStore The RocksStore instance for object. + * @param[in] streamRocksStore The RocksStore instance for stream. * @return Status of this call. */ - virtual Status CreateMetaManager(const std::string &dbName, RocksStore *objectRocksStore); + virtual Status CreateMetaManager(const std::string &dbName, RocksStore *objectRocksStore, + RocksStore *streamRocksStore); /** * @brief Destroy the metadata manager instance. @@ -472,7 +487,9 @@ protected: object_cache::MasterWorkerOCServiceImpl *masterWorkerService_; object_cache::WorkerWorkerOCServiceImpl *workerWorkerService_; // rpc seessin manager for stream + std::shared_ptr rpcSessionManager_; bool isOcEnabled_; + bool isScEnabled_; bool isNewNode_; std::unique_ptr channel_; diff --git a/src/datasystem/master/stream_cache/CMakeLists.txt b/src/datasystem/master/stream_cache/CMakeLists.txt new file mode 100644 index 0000000..ed66912 --- /dev/null +++ b/src/datasystem/master/stream_cache/CMakeLists.txt @@ -0,0 +1,26 @@ +add_subdirectory(store) +set(MASTER_SC_SRCS + master_sc_service_impl.cpp + master_worker_sc_api.cpp + rpc_session_manager.cpp + stream_metadata.cpp + topology_manager.cpp + sc_metadata_manager.cpp + sc_migrate_metadata_manager.cpp + sc_notify_worker_manager.cpp + ) + +set(MASTER_SC_DEPEND_LIBS + common_log + common_util + common_rpc_zmq + common_event_loop + master_stream_cache_store + master_stream_protos + cluster_manager + ) + +add_library(master_stream_cache STATIC ${MASTER_SC_SRCS}) +target_link_libraries(master_stream_cache PRIVATE ${MASTER_SC_DEPEND_LIBS}) +add_dependencies(master_stream_cache + master_stream_protos) \ No newline at end of file diff --git a/src/datasystem/master/stream_cache/master_sc_service_impl.cpp b/src/datasystem/master/stream_cache/master_sc_service_impl.cpp new file mode 100644 index 0000000..744b650 --- /dev/null +++ b/src/datasystem/master/stream_cache/master_sc_service_impl.cpp @@ -0,0 +1,346 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement the stream cache services on the master. + */ +#include "datasystem/master/stream_cache/master_sc_service_impl.h" + +#include + +#include "datasystem/common/log/log_helper.h" +#include "datasystem/common/stream_cache/util.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/thread_local.h" +#include "datasystem/worker/stream_cache/master_worker_sc_service_impl.h" +#include "datasystem/master/stream_cache/sc_metadata_manager.h" +#include "datasystem/master/stream_cache/sc_migrate_metadata_manager.h" + +DS_DEFINE_int32(master_sc_thread_num, 128, "Max number of threads for (non rpc) master stream cache service work"); + +namespace datasystem { +namespace master { +MasterSCServiceImpl::MasterSCServiceImpl(const HostPort &masterAddress, std::shared_ptr akSkManager, + ReplicaManager *replicaManager) + : MasterSCService(masterAddress), akSkManager_(std::move(akSkManager)), replicaManager_(replicaManager) +{ +} + +void MasterSCServiceImpl::Shutdown() +{ + LOG(INFO) << "MasterSCServiceImpl shutdown."; + SCMigrateMetadataManager::Instance().Shutdown(); +} + +Status MasterSCServiceImpl::Init() +{ + RETURN_IF_NOT_OK(MasterSCService::Init()); + const size_t MIN_THREADS = 1; + size_t minThreads = std::min(MIN_THREADS, FLAGS_master_sc_thread_num); + RETURN_IF_EXCEPTION_OCCURS(threadPool_ = + std::make_unique(minThreads, FLAGS_master_sc_thread_num, "MScThreads")); + RETURN_IF_NOT_OK(SCMigrateMetadataManager::Instance().Init(GetLocalAddr(), akSkManager_, etcdCM_, replicaManager_)); + VLOG(SC_NORMAL_LOG_LEVEL) << "MasterSCServiceImpl initialization success"; + return Status::OK(); +} + +Status MasterSCServiceImpl::CreateProducer( + std::shared_ptr> serverApi) +{ + CreateProducerReqPb req; + CreateProducerRspPb rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + return CreateProducerImpl(serverApi, req, rsp); +} + +Status MasterSCServiceImpl::CreateProducerImpl( + const std::shared_ptr> &serverApi, + const CreateProducerReqPb &req, CreateProducerRspPb &rsp) +{ + Timer timer(req.timeout()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + LOG(INFO) << FormatString("Master receive create producer request: <%s> with timeout: %d", + LogHelper::IgnoreSensitive(req.producer_meta()), req.timeout()); + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + std::shared_ptr scMetadataManager; + INJECT_POINT("master.CreateProducer"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + if (serverApi) { + // Launch child thread to run the real logic and then return this thread. This avoids the rpc thread being + // active during the logic of this request so that it can be re-used by other requests. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration.Init(timer.GetRemainingTimeMs()); + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + Status rc = scMetadataManager->CreateProducer(req, rsp); + CheckErrorReturn( + rc, rsp, FormatString("[S:%s] CreateProducerImpl failed with rc ", req.producer_meta().stream_name()), + serverApi); + }); + } else { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(scMetadataManager->CreateProducer(req, rsp), "CreateProducer failed"); + } + LOG(INFO) << FormatString("Master create producer request: <%s> Successful", + LogHelper::IgnoreSensitive(req.producer_meta())); + return Status::OK(); +} + +Status MasterSCServiceImpl::CloseProducer( + std::shared_ptr> serverApi) +{ + CloseProducerReqPb req; + CloseProducerRspPb rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + return CloseProducerImpl(serverApi, req, rsp); +} + +Status MasterSCServiceImpl::CloseProducerImpl( + const std::shared_ptr> &serverApi, + const CloseProducerReqPb &req, CloseProducerRspPb &rsp) +{ + Timer timer(req.timeout()); + INJECT_POINT("master.CloseProducerImpl"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + std::string infoMsg; + // If there's more than one producer to close, only log the count. If there is only one, show the detail + if (req.producer_infos_size() == 1) { + infoMsg = FormatString("S:%s", req.producer_infos(0).stream_name()); + } else { + infoMsg = FormatString("Number of producers: %d", req.producer_infos_size()); + } + LOG(INFO) << "Master receive close producer request: " << infoMsg<< " with timeout: "< scMetadataManager; + INJECT_POINT("master.CloseProducer"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + if (serverApi) { + // Launch child thread to run the real logic and then return this thread. This avoids the rpc thread being + // active during the logic of this request so that it can be re-used by other requests. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration.Init(timer.GetRemainingTimeMs()); + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + Status rc = scMetadataManager->CloseProducer(req, rsp); + CheckErrorReturn(rc, rsp, "CloseProducerImpl failed with rc", serverApi); + }); + } else { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(scMetadataManager->CloseProducer(req, rsp), "CloseProducer failed"); + } + + return Status::OK(); +} + +Status MasterSCServiceImpl::Subscribe( + std::shared_ptr> serverApi) +{ + SubscribeReqPb req; + SubscribeRspPb rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + return SubscribeImpl(serverApi, req, rsp); +} + +Status MasterSCServiceImpl::SubscribeImpl( + const std::shared_ptr> &serverApi, + const SubscribeReqPb &req, SubscribeRspPb &rsp) +{ + Timer timer(req.timeout()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + LOG(INFO) << FormatString("Master receive subscribe request: <%s> with timeout: %d", + LogHelper::IgnoreSensitive(req.consumer_meta()), req.timeout()); + std::shared_ptr scMetadataManager; + INJECT_POINT("master.Subscribe"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + if (serverApi) { + // Launch child thread to run the real logic and then return this thread. This avoids the rpc thread being + // active during the logic of this request so that it can be re-used by other requests. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration.Init(timer.GetRemainingTimeMs()); + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + Status rc = scMetadataManager->Subscribe(req, rsp); + CheckErrorReturn( + rc, rsp, FormatString("[S:%s] SubscribeImpl failed with rc", + req.consumer_meta().stream_name()), serverApi); + }); + } else { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(scMetadataManager->Subscribe(req, rsp), "Subscribe failed"); + } + return Status::OK(); +} + +Status MasterSCServiceImpl::CloseConsumer( + std::shared_ptr> serverApi) +{ + CloseConsumerReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + CloseConsumerRspPb rsp; + return CloseConsumerImpl(serverApi, req, rsp); +} + +Status MasterSCServiceImpl::CloseConsumerImpl( + const std::shared_ptr> &serverApi, + const CloseConsumerReqPb &req, CloseConsumerRspPb &rsp) +{ + Timer timer(req.timeout()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + LOG(INFO) << FormatString("Master receive close consumer request: <%s> with timeout: %d", + LogHelper::IgnoreSensitive(req), req.timeout()); + std::shared_ptr scMetadataManager; + INJECT_POINT("master.CloseConsumer"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + if (serverApi) { + // Launch child thread to run the real logic and then return this thread. This avoids the rpc thread being + // active during the logic of this request so that it can be re-used by other requests. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration.Init(timer.GetRemainingTimeMs()); + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + Status rc = scMetadataManager->CloseConsumer(req, rsp); + CheckErrorReturn(rc, rsp, FormatString("[S:%s] CloseConsumer failed with rc", + req.consumer_meta().stream_name()), serverApi); + }); + } else { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(scMetadataManager->CloseConsumer(req, rsp), "CloseConsumer failed"); + scTimeoutDuration.Reset(); + } + + return Status::OK(); +} + +Status MasterSCServiceImpl::DeleteStream(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp) +{ + scTimeoutDuration.Init(req.timeout()); + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + std::shared_ptr scMetadataManager; + INJECT_POINT("master.DeleteStream"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + LOG(INFO) << FormatString("Master receive delete stream request: <%s> with timeout: %d", + LogHelper::IgnoreSensitive(req), req.timeout()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(scMetadataManager->DeleteStream(req, rsp), "DeleteStream failed"); + return Status::OK(); +} + +Status MasterSCServiceImpl::QueryGlobalProducersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Master receive query producer number request: <%s>", + LogHelper::IgnoreSensitive(req)); + std::shared_ptr scMetadataManager; + INJECT_POINT("master.QueryGlobalProducersNum"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(scMetadataManager->QueryGlobalProducersNum(req, rsp), + "QueryGlobalProducersNum failed"); + return Status::OK(); +} + +Status MasterSCServiceImpl::QueryGlobalConsumersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Master receive query consumer number request: <%s>", + LogHelper::IgnoreSensitive(req)); + std::shared_ptr scMetadataManager; + INJECT_POINT("master.QueryGlobalConsumersNum"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(scMetadataManager->QueryGlobalConsumersNum(req, rsp), + "QueryGlobalConsumersNum failed"); + return Status::OK(); +} + +Status MasterSCServiceImpl::StartCheckMetadata() +{ + bool isRestart = false; + RETURN_IF_NOT_OK(etcdCM_->IsRestart(isRestart)); + if (!isRestart || !etcdCM_->IsEtcdAvailableWhenStart()) { + return Status::OK(); + } + RETURN_IF_NOT_OK(etcdCM_->CheckWaitNodeTableComplete()); + std::vector nodeAddrs; + // Why does it get node list from etcd instead of cluster manager or hashring? + // Because in case of centralized master, we have no hashring, and can get list from only etcd and cluster manager. + // We want a complete list without absence of any running node. Since the list in cluster manager is from etcd, + // getting list directly from etcd can avoid possible delay in rpc resulting in miss of nodes. + RETURN_IF_NOT_OK(etcdCM_->GetNodeAddrListFromEtcd(nodeAddrs)); + const size_t maxThreadNum = 20; + // Add a condition to forbid thread pool size creation with minThreadNum 0 to avoid cpp runtime exception. + if (nodeAddrs.empty()) { + return Status::OK(); + } + auto checkPool = std::make_unique(std::min(maxThreadNum, nodeAddrs.size()), 0, "MScCheck"); + // broadcast over all active masters + std::vector> rcs(nodeAddrs.size()); + for (size_t i = 0; i < nodeAddrs.size(); ++i) { + rcs[i] = checkPool->Submit([this, i, &nodeAddrs]() { + auto func = [i, &nodeAddrs](const std::string &dbName, MetadataManager metadataMansger) { + auto traceGuard = Trace::Instance().SetTraceNewID(GetStringUuid() + "-sc-check"); + LOG(INFO) << "Check metadata for db name " << dbName; + if (metadataMansger.sc != nullptr) { + metadataMansger.sc->StartCheckMetadata(nodeAddrs[i]); + } + return Status::OK(); + }; + (void)replicaManager_->ApplyForAllMetaManager(func); + }); + } + // wait for the end of all reconciliations + // we do not check the results and ignore the failed status + for (const auto &rc : rcs) { + rc.wait(); + } + return Status::OK(); +} + +Status MasterSCServiceImpl::MigrateSCMetadata(const MigrateSCMetadataReqPb &req, MigrateSCMetadataRspPb &rsp) +{ + masterOperationTimeCost.Clear(); + Timer timer; + auto copyReq = req; + for (int i = 0; i < copyReq.stream_metas_size(); ++i) { + auto *meta = copyReq.mutable_stream_metas(i); + if (meta != nullptr) { + meta->clear_notifications(); + } + } + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(copyReq), "AK/SK failed."); + std::shared_ptr scMetadataManager; + INJECT_POINT("master.MigrateSCMetadata"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(GetDbName(), scMetadataManager), + "GetScMetadataManager failed"); + RETURN_IF_NOT_OK(scMetadataManager->SaveMigrationMetadata(req, rsp)); + masterOperationTimeCost.Append("Total MigrateMetadata", timer.ElapsedMilliSecond()); + LOG(INFO) << FormatString("The operations of SC master MigrateMetadata %s", masterOperationTimeCost.GetInfo()); + return Status::OK(); +} + +std::string MasterSCServiceImpl::GetDbName() +{ + if (replicaManager_->MultiReplicaEnabled()) { + return g_MetaRocksDbName; + } + return replicaManager_->GetCurrentWorkerUuid(); +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/master_sc_service_impl.h b/src/datasystem/master/stream_cache/master_sc_service_impl.h new file mode 100644 index 0000000..c586c64 --- /dev/null +++ b/src/datasystem/master/stream_cache/master_sc_service_impl.h @@ -0,0 +1,204 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Implement the stream cache services on the master. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_MASTER_SC_SERVICE_IMPL_H +#define DATASYSTEM_MASTER_STREAM_CACHE_MASTER_SC_SERVICE_IMPL_H + +#include + +#include + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/master/replica_manager.h" +#include "datasystem/master/stream_cache/master_worker_sc_api.h" +#include "datasystem/master/stream_cache/rpc_session_manager.h" +#include "datasystem/master/stream_cache/stream_metadata.h" +#include "datasystem/protos/master_stream.service.rpc.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" + +namespace datasystem { +namespace master { +class MasterSCServiceImpl : public MasterSCService { +public: + MasterSCServiceImpl(const HostPort &masterAddress, std::shared_ptr akSkManager, + ReplicaManager *replicaManager); + MasterSCServiceImpl() = default; + ~MasterSCServiceImpl() override = default; + + /** + * @brief Shutdown the sc metadata manager module. + */ + static void Shutdown(); + + /** + * @brief Initialize master service. + * @return Status of call. + */ + Status Init() override; + + /** + * @brief Create a producer, i.e., register a publisher to a stream. Similar to worker::CreateProducer. + * This version is a wrapper that will redirect the call to the CreateProducerImpl() function. + * @param[in] serverApi Used to read request from rpc client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CreateProducer( + std::shared_ptr> serverApi) override; + + /** + * @brief The real work of the CreateProducer is started in this function. + * @param[in] serverApi Used to read request from rpc client and write response to client. nullptr can be passed + * here if the caller is directly calling this function (not through rpc). + * @param[in] req The create producer request details. + * @param[out] rsp The create producer response. + * @return K_OK on success; the error code otherwise. + */ + Status CreateProducerImpl( + const std::shared_ptr> &serverApi, + const CreateProducerReqPb &req, CreateProducerRspPb &rsp); + + /** + * @brief Close a producer, force flushing and page seal, unregister a publisher to a stream. + * Similar to worker::CloseProducer. This is a wrapper version that will redirect the call to the + * CloseProducerImpl() function. + * @param[in] serverApi Used to read request from rpc client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducer( + std::shared_ptr> serverApi) override; + + /** + * @brief The real work of the CloseProducer is started in this function. + * @param[in] serverApi Used to read request from rpc client and write response to client. nullptr can be passed + * here if the caller is directly calling this function (not through rpc). + * @param[in] req The close consumer request details + * @param[out] rsp The close consumer response details + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducerImpl( + const std::shared_ptr> &serverApi, + const CloseProducerReqPb &req, CloseProducerRspPb &rsp); + + /** + * @brief Subscribe to a stream, using a subscription name, i.e., register a consumer to a subscription. + * Similar to worker::Subscribe. This is a wrapper version that will redirect the call to the SubscribeImpl() + * function. + * @param[in] serverApi Used to read request from rpc client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status Subscribe(std::shared_ptr> serverApi) override; + + /** + * @brief The real work for the Subscribe is started in this function. + * @param[in] serverApi Used to read request from rpc client and write response to client. nullptr can be passed + * here if the caller is directly calling this function (not through rpc). + * @param[in] req The subscribe request details + * @param[out] rsp The subscribe response info + * @return K_OK on success; the error code otherwise. + */ + Status SubscribeImpl(const std::shared_ptr> &serverApi, + const SubscribeReqPb &req, SubscribeRspPb &rsp); + + /** + * @brief Close a consumer, trigger subscription cursor change and unregister a subscribed consumer to a stream. + * Similar to worker::CloseConsumer. This is a wrapper version that will redirect the call to the + * CloseConsumerImpl() function. + * @param[in] serverApi Used to read request from rpc client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CloseConsumer( + std::shared_ptr> serverApi) override; + + /** + * @brief The real work of the CloseConsumer is started in this function. + * @param[in] serverApi Used to read request from rpc client and write response to client. nullptr can be passed + * here if the caller is directly calling this function (not through rpc). + * @param[in] req The close consumer request details. + * @param[out] rsp The close consumer response. + * @return K_OK on success; the error code otherwise. + */ + Status CloseConsumerImpl( + const std::shared_ptr> &serverApi, + const CloseConsumerReqPb &req, CloseConsumerRspPb &rsp); + + /** + * @brief Delete a stream. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status DeleteStream(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp) override; + + /** + * @brief Query global producers for a stream. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalProducersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) override; + + /** + * @brief Query global consumers for a stream. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalConsumersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) override; + + /** + * @brief Check metadata when master starts + * @return K_OK on success; the error code otherwise. + */ + Status StartCheckMetadata(); + + /** + * @brief Setter function to assign the cluster manager back pointer. + * @param[in] etcdCM The cluster manager pointer to assign + */ + void SetClusterManager(EtcdClusterManager *etcdCM) + { + etcdCM_ = etcdCM; + } + + /** + * @brief Migrate stream metadata. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status MigrateSCMetadata(const MigrateSCMetadataReqPb &req, MigrateSCMetadataRspPb &rsp) override; + +protected: + /** + * @brief Get the current db name. + * @return std::string The db name. + */ + std::string GetDbName(); + + std::shared_ptr akSkManager_{ nullptr }; + EtcdClusterManager *etcdCM_{ nullptr }; // back pointer to the cluster manager + std::unique_ptr threadPool_{ nullptr }; + ReplicaManager *replicaManager_; +}; +} // namespace master +} // namespace datasystem + +#endif // DATASYSTEM_MASTER_STREAM_CACHE_MASTER_SC_SERVICE_IMPL_H diff --git a/src/datasystem/master/stream_cache/master_worker_sc_api.cpp b/src/datasystem/master/stream_cache/master_worker_sc_api.cpp new file mode 100644 index 0000000..27b0035 --- /dev/null +++ b/src/datasystem/master/stream_cache/master_worker_sc_api.cpp @@ -0,0 +1,393 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The interface of file cache worker descriptor. + */ +#include "datasystem/master/stream_cache/master_worker_sc_api.h" + +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/common/rpc/rpc_stub_base.h" +#include "datasystem/common/rpc/rpc_stub_cache_mgr.h" +#include "datasystem/worker/stream_cache/master_worker_sc_service_impl.h" + +namespace datasystem { +namespace master { + +// Base class methods +MasterWorkerSCApi::MasterWorkerSCApi(HostPort localMasterAddress, std::shared_ptr akSkManager) + : localMasterAddress_(std::move(localMasterAddress)), akSkManager_(std::move(akSkManager)) +{ +} + +std::shared_ptr MasterWorkerSCApi::CreateMasterWorkerSCApi( + const HostPort &hostPort, const HostPort &localHostPort, const std::shared_ptr &akSkManager, + worker::stream_cache::MasterWorkerSCServiceImpl *service) +{ + if (hostPort != localHostPort) { + LOG(INFO) << "Master and worker are not collocated. Creating a MasterWorkerSCApi as RPC-based api."; + return std::make_shared(hostPort, localHostPort, akSkManager); + } + if (service == nullptr) { + LOG(INFO) << "Master and worker are collocated but the worker service is not provided. Local bypass disabled."; + return std::make_shared(hostPort, localHostPort, akSkManager); + } + LOG(INFO) << "Master and worker are collocated. Creating a MasterWorkerOCApi with local bypass optimization."; + return std::make_shared(service, localHostPort, akSkManager); +} + +void MasterWorkerSCApi::ConstructSyncConsumerNodePb(const std::string &streamName, + const std::vector &consumerMetas, + const HostPort &src, const RetainDataState::State retainData, + bool isRecon, SyncConsumerNodeReqPb &req) noexcept +{ + req.set_stream_name(streamName); + req.set_retain_data(retainData); + req.set_is_reconciliation(isRecon); + for (const auto &consumerMeta : consumerMetas) { + ConsumerMetaPb *consumerMetaPb = req.add_consumer_meta_vector(); + if (consumerMetaPb == nullptr) { + LOG(ERROR) << FormatString("[S:%s] add_consumer_meta_vector pointer does not exit", streamName); + + return; + } + // ConsumerMetaPb = (streamName, workerAddress, consumerId, subConfig, lastAckCursor). + consumerMetaPb->set_stream_name(consumerMeta.stream_name()); + *consumerMetaPb->mutable_worker_address() = consumerMeta.worker_address(); + + consumerMetaPb->set_consumer_id(consumerMeta.consumer_id()); + *consumerMetaPb->mutable_sub_config() = consumerMeta.sub_config(); + + consumerMetaPb->set_last_ack_cursor(consumerMeta.last_ack_cursor()); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "Stream:<%s>, Pub node:<%s>, Operation:, nodes size:<%d>", streamName, src.ToString(), + req.consumer_meta_vector_size()); +} + +void MasterWorkerSCApi::ConstructSyncPubNodePb(const std::string &streamName, const std::set &pubTable, + const HostPort &src, bool isRecon, SyncPubNodeReqPb &req) noexcept +{ + req.set_stream_name(streamName); + req.set_is_reconciliation(isRecon); + for (auto &pubNode : pubTable) { + auto workerNodePb = req.add_worker_address_vector(); + if (workerNodePb == nullptr) { + LOG(ERROR) << FormatString("[S:%s] add_worker_address_vector pointer does not exit", streamName); + return; + } + workerNodePb->set_host(pubNode.Host()); + workerNodePb->set_port(pubNode.Port()); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Stream:<%s>, Dest:<%s>, Operation:, Size:<%d>", + streamName, src.ToString(), req.worker_address_vector_size()); +} + +// MasterRemoteWorkerSCApi methods +MasterRemoteWorkerSCApi::MasterRemoteWorkerSCApi(HostPort workerAddress, const HostPort &localAddress, + std::shared_ptr akSkManager) + : MasterWorkerSCApi(localAddress, std::move(akSkManager)), workerAddress_(std::move(workerAddress)) +{ +} + +Status MasterRemoteWorkerSCApi::Init() +{ + std::shared_ptr rpcStub; + RETURN_IF_NOT_OK( + RpcStubCacheMgr::Instance().GetStub(workerAddress_, StubType::MASTER_WORKER_SC_SVC, rpcStub)); + rpcSession_ = std::dynamic_pointer_cast(rpcStub); + RETURN_RUNTIME_ERROR_IF_NULL(rpcSession_); + return Status::OK(); +} + +MasterWorkerSCApiType MasterRemoteWorkerSCApi::TypeId() +{ + return MasterWorkerSCApiType::MasterRemoteWorkerSCApi; +} + +Status MasterRemoteWorkerSCApi::DelStreamContextBroadcast(const std::string &streamName, bool forceDelete) +{ + DelStreamContextReqPb req; + req.set_stream_name(streamName); + req.set_force_delete(forceDelete); + INJECT_POINT("MasterRemoteWorkerSCApi.DelStreamContextBroadcast.sleep"); + DelStreamContextRspPb rsp; + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->DelStreamContext(opts, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Delete stream broadcast success.", LogPrefix(), streamName); + return Status::OK(); +} + +Status MasterRemoteWorkerSCApi::DelStreamContextBroadcastAsyncWrite(const std::string &streamName, bool forceDelete, + int64_t &tagId) +{ + DelStreamContextReqPb req; + req.set_stream_name(streamName); + req.set_force_delete(forceDelete); + INJECT_POINT("MasterRemoteWorkerSCApi.DelStreamContextBroadcast.sleep"); + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->DelStreamContextAsyncWrite(opts, req, tagId)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Delete stream async broadcast send success.", LogPrefix(), + streamName); + return Status::OK(); +} + +Status MasterRemoteWorkerSCApi::DelStreamContextBroadcastAsyncRead(int64_t tagId, RpcRecvFlags flags) +{ + DelStreamContextRspPb rsp; + RETURN_IF_NOT_OK(rpcSession_->DelStreamContextAsyncRead(tagId, rsp, flags)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Delete stream async broadcast receive success.", LogPrefix()); + return Status::OK(); +} + +Status MasterRemoteWorkerSCApi::SyncPubNode(const std::string &streamName, const std::set &pubNodeSet, + bool isRecon) +{ + SyncPubNodeReqPb req; + ConstructSyncPubNodePb(streamName, pubNodeSet, workerAddress_, isRecon, req); + + SyncPubNodeRspPb rsp; + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->SyncPubNode(opts, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] SyncPubNode success, node size:[%d]", LogPrefix(), streamName, + pubNodeSet.size()); + return Status::OK(); +} + +Status MasterRemoteWorkerSCApi::SyncConsumerNode(const std::string &streamName, + const std::vector &consumerMetas, + const RetainDataState::State retainData, bool isRecon) +{ + SyncConsumerNodeReqPb req; + ConstructSyncConsumerNodePb(streamName, consumerMetas, workerAddress_, retainData, isRecon, req); + + SyncConsumerNodeRspPb rsp; + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->SyncConsumerNode(opts, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] SyncConsumer success, consumer size:[%d]", LogPrefix(), + streamName, consumerMetas.size()); + return Status::OK(); +} + +Status MasterRemoteWorkerSCApi::ClearAllRemotePub(const std::string &streamName) +{ + ClearRemoteInfoReqPb req; + req.set_stream_name(streamName); + + ClearRemoteInfoRspPb rsp; + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Send ClearAllRemotePub request to worker", LogPrefix(), + streamName); + RETURN_IF_NOT_OK(rpcSession_->ClearAllRemotePub(opts, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] ClearAllRemotePub to worker success", LogPrefix(), + streamName); + return Status::OK(); +} + +std::string MasterRemoteWorkerSCApi::LogPrefix() const +{ + return FormatString("MasterWorkerApi, EndPoint:%s", workerAddress_.ToString()); +} + +Status MasterRemoteWorkerSCApi::QueryMetadata( + std::unique_ptr> &stream) +{ + return rpcSession_->QueryMetadata(&stream); +} + +Status MasterRemoteWorkerSCApi::UpdateTopoNotification(UpdateTopoNotificationReq &req) +{ + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + INJECT_POINT("master.UpdateTopoNotification.setTimeout", [&opts](int timeout) { + LOG(INFO) << "set rpc timeout to " << timeout; + opts.SetTimeout(timeout); + return Status::OK(); + }); + UpdateTopoNotificationRsp rsp; + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + return rpcSession_->UpdateTopoNotification(opts, req, rsp); +} + +Status MasterRemoteWorkerSCApi::ClearAllRemotePubAsynWrite(const std::string &streamName, int64_t &tagId) +{ + ClearRemoteInfoReqPb req; + req.set_stream_name(streamName); + + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s]Asyn write ClearAllRemotePub request to worker", LogPrefix(), + streamName); + Status rc = rpcSession_->ClearAllRemotePubAsyncWrite(opts, req, tagId); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s]Asyn write ClearAllRemotePub to worker success", LogPrefix(), + streamName); + return rc; +} + +Status MasterRemoteWorkerSCApi::ClearAllRemotePubAsynRead(int64_t tagId, RpcRecvFlags flags) +{ + ClearRemoteInfoRspPb rsp; + Status rc = rpcSession_->ClearAllRemotePubAsyncRead(tagId, rsp, flags); + return rc; +} + +// MasterLocalWorkerSCApi methods +MasterLocalWorkerSCApi::MasterLocalWorkerSCApi(worker::stream_cache::MasterWorkerSCServiceImpl *service, + const HostPort &localAddress, std::shared_ptr akSkManager) + : MasterWorkerSCApi(localAddress, std::move(akSkManager)), workerSC_(service) +{ +} + +Status MasterLocalWorkerSCApi::Init() +{ + RETURN_RUNTIME_ERROR_IF_NULL(workerSC_); + return Status::OK(); +} + +MasterWorkerSCApiType MasterLocalWorkerSCApi::TypeId() +{ + return MasterWorkerSCApiType::MasterLocalWorkerSCApi; +} + +Status MasterLocalWorkerSCApi::DelStreamContextBroadcast(const std::string &streamName, bool forceDelete) +{ + DelStreamContextReqPb req; + req.set_stream_name(streamName); + req.set_force_delete(forceDelete); + INJECT_POINT("MasterLocalWorkerSCApi.DelStreamContextBroadcast.sleep"); + // We use timeout to avoid deadlock + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + INJECT_POINT("MasterLocalWorkerSCApi.DelStreamContextBroadcast.setTimeout", [&req](int timeout) { + LOG(INFO) << "set rpc timeout to " << timeout; + req.set_timeout(timeout); + return Status::OK(); + }); + DelStreamContextRspPb rsp; + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(workerSC_->DelStreamContext(req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Delete stream broadcast success.", LogPrefix(), streamName); + return Status::OK(); +} + +Status MasterLocalWorkerSCApi::DelStreamContextBroadcastAsyncWrite(const std::string &streamName, bool forceDelete, + int64_t &tagId) +{ + (void)forceDelete; + (void)tagId; + // AsyncWrite and AsynRead with local bypass is not supported. + LOG(WARNING) << FormatString("[%s, S:%s] Async DelStreamContextBroadcast not supported for MasterLocalWorkerSCApi", + LogPrefix(), streamName); + return Status::OK(); +} + +Status MasterLocalWorkerSCApi::DelStreamContextBroadcastAsyncRead(int64_t tagId, RpcRecvFlags flags) +{ + (void)tagId; + (void)flags; + // AsyncWrite and AsynRead with local bypass is not supported. + LOG(WARNING) << FormatString("[%s] Async DelStreamContextBroadcast not supported for MasterLocalWorkerSCApi", + LogPrefix()); + return Status::OK(); +} + +Status MasterLocalWorkerSCApi::SyncPubNode(const std::string &streamName, const std::set &pubNodeSet, + bool isRecon) +{ + SyncPubNodeReqPb req; + ConstructSyncPubNodePb(streamName, pubNodeSet, localMasterAddress_, isRecon, req); + + SyncPubNodeRspPb rsp; + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(workerSC_->SyncPubNode(req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] SyncPubNode success, node size:[%d]", LogPrefix(), streamName, + pubNodeSet.size()); + return Status::OK(); +} + +Status MasterLocalWorkerSCApi::SyncConsumerNode(const std::string &streamName, + const std::vector &consumerMetas, + const RetainDataState::State retainData, bool isRecon) +{ + SyncConsumerNodeReqPb req; + ConstructSyncConsumerNodePb(streamName, consumerMetas, localMasterAddress_, retainData, isRecon, req); + + SyncConsumerNodeRspPb rsp; + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(workerSC_->SyncConsumerNode(req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] SyncConsumer success, consumer size:[%d]", LogPrefix(), + streamName, consumerMetas.size()); + return Status::OK(); +} + +Status MasterLocalWorkerSCApi::ClearAllRemotePub(const std::string &streamName) +{ + ClearRemoteInfoReqPb req; + req.set_stream_name(streamName); + + ClearRemoteInfoRspPb rsp; + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Send ClearAllRemotePub request to worker", LogPrefix(), + streamName); + RETURN_IF_NOT_OK(workerSC_->ClearAllRemotePub(req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] ClearAllRemotePub to worker success", LogPrefix(), + streamName); + return Status::OK(); +} + +std::string MasterLocalWorkerSCApi::LogPrefix() const +{ + // local version of the api has worker and master as same address (the local one) + return FormatString("MasterWorkerApi, EndPoint:%s", localMasterAddress_.ToString()); +} + +Status MasterLocalWorkerSCApi::QueryMetadata( + std::unique_ptr> &stream) +{ + // The local version of the api should not call this one. It should use the other syntax for local-only calls. + stream.reset(); + RETURN_STATUS_LOG_ERROR(K_RUNTIME_ERROR, "Local version of MasterWorkerApi was incorrectly used."); +} + +Status MasterLocalWorkerSCApi::QueryMetadata(const GetMetadataAllStreamReqPb &req, GetMetadataAllStreamRspPb &rsp) +{ + return workerSC_->QueryMetadata(req, rsp); +} + + +Status MasterLocalWorkerSCApi::UpdateTopoNotification(UpdateTopoNotificationReq &req) +{ + UpdateTopoNotificationRsp rsp; + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(workerSC_->UpdateTopoNotification(req, rsp)); + return Status::OK(); +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/master_worker_sc_api.h b/src/datasystem/master/stream_cache/master_worker_sc_api.h new file mode 100644 index 0000000..012ee58 --- /dev/null +++ b/src/datasystem/master/stream_cache/master_worker_sc_api.h @@ -0,0 +1,291 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The interface of file cache worker descriptor. + */ + +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_MASTER_WORKER_SC_API_H +#define DATASYSTEM_MASTER_STREAM_CACHE_MASTER_WORKER_SC_API_H + +#include +#include + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/protos/worker_stream.stub.rpc.pb.h" +namespace datasystem { +namespace worker { +namespace stream_cache { +class MasterWorkerSCServiceImpl; +} // namespace stream_cache +} // namespace worker +namespace master { + +enum class MasterWorkerSCApiType : int { MasterLocalWorkerSCApi = 0, MasterRemoteWorkerSCApi = 1 }; + +/** + * @brief The MasterWorkerSCApi is an abstract class that defines the interface for interactions with the stream cache + * worker service. + */ +class MasterWorkerSCApi { +public: + /** + * Default destructor + */ + virtual ~MasterWorkerSCApi() = default; + + /** + * @brief Initialize the MasterWorkerSCApi Object(include rpc channel). + * @return Status of the call. + */ + virtual Status Init() = 0; + + /** + * @brief The type id of MasterWorkerSCApi Object. + * @return The type id of MasterWorkerSCApi Object. + */ + virtual MasterWorkerSCApiType TypeId() = 0; + + /** + * @brief Broadcast delete-stream to all related node for a stream. + * @param[in] streamName Target stream. + * @param[in] forceDelete Force deletion. + * @return Status of the call. + */ + virtual Status DelStreamContextBroadcast(const std::string &streamName, bool forceDelete) = 0; + + /** + * @brief Async broadcast delete-stream to all related node for a stream. + * @param[in] streamName Target stream. + * @param[in] forceDelete Force deletion. + * @param[out] tagId The async RPC tag id. + * @return Status of the call. + */ + virtual Status DelStreamContextBroadcastAsyncWrite(const std::string &streamName, bool forceDelete, + int64_t &tagId) = 0; + + /** + * @brief Read the async delete-stream broadcast response. + * @param[in] tagId The async RPC tag id. + * @param[in] flags The RPC receive flag option. + * @return Status of the call. + */ + virtual Status DelStreamContextBroadcastAsyncRead(int64_t tagId, RpcRecvFlags flags) = 0; + + /** + * @brief Sync all pub node(except src node) for target stream to src node. + * @param[in] streamName Target stream name. + * @param[in] pubNodeSet All pub node set. + * @param[in] isRecon Is this part of reconciliation process. + * @return Status of the call. + */ + virtual Status SyncPubNode(const std::string &streamName, const std::set &pubNodeSet, bool isRecon) = 0; + + /** + * @brief Sync all consumer node(except consumer generated from src node) for target stream to src node. + * @param[in] streamName Target stream name. + * @param[in] consumerMetas All consumer metadata list. + * @param[in] retainData Ask Producers to retain data if needed + * @param[in] isRecon Is this part of reconciliation + * @return Status of the call. + */ + virtual Status SyncConsumerNode(const std::string &streamName, const std::vector &consumerMetas, + const RetainDataState::State retainData, bool isRecon) = 0; + + /** + * @brief Clear all remote pub node for target stream on src node. + * @param[in] streamName Target stream name. + * @return Status of the call. + */ + virtual Status ClearAllRemotePub(const std::string &streamName) = 0; + + /** + * @brief The stream rpc used to query metadata in worker. + * @param[in/out] stream The stream rpc reader writer. + * @return Status of the call. + */ + virtual Status QueryMetadata( + std::unique_ptr> &stream) = 0; + + /** + * @brief Notify worker the topo update. + * @param[in] req The request send to worker. + * @return Status of the call. + */ + virtual Status UpdateTopoNotification(UpdateTopoNotificationReq &req) = 0; + + /** + * @brief A factory method to instantiate the correct derived version of the api. Remote masters will use an + * rpc-based api, whereas local masters can be optimized for in-process pointer based api. + * @param[in] hostPort The host port of the target master + * @param[in] localHostPort The local worker rpc service host port. + * @param[in] akSkManager Used to do AK/SK authenticate. + * @param[in] service The local pointer to the master SC service implementation. If null, the created api must + * default to the RPC-based version. + * @return A base class pointer to the correct derived type of api. + */ + static std::shared_ptr CreateMasterWorkerSCApi( + const HostPort &hostPort, const HostPort &localHostPort, const std::shared_ptr &akSkManager, + worker::stream_cache::MasterWorkerSCServiceImpl *service); + +protected: + /** + * @brief Constructor, Create a new MasterWorkerSCApi object from master to a particular worker. + * @param[in] localMasterAddress The source master address (local) + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + MasterWorkerSCApi(HostPort localMasterAddress, std::shared_ptr akSkManager); + + /** + * @brief Get log prefix + * @return The log prefix + */ + [[nodiscard]] virtual std::string LogPrefix() const = 0; + + /** + * @brief Construct synchronize sub consumer nodes table protobuf. + * @param[in] streamName Related stream name. + * @param[in] consumerMetas The set of all sub consumer nodes information. + * @param[in] src The source worker node of new producer. + * @param[in] retainData Ask producer to retain data if needed. + * @param[in] isRecon Is this part of reconciliation process. + * @param[out] req Pointer to sync sub consumer nodes table protobuf. + */ + static void ConstructSyncConsumerNodePb(const std::string &streamName, + const std::vector &consumerMetas, const HostPort &src, + const RetainDataState::State retainData, bool isRecon, + SyncConsumerNodeReqPb &req) noexcept; + + /** + * @brief Construct synchronize pub worker nodes table protobuf. + * @param[in] streamName Related stream name. + * @param[in] pubTable The set of all pub worker nodes address. + * @param[in] src The source worker node of subscription. + * @param[in] isRecon Is this part of reconciliation process. + * @param[out] req Pointer to sync pub worker nodes table protobuf. + */ + static void ConstructSyncPubNodePb(const std::string &streamName, const std::set &pubTable, + const HostPort &src, bool isRecon, SyncPubNodeReqPb &req) noexcept; + + HostPort localMasterAddress_; + std::shared_ptr akSkManager_{ nullptr }; +}; + +/** + * @brief MasterRemoteWorkerSCApi is the derived remote version of the api for sending and receiving worker SC requests + * where the worker is on a different host. This class will use an RPC mechanism for communication to the remote + * location. + * Callers will access this class naturally through base class polymorphism. + * See the parent interface for function argument documentation. + */ +class MasterRemoteWorkerSCApi : public MasterWorkerSCApi { +public: + /** + * @brief Constructor, Create a new MasterWorkerSCApi object from master to a particular worker. + * @param[in] workerAddress The target worker node address + * @param[in] localAddress The source address of this host + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + explicit MasterRemoteWorkerSCApi(HostPort workerAddress, const HostPort &localAddress, + std::shared_ptr akSkManager); + ~MasterRemoteWorkerSCApi() override = default; + Status Init() override; + MasterWorkerSCApiType TypeId() override; + Status DelStreamContextBroadcast(const std::string &streamName, bool forceDelete) override; + Status DelStreamContextBroadcastAsyncWrite(const std::string &streamName, bool forceDelete, + int64_t &tagId) override; + Status DelStreamContextBroadcastAsyncRead(int64_t tagId, RpcRecvFlags flags) override; + Status SyncPubNode(const std::string &streamName, const std::set &pubNodeSe, bool isRecon) override; + Status SyncConsumerNode(const std::string &streamName, const std::vector &consumerMetas, + const RetainDataState::State retainData, bool isRecon) override; + Status ClearAllRemotePub(const std::string &streamName) override; + Status QueryMetadata( + std::unique_ptr> &stream) override; + Status UpdateTopoNotification(UpdateTopoNotificationReq &req) override; + + Status ClearAllRemotePubAsynWrite(const std::string &streamName, int64_t &tagId); + Status ClearAllRemotePubAsynRead(int64_t tagId, RpcRecvFlags flags); + +private: + /** + * @brief Get log prefix + * @return The log prefix + */ + [[nodiscard]] std::string LogPrefix() const override; + + HostPort workerAddress_; + std::shared_ptr rpcSession_{ nullptr }; // Session to the worker rpc service. +}; + +/** + * @brief MasterLocalWorkerSCApi is the derived local version of the api for sending and receiving worker SC requests + * where the worker exists in the same process as the service. This class will directly reference the service through a + * pointer and does not use any RPC mechanism for communication. + * Callers will access this class naturally through base class polymorphism. + * See the parent interface for function argument documentation. + */ +class MasterLocalWorkerSCApi : public MasterWorkerSCApi { +public: + /** + * @brief Constructor, Create a new MasterWorkerSCApi object from master to a particular worker. + * @param[in] service The direct pointer to the service for the requests. + * @param[in] localAddress The source address of this host + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + explicit MasterLocalWorkerSCApi(worker::stream_cache::MasterWorkerSCServiceImpl *service, + const HostPort &localAddress, std::shared_ptr akSkManager); + ~MasterLocalWorkerSCApi() override = default; + Status Init() override; + MasterWorkerSCApiType TypeId() override; + Status DelStreamContextBroadcast(const std::string &streamName, bool forceDelete) override; + Status DelStreamContextBroadcastAsyncWrite(const std::string &streamName, bool forceDelete, + int64_t &tagId) override; + Status DelStreamContextBroadcastAsyncRead(int64_t tagId, RpcRecvFlags flags) override; + Status SyncPubNode(const std::string &streamName, const std::set &pubNodeSet, bool isRecon) override; + Status SyncConsumerNode(const std::string &streamName, const std::vector &consumerMetas, + const RetainDataState::State retainData, bool isRecon) override; + Status ClearAllRemotePub(const std::string &streamName) override; + Status QueryMetadata( + std::unique_ptr> &stream) override; + Status QueryMetadata(const GetMetadataAllStreamReqPb &req, GetMetadataAllStreamRspPb &rsp); + Status UpdateTopoNotification(UpdateTopoNotificationReq &req) override; + +private: + /** + * @brief Get log prefix + * @return The log prefix + */ + [[nodiscard]] std::string LogPrefix() const override; + worker::stream_cache::MasterWorkerSCServiceImpl *workerSC_{ nullptr }; +}; + +/** + * @brief Convert HostPortPb to string. + * @param hostPb The HostPortPb object. + * @return The string of the HostPortPb. + */ +inline std::string HostPb2Str(const HostPortPb &hostPb) noexcept +{ + HostPort addr(hostPb.host(), hostPb.port()); + return addr.ToString(); +} +} // namespace master +} // namespace datasystem + +#endif // DATASYSTEM_MASTER_STREAM_CACHE_MASTER_WORKER_SC_API_H diff --git a/src/datasystem/master/stream_cache/rpc_session_manager.cpp b/src/datasystem/master/stream_cache/rpc_session_manager.cpp new file mode 100644 index 0000000..91366f2 --- /dev/null +++ b/src/datasystem/master/stream_cache/rpc_session_manager.cpp @@ -0,0 +1,33 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The interface of rpc session manager. + */ +#include "datasystem/master/stream_cache/rpc_session_manager.h" +#include "datasystem/common/util/status_helper.h" + +namespace datasystem { +namespace master { + +Status RpcSessionManager::GetRpcSession(const HostPort &endPoint, std::shared_ptr &rpcStub, + const std::shared_ptr &akSkManager) +{ + rpcStub = MasterWorkerSCApi::CreateMasterWorkerSCApi(endPoint, localMaster_, akSkManager, masterWorkerSvc_.get()); + return rpcStub->Init(); +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/rpc_session_manager.h b/src/datasystem/master/stream_cache/rpc_session_manager.h new file mode 100644 index 0000000..023a790 --- /dev/null +++ b/src/datasystem/master/stream_cache/rpc_session_manager.h @@ -0,0 +1,61 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The interface of rpc session manager. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_RPC_SESSION_MANAGER_H +#define DATASYSTEM_MASTER_STREAM_CACHE_RPC_SESSION_MANAGER_H + +#include +#include +#include + +#include "datasystem/master/stream_cache/master_worker_sc_api.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class MasterWorkerSCServiceImpl; +} // namespace stream_cache +} // namespace worker +namespace master { +class RpcSessionManager { +public: + /** + * @brief Get RpcSession by endPoint. + * @param[in] endPoint Worker endpoint address. + * @param[out] rpcStub Pointer to target rpcSession. + * @param[in] akSkManager Used to do AK/SK authenticate. + * @return Status of the call. + */ + Status GetRpcSession(const HostPort &endPoint, std::shared_ptr &rpcStub, + const std::shared_ptr &akSkManager); + + void SetLocalArgs(const HostPort &localMaster, + std::shared_ptr service) + { + localMaster_ = localMaster; + masterWorkerSvc_ = service; + } + +private: + HostPort localMaster_; + std::shared_ptr masterWorkerSvc_{ nullptr }; +}; +} // namespace master +} // namespace datasystem +#endif // DATASYSTEM_MASTER_STREAM_CACHE_RPC_SESSION_MANAGER_H diff --git a/src/datasystem/master/stream_cache/sc_metadata_manager.cpp b/src/datasystem/master/stream_cache/sc_metadata_manager.cpp new file mode 100644 index 0000000..417f058 --- /dev/null +++ b/src/datasystem/master/stream_cache/sc_metadata_manager.cpp @@ -0,0 +1,1075 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Module responsible for managing the stream cache metadata on the master. + */ +#include "datasystem/master/stream_cache/sc_metadata_manager.h" + +#include +#include + +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/container_util.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/uri.h" +#include "datasystem/master/stream_cache/sc_notify_worker_manager.h" +#include "datasystem/master/stream_cache/stream_metadata.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/cluster_event_type.h" +#include "datasystem/worker/hash_ring/hash_ring_event.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" + +DS_DECLARE_string(rocksdb_store_dir); +DS_DECLARE_int32(sc_regular_socket_num); +DS_DECLARE_int32(sc_stream_socket_num); +DS_DECLARE_uint32(node_dead_timeout_s); + +namespace datasystem { +namespace master { +constexpr size_t THREAD_POOL_SIZE = 8; +const std::string SC_METADATA_MANAGER = "SCMetadataManager-"; +namespace { +inline void HostPb2Host(const HostPortPb &hostPb, HostPort &host) noexcept +{ + host = HostPort(hostPb.host(), hostPb.port()); +} + +bool EnableSCService() +{ + return FLAGS_sc_regular_socket_num > 0 && FLAGS_sc_stream_socket_num > 0; +} +} // namespace + +SCMetadataManager::SCMetadataManager(const HostPort &masterHostPort, std::shared_ptr akSkManager, + std::shared_ptr rpcSessionManager, EtcdClusterManager *cm, + RocksStore *rocksStore, const std::string &dbName) + : MetadataRedirectHelper(cm), + masterAddress_(masterHostPort), + akSkManager_(std::move(akSkManager)), + rpcSessionManager_(std::move(rpcSessionManager)), + dbName_(dbName), + eventName_(SC_METADATA_MANAGER + dbName) +{ + streamMetaStore_ = std::make_shared(rocksStore); + exitFlag_ = std::make_shared(false); +} + +SCMetadataManager::~SCMetadataManager() +{ + LOG(INFO) << "Destroy SCMetadataManager."; + Shutdown(); +} + +void SCMetadataManager::Shutdown() +{ + if (exitFlag_->load()) { + return; + } + exitFlag_->store(true); + LOG(INFO) << "Start shutdown ScMetadataManager for " << dbName_; + if (!EnableSCService()) { + return; + } + CheckNewNodeMetaEvent::GetInstance().RemoveSubscriber(eventName_); + StartClearWorkerMeta::GetInstance().RemoveSubscriber(eventName_); + ClearWorkerMeta::GetInstance().RemoveSubscriber(eventName_); + HashRingEvent::RecoverMetaRanges::GetInstance().RemoveSubscriber(eventName_); + if (notifyWorkerManager_ != nullptr) { + notifyWorkerManager_->Shutdown(); + } + if (asyncReconciliationPool_ != nullptr) { + asyncReconciliationPool_.reset(); + } +} + +void SCMetadataManager::SetClusterManagerToNullptr() +{ + MetadataRedirectHelper::Shutdown(); +} + +Status SCMetadataManager::Init() +{ + RETURN_OK_IF_TRUE(!EnableSCService()); + CHECK_FAIL_RETURN_STATUS(rpcSessionManager_ != nullptr, StatusCode::K_RUNTIME_ERROR, + "Runtime error, failed to get RpcSessionManager"); + RETURN_IF_NOT_OK(Uri::NormalizePathWithUserHomeDir(FLAGS_rocksdb_store_dir, "~/.datasystem/rocksdb", "/master")); + RETURN_IF_NOT_OK(streamMetaStore_->Init()); + RETURN_IF_EXCEPTION_OCCURS(asyncReconciliationPool_ = + std::make_unique(0, THREAD_POOL_SIZE, "MScAsyncReconcilation", false)); + + notifyWorkerManager_ = + std::make_unique(streamMetaStore_, akSkManager_, rpcSessionManager_, etcdCM_, this); + RETURN_IF_NOT_OK(notifyWorkerManager_->Init()); + RETURN_IF_NOT_OK(LoadMeta()); + CheckNewNodeMetaEvent::GetInstance().AddSubscriber(eventName_, [this](const HostPort &eventNodeKey) { + StartCheckMetadata(eventNodeKey); + return Status::OK(); + }); + StartClearWorkerMeta::GetInstance().AddSubscriber(eventName_, [this](const HostPort &eventNodeKey) { + StartClearWorkerMetadata(eventNodeKey); + return Status::OK(); + }); + ClearWorkerMeta::GetInstance().AddSubscriber( + eventName_, [this](const HostPort &eventNodeKey) { return ClearWorkerMetadata(eventNodeKey); }); + HashRingEvent::RecoverMetaRanges::GetInstance().AddSubscriber( + eventName_, [this](const std::vector &workerUuids, const worker::HashRange &extraRanges) { + return RecoverMetadataOfFaultyWorker(workerUuids, extraRanges); + }); + LOG(INFO) << FormatString("[%s] Initialize success", LogPrefix()); + return Status::OK(); +} + +bool SCMetadataManager::MetaIsFound(const std::string &streamName) +{ + ReadLockHelper rlocker(LOCK_ARGS_MSG(metaDictMutex_, streamName)); + TbbMetaHashmap::const_accessor accessor; + return streamMetaManagerDict_.find(accessor, streamName); +} + +bool SCMetadataManager::CheckMetaTableEmpty() +{ + ReadLockHelper rlocker(LOCK_ARGS(metaDictMutex_)); + return streamMetaManagerDict_.empty(); +} + +void SCMetadataManager::GetMetasMatch(std::function &&matchFunc, + std::vector &streamNames, bool *exitEarly) +{ + int timeoutSec = 300; + INJECT_POINT("SCMetadataManager.GetMetasMatch.timeout", [&timeoutSec] (int timeout) { + timeoutSec = timeout; + }); + WriteLockHelper wlocker(DEFER_LOCK_ARGS(metaDictMutex_)); + if (!wlocker.TryLock(timeoutSec)) { + LOG(WARNING) << "[GetMetasMatch] Failed to acquire lock within " << timeoutSec + << " seconds, aborting metadata matching operation"; + return; + } + for (const auto &it : streamMetaManagerDict_) { + if (exitEarly && *exitEarly) { + break; + } + if (it.first.empty()) { + LOG(ERROR) << "[GetMetasMatch] stream name is empty!"; + continue; + } + if (matchFunc(it.first)) { + streamNames.emplace_back(it.first); + } + } +} + +Status SCMetadataManager::SaveMigrationMetadata(const MigrateSCMetadataReqPb &req, MigrateSCMetadataRspPb &rsp) +{ + CHECK_FAIL_RETURN_STATUS(etcdCM_->CheckReceiveMigrateInfo(), K_NOT_READY, + "wait and retry, worker don't receive addnode info"); + LOG(INFO) << "Recv migrate metadata msg. source:" << req.source_addr() + << ", stream count:" << req.stream_metas().size(); + + auto injectTest = []() { + INJECT_POINT("master.sc.fail_save_migration_data", []() { return true; }); + return false; + }; + Status status; + for (auto &streamMeta : req.stream_metas()) { + if (injectTest()) { + rsp.add_results(MigrateSCMetadataRspPb::FAILED); + continue; + } + + if (SaveMigrationData(streamMeta, status, rsp).IsError()) { + rsp.add_results(MigrateSCMetadataRspPb::FAILED); + continue; + } + rsp.add_results(MigrateSCMetadataRspPb::SUCCESSFUL); + } + return Status::OK(); +} + +Status SCMetadataManager::SaveMigrationData(const MetaForSCMigrationPb &streamMeta, Status &status, + MigrateSCMetadataRspPb &rsp) +{ + (void)status; + (void)rsp; + ReadLockHelper rlocker(LOCK_ARGS(metaDictMutex_)); + const auto &meta = streamMeta.meta(); + const std::string &streamName = meta.stream_name(); + StreamFields streamFields(meta.max_stream_size(), meta.page_size(), meta.auto_cleanup(), meta.retain_num_consumer(), + meta.encrypt_stream(), meta.reserve_size(), meta.stream_mode()); + bool needRevert = true; + // clang-format off + CHECK_FAIL_RETURN_STATUS(streamMetaManagerDict_.emplace( + streamName, std::make_shared(streamName, streamFields, streamMetaStore_.get(), + akSkManager_, rpcSessionManager_, etcdCM_, notifyWorkerManager_.get())), + StatusCode::K_RUNTIME_ERROR, "Load meta reconstruction insertion failed"); + // clang-format on + TbbMetaHashmap::accessor accessor; + RETURN_IF_NOT_OK(GetStreamMetadata(streamName, accessor)); + RaiiPlus raiiP([this, &needRevert, &accessor]() { + if (needRevert) { + streamMetaManagerDict_.erase(accessor); + } + }); + StreamMetadata *metadata = accessor->second.get(); + CHECK_FAIL_RETURN_STATUS(metadata != nullptr, K_RUNTIME_ERROR, "metadata is null"); + if (ScMetricsMonitor::Instance()->IsEnabled()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + metadata->InitStreamMetrics(), + FormatString("[%s, S:%s] Init master sc metrics failed", LogPrefix(), streamName)); + } + std::vector producerRelatedNodes = { streamMeta.producer_rel_nodes().begin(), + streamMeta.producer_rel_nodes().end() }; + std::vector consumerRelatedNodes = { streamMeta.consumer_rel_nodes().begin(), + streamMeta.consumer_rel_nodes().end() }; + metadata->PreparePubSubRelNodes(producerRelatedNodes, consumerRelatedNodes); + // Make use of recovery logic so no notification is sent. + for (const auto &producerMetaPb : streamMeta.producers()) { + RETURN_IF_NOT_OK(metadata->RecoveryPubMeta(producerMetaPb)); + RETURN_IF_NOT_OK(streamMetaStore_->AddPubNode(producerMetaPb)); + raiiP.AddTask([this, &needRevert, &producerMetaPb]() { + if (needRevert) { + LOG_IF_ERROR(streamMetaStore_->DelPubNode(producerMetaPb), + "Rollback persisted producer failed in migration failure."); + } + }); + } + for (const auto &consumerMetaPb : streamMeta.consumers()) { + RETURN_IF_NOT_OK(metadata->RecoverySubMeta(consumerMetaPb)); + RETURN_IF_NOT_OK(streamMetaStore_->AddSubNode(consumerMetaPb)); + raiiP.AddTask([this, &needRevert, &consumerMetaPb]() { + if (needRevert) { + LOG_IF_ERROR(streamMetaStore_->DelSubNode(consumerMetaPb.stream_name(), consumerMetaPb.consumer_id()), + "Rollback persisted consumer failed in migration failure."); + } + }); + } + // Recover the lifetime consumer count and retain data state. + const auto &consumerLifeCount = meta.consumer_life_count(); + RETURN_IF_NOT_OK(metadata->RestoreConsumerLifeCount(consumerLifeCount)); + auto currentState = metadata->CheckNUpdateNeedRetainData(); + VLOG(SC_NORMAL_LOG_LEVEL) << "[RetainData] RetainData state is restored for stream: " << streamName << " to " + << currentState; + + RETURN_IF_NOT_OK(streamMetaStore_->AddStream(streamName, streamFields)); + raiiP.AddTask([this, &needRevert, &streamName]() { + if (needRevert) { + LOG_IF_ERROR(streamMetaStore_->DelStream(streamName), + "Rollback persisted stream failed in migration failure."); + } + }); + RETURN_IF_NOT_OK(streamMetaStore_->UpdateLifeTimeConsumerCount(streamName, consumerLifeCount)); + // Migrate the async notifications. + RETURN_IF_NOT_OK(notifyWorkerManager_->AddAsyncNotifications(streamFields, streamName, streamMeta)); + // If auto clean up is true and there is no more producer/consumer, delete the stream. + RETURN_IF_NOT_OK(metadata->AutoCleanupIfNeeded(HostPort())); + needRevert = false; + return Status::OK(); +} + +Status SCMetadataManager::FillMetadataForMigration(const std::string &streamName, MetaForSCMigrationPb *meta) +{ + TbbMetaHashmap::accessor accessor; + RETURN_IF_NOT_OK(GetStreamMetadata(streamName, accessor)); + // Marker is migrating + migratingItems_.insert({ streamName, true }); + + // Fill metadata + StreamMetadata *metadata = accessor->second.get(); + StreamMetaPb *streamMetaPb = meta->mutable_meta(); + streamMetaPb->set_stream_name(streamName); + const auto &streamFields = metadata->GetStreamFields(); + streamMetaPb->set_max_stream_size(streamFields.maxStreamSize_); + streamMetaPb->set_page_size(streamFields.pageSize_); + streamMetaPb->set_auto_cleanup(streamFields.autoCleanup_); + streamMetaPb->set_retain_num_consumer(streamFields.retainForNumConsumers_); + streamMetaPb->set_encrypt_stream(streamFields.encryptStream_); + streamMetaPb->set_reserve_size(streamFields.reserveSize_); + streamMetaPb->set_stream_mode(streamFields.streamMode_); + streamMetaPb->set_consumer_life_count(metadata->GetConsumerLifeCount()); + + std::vector producerRelatedNodes; + std::vector consumerRelatedNodes; + std::vector masterProducers; + std::vector masterConsumers; + metadata->GetAllProducerConsumer(masterProducers, masterConsumers, producerRelatedNodes, consumerRelatedNodes); + *meta->mutable_producers() = { masterProducers.begin(), masterProducers.end() }; + *meta->mutable_consumers() = { masterConsumers.begin(), masterConsumers.end() }; + *meta->mutable_producer_rel_nodes() = { producerRelatedNodes.begin(), producerRelatedNodes.end() }; + *meta->mutable_consumer_rel_nodes() = { consumerRelatedNodes.begin(), consumerRelatedNodes.end() }; + std::vector notifications; + notifyWorkerManager_->GetPendingNotificationByStreamName(streamName, notifications); + *meta->mutable_notifications() = { notifications.begin(), notifications.end() }; + return Status::OK(); +} + +void SCMetadataManager::HandleMetaDataMigrationSuccess(const std::string &streamName) +{ + Raii outer([this, &streamName]() { migratingItems_.erase(streamName); }); + + auto func = [this, &streamName]() { + TbbMetaHashmap::accessor accessor; + ReadLockHelper rlocker(LOCK_ARGS_MSG(metaDictMutex_, streamName)); + RETURN_IF_NOT_OK(GetStreamMetadataNoLock(streamName, accessor)); + // Cleanup rocksdb after migration, so it will not be reloaded after a restart + StreamMetadata *metadata = accessor->second.get(); + RETURN_IF_NOT_OK(metadata->CleanUpStreamPersistent(streamName)); + // Also cleanup the notifications since they are also migrated already + RETURN_IF_NOT_OK(notifyWorkerManager_->RemovePendingNotificationByStreamName(streamName)); + CHECK_FAIL_RETURN_STATUS(streamMetaManagerDict_.erase(accessor) == 1, StatusCode::K_NOT_FOUND, + FormatString("Stream <%s> does not exist on master", streamName)); + return Status::OK(); + }; + LOG_IF_ERROR(func(), ""); +} + +void SCMetadataManager::HandleMetaDataMigrationFailed(const MetaForSCMigrationPb &streamMeta) +{ + (void)streamMeta; + // Placeholder, do nothing +} + +Status SCMetadataManager::CreateProducer(const CreateProducerReqPb &req, CreateProducerRspPb &rsp) +{ + INJECT_POINT("SCMetadataManager.CreateProducer.wait"); + const auto &producerMeta = req.producer_meta(); + const std::string &streamName = producerMeta.stream_name(); + bool redirect = req.redirect(); + FillRedirectResponseInfo(rsp, streamName, redirect); + RETURN_OK_IF_TRUE(redirect); + StreamFields streamFields(req.max_stream_size(), req.page_size(), req.auto_cleanup(), req.retain_num_consumer(), + req.encrypt_stream(), req.reserve_size(), req.stream_mode()); + // Hold const_accessor to allow parallel CreateProducer. + TbbMetaHashmap::const_accessor accessor; + Status rc = GetStreamMetadata(streamName, accessor); + if (rc.GetCode() == K_NOT_FOUND) { + RETURN_IF_NOT_OK(CreateStreamMetadata(streamName)); + RETURN_IF_NOT_OK(GetStreamMetadata(streamName, accessor)); + } + StreamMetadata *metadata = accessor->second.get(); + RETURN_IF_NOT_OK(VerifyStreamMode(static_cast(req.stream_mode()), metadata->GetConsumerCount(), + metadata->GetProducerCount() + 1)); + RETURN_IF_NOT_OK(metadata->PubIncreaseNode(producerMeta, streamFields)); + LOG(INFO) << FormatString("[%s, S:%s, W:%s] Create producer on master success", LogPrefix(), streamName, + HostPb2Str(producerMeta.worker_address())); + return Status::OK(); +} + +bool SCMetadataManager::HandleCloseProducerError(bool &firstError, const Status &rc, const ProducerInfoPb &info, + CloseProducerRspPb &rsp) +{ + if (rc.IsError()) { + LOG(INFO) << FormatString("[%s, S:%s ] CloseProducer failed on master: %s", LogPrefix(), info.stream_name(), + rc.ToString()); + INJECT_POINT("master.sc.close_producer_error", []() { return true; }); + if (firstError) { + // The first time that an error is hit, we save it into the response struct. + rsp.mutable_err()->set_error_code(rc.GetCode()); + rsp.mutable_err()->set_error_msg(rc.GetMsg()); + firstError = false; + } + // Add this producer to the failed list + auto failedProducerPb = rsp.add_failed_producers(); + failedProducerPb->CopyFrom(info); + return true; + } + return false; +} + +Status SCMetadataManager::CloseProducer(const CloseProducerReqPb &req, CloseProducerRspPb &rsp) +{ + INJECT_POINT("SCMetadataManager.CloseProducer.wait"); + int numProducers = req.producer_infos_size(); + VLOG(SC_NORMAL_LOG_LEVEL) << "Starting to close " << numProducers << " producers in master from worker " + << HostPb2Str(req.worker_address()); + const bool forceClose = req.force_close(); + int numSuccess = 0; + bool firstError = true; + // Check redirect for producers + std::vector streams; + for (const ProducerInfoPb &currProducer : req.producer_infos()) { + streams.emplace_back(currProducer.stream_name()); + } + bool redirect = req.redirect(); + FillRedirectResponseInfos(rsp, streams, redirect); + // For now the close producer request is grouped by stream name, so if redirect is needed for any, + // all of them will need the redirect, otherwise it will be empty. + RETURN_OK_IF_TRUE(streams.empty()); + // For producer that does not need redirect, drive close logic against them. + // Any failed producers are tracked, and then return to the caller. + for (const ProducerInfoPb &currProducer : req.producer_infos()) { + const std::string &streamName = currProducer.stream_name(); + + // Hold const_accessor to allow parallel CloseProducer. + TbbMetaHashmap::const_accessor accessor; + Status rc = GetStreamMetadata(streamName, accessor); + if (HandleCloseProducerError(firstError, rc, currProducer, rsp)) { + continue; // loop to next producer in the list. This one failed. + } + + ProducerMetaPb producerMeta; + producerMeta.set_stream_name(currProducer.stream_name()); + producerMeta.mutable_worker_address()->CopyFrom(req.worker_address()); + StreamMetadata *metadata = accessor->second.get(); + + rc = metadata->PubDecreaseNode(producerMeta, forceClose); + if (HandleCloseProducerError(firstError, rc, currProducer, rsp)) { + continue; // loop to next producer in the list. This one failed. + } + ++numSuccess; + auto successProducerPb = rsp.add_success_producers(); + successProducerPb->CopyFrom(currProducer); + VLOG(SC_NORMAL_LOG_LEVEL) << "Producer for stream " << currProducer.stream_name() + << " successfully closed on master."; + } + + VLOG(SC_NORMAL_LOG_LEVEL) << "Finished closing producers in master. Num successful: " << numSuccess + << ". Num failed: " << rsp.failed_producers_size(); + + // Always return success. The error code for any failures is packed into the rsp structure that the sending side + // must unpack. + return Status::OK(); +} + +Status SCMetadataManager::Subscribe(const SubscribeReqPb &req, SubscribeRspPb &rsp) +{ + INJECT_POINT("SCMetadataManager.Subscribe.wait"); + CHECK_FAIL_RETURN_STATUS(req.has_consumer_meta(), StatusCode::K_RUNTIME_ERROR, + "Runtime error in get consumer_meta"); + const auto &consumerMeta = req.consumer_meta(); + const auto &streamName = consumerMeta.stream_name(); + bool redirect = req.redirect(); + FillRedirectResponseInfo(rsp, streamName, redirect); + RETURN_OK_IF_TRUE(redirect); + + // Accessor as write lock for this stream. + // A procedure lock for the stream. + TbbMetaHashmap::accessor accessor; + RETURN_IF_NOT_OK(CreateOrGetStreamMetadata(streamName, accessor)); + StreamMetadata *metadata = accessor->second.get(); + RETURN_IF_NOT_OK(VerifyStreamMode(metadata->GetStreamFields().streamMode_, metadata->GetConsumerCount() + 1, + metadata->GetProducerCount())); + RETURN_IF_NOT_OK(metadata->SubIncreaseNode(consumerMeta)); + const StreamFields &streamFields = metadata->GetStreamFields(); + rsp.set_max_stream_size(streamFields.maxStreamSize_); + rsp.set_page_size(streamFields.pageSize_); + rsp.set_auto_cleanup(streamFields.autoCleanup_); + rsp.set_retain_num_consumer(streamFields.retainForNumConsumers_); + rsp.set_retain_data(metadata->CheckNUpdateNeedRetainData()); + rsp.set_encrypt_stream(streamFields.encryptStream_); + rsp.set_reserve_size(streamFields.reserveSize_); + rsp.set_stream_mode(streamFields.streamMode_); + LOG(INFO) << FormatString("[%s, S:%s, C:%s, W:%s] Create consumer on master success", LogPrefix(), streamName, + consumerMeta.consumer_id(), HostPb2Str(req.consumer_meta().worker_address())); + return Status::OK(); +} + +Status SCMetadataManager::CloseConsumer(const CloseConsumerReqPb &req, CloseConsumerRspPb &rsp) +{ + INJECT_POINT("SCMetadataManager.CloseConsumer.wait"); + (void)rsp; + CHECK_FAIL_RETURN_STATUS(req.has_consumer_meta(), StatusCode::K_RUNTIME_ERROR, + "Runtime error in get consumer_meta"); + const auto &consumerMeta = req.consumer_meta(); + const auto &streamName = consumerMeta.stream_name(); + bool redirect = req.redirect(); + FillRedirectResponseInfo(rsp, streamName, redirect); + RETURN_OK_IF_TRUE(redirect); + + // Accessor as write lock for this stream. + // A procedure lock for the stream. + TbbMetaHashmap::accessor accessor; + RETURN_IF_NOT_OK(GetStreamMetadata(streamName, accessor)); + StreamMetadata *metadata = accessor->second.get(); + RETURN_IF_NOT_OK(metadata->SubDecreaseNode(consumerMeta)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s, W:%s] Close consumer on master success", LogPrefix(), + streamName, consumerMeta.consumer_id(), + HostPb2Str(req.consumer_meta().worker_address())); + return Status::OK(); +} + +Status SCMetadataManager::SendDeleteStreamReqToWorker(const std::string &streamName, const HostPort workerNode) +{ + std::shared_ptr masterWorkerApi = nullptr; + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(workerNode, masterWorkerApi, akSkManager_)); + RETURN_IF_NOT_OK_EXCEPT(masterWorkerApi->DelStreamContextBroadcast(streamName, false), + StatusCode::K_SC_STREAM_NOT_FOUND); + return Status::OK(); +} + +Status SCMetadataManager::SendDeleteStreamReq(const std::string &streamName, std::set &relatedWorkerSet) +{ + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Sending DelStreamContextBroadcast request to %d workers", + LogPrefix(), streamName, relatedWorkerSet.size()); + std::unordered_map, int64_t> tagIds; + bool local = false; + for (const auto &workerNode : relatedWorkerSet) { + // Deal with local worker separately. + if (workerNode == masterAddress_) { + local = true; + continue; + } else if (CheckWorkerStatus(workerNode).IsError()) { + // Avoid unlimited retry due to node lost. + continue; + } + // Each related node maybe a worker which has not been recorded, so we should initialize rpc channel. + int64_t tagId; + std::shared_ptr masterWorkerApi = nullptr; + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(workerNode, masterWorkerApi, akSkManager_)); + RETURN_IF_NOT_OK(masterWorkerApi->DelStreamContextBroadcastAsyncWrite(streamName, false, tagId)); + tagIds.emplace(masterWorkerApi, tagId); + } + // Process the local bypass request to local worker in between AsyncWrite and AsyncRead for better time utilization. + if (local) { + RETURN_IF_NOT_OK(SendDeleteStreamReqToWorker(streamName, masterAddress_)); + } + for (const auto &pair : tagIds) { + RETURN_IF_NOT_OK_EXCEPT(pair.first->DelStreamContextBroadcastAsyncRead(pair.second, RpcRecvFlags::NONE), + StatusCode::K_SC_STREAM_NOT_FOUND); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] DelStreamContextBroadcast requests are done", LogPrefix(), + streamName); + return Status::OK(); +} + +Status SCMetadataManager::DeleteStream(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp) +{ + INJECT_POINT("SCMetadataManager.DeleteStream.sleep"); + (void)rsp; + const std::string &streamName = req.stream_name(); + bool redirect = req.redirect(); + FillRedirectResponseInfo(rsp, streamName, redirect); + RETURN_OK_IF_TRUE(redirect); + HostPort srcWorkerAddr; + HostPb2Host(req.src_node_addr(), srcWorkerAddr); + std::set relatedWorkerSet; + bool decrementRef = false; + Raii deleter([this, streamName, &decrementRef]() { + TbbMetaHashmap::const_accessor accessor; + Status rc = GetStreamMetadata(streamName, accessor); + if (rc.IsOk()) { + // undo delete if stream still exists + StreamMetadata *metadata = accessor->second.get(); + RETURN_RUNTIME_ERROR_IF_NULL(metadata); + metadata->UndoDeleteStream(decrementRef); + } + // ignore K_NOT_FOUND error as stream might have deleted + return Status::OK(); + }); + + // Accessor as write lock for this stream. + // A procedure lock for the stream. + // Only change local metadata under lock. + { + TbbMetaHashmap::accessor accessor; + RETURN_IF_NOT_OK(GetStreamMetadata(streamName, accessor)); + StreamMetadata *metadata = accessor->second.get(); + RETURN_RUNTIME_ERROR_IF_NULL(metadata); + RETURN_IF_NOT_OK(metadata->DeleteStreamStart(srcWorkerAddr, relatedWorkerSet)); + decrementRef = true; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Started to Delete stream on master", LogPrefix(), + streamName); + } + + // Send requests to all workers without a lock + // here all other requests to the stream will be rejected (including another delete) + // based on delete stream state we set above so need a lock to + // prevent this + INJECT_POINT("SCMetadataManager.DeleteStream.SendReqs"); + RETURN_IF_NOT_OK(SendDeleteStreamReq(streamName, relatedWorkerSet)); + INJECT_POINT("SCMetadataManager.DeleteStream.SentReqs"); + // We need lock here as someone else might be accessing streamMetaManagerDict_ + { + TbbMetaHashmap::accessor accessor; + ReadLockHelper rlocker(LOCK_ARGS_MSG(metaDictMutex_, streamName)); + RETURN_IF_NOT_OK(GetStreamMetadataNoLock(streamName, accessor)); + StreamMetadata *metadata = accessor->second.get(); + RETURN_RUNTIME_ERROR_IF_NULL(metadata); + RETURN_IF_NOT_OK(metadata->DeleteStreamEnd()); + CHECK_FAIL_RETURN_STATUS(streamMetaManagerDict_.erase(accessor) == 1, StatusCode::K_NOT_FOUND, + FormatString("Stream <%s> does not exist on master", streamName)); + } + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Delete stream on master success", LogPrefix(), streamName); + return Status::OK(); +} + +Status SCMetadataManager::QueryGlobalProducersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + const auto &streamName = req.stream_name(); + bool redirect = req.redirect(); + FillRedirectResponseInfo(rsp, streamName, redirect); + RETURN_OK_IF_TRUE(redirect); + // Accessor as read lock for this stream. + // A procedure lock for the stream. + TbbMetaHashmap::const_accessor accessor; + Status rc = GetStreamMetadata(streamName, accessor); + RETURN_OK_IF_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND); + RETURN_IF_NOT_OK(rc); + StreamMetadata *metadata = accessor->second.get(); + auto count = metadata->GetProducerCount(); + rsp.set_global_count(count); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] The global producer number is %ld", LogPrefix(), streamName, + count); + return Status::OK(); +} + +Status SCMetadataManager::QueryGlobalConsumersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + const auto &streamName = req.stream_name(); + bool redirect = req.redirect(); + FillRedirectResponseInfo(rsp, streamName, redirect); + RETURN_OK_IF_TRUE(redirect); + // Accessor as read lock for this stream. + // A procedure lock for the stream. + TbbMetaHashmap::const_accessor accessor; + Status rc = GetStreamMetadata(streamName, accessor); + RETURN_OK_IF_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND); + RETURN_IF_NOT_OK(rc); + StreamMetadata *metadata = accessor->second.get(); + auto count = metadata->GetConsumerCount(); + rsp.set_global_count(count); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] The global consumer number is %ld", LogPrefix(), streamName, + count); + return Status::OK(); +} + +std::string SCMetadataManager::LogPrefix() const +{ + return FormatString("MasterSvc, Node:%s", masterAddress_.ToString()); +} + +Status SCMetadataManager::CreateOrGetStreamMetadata(const std::string &streamName, TbbMetaHashmap::accessor &accessor) +{ + ReadLockHelper rlocker(LOCK_ARGS_MSG(metaDictMutex_, streamName)); + + // Accessor work as a procedure lock of a stream. + bool isFirst = streamMetaManagerDict_.insert(accessor, streamName); + if (isFirst) { + StreamFields streamFields(0, 0, false, 0, false, 0, + StreamMode::MPMC); // empty stream fields to start for a new entry + // If this stream just created on master, we save it into rocksdb + auto status = streamMetaStore_->AddStream(streamName, streamFields); + if (status.IsError()) { + // Allow stream can be recreated correctly. + (void)streamMetaManagerDict_.erase(accessor); + return status; + } + // Initialize consumer count for the stream + status = streamMetaStore_->UpdateLifeTimeConsumerCount(streamName, 0); + if (status.IsError()) { + // Rollback previous changes + streamMetaStore_->DelStream(streamName); + // Allow stream can be recreated correctly. + (void)streamMetaManagerDict_.erase(accessor); + return status; + } + accessor->second = + std::make_shared(streamName, streamFields, streamMetaStore_.get(), akSkManager_, + rpcSessionManager_, etcdCM_, notifyWorkerManager_.get()); + if (ScMetricsMonitor::Instance()->IsEnabled()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + accessor->second->InitStreamMetrics(), + FormatString("[%s, S:%s] Init master sc metrics failed", LogPrefix(), streamName)); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Create new StreamMetaManager success", LogPrefix(), + streamName); + } else { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] StreamMetaManager already exists", LogPrefix(), + streamName); + } + return Status::OK(); +} + +Status SCMetadataManager::CreateStreamMetadata(const std::string &streamName) +{ + TbbMetaHashmap::accessor accessor; + INJECT_POINT("SCMetadataManager.CreateStreamMetadata", [this, &accessor](const std::string &stream) { + (void)CreateOrGetStreamMetadata(stream, accessor); + return Status::OK(); + }); + RETURN_IF_NOT_OK(CreateOrGetStreamMetadata(streamName, accessor)); + return Status::OK(); +} + +Status SCMetadataManager::LoadMeta() +{ + LOG(INFO) << "Start to load meta data from rocksdb into memory"; + INJECT_POINT("master.SCMetadataManager.LoadMeta"); + std::vector streamMetas; + RETURN_IF_NOT_OK(streamMetaStore_->GetAllStream(streamMetas)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Total stream number:<%d>", streamMetas.size()); + for (const auto &meta : streamMetas) { + const auto &streamName = meta.stream_name(); + // No need to have lock here, because we are not provide rpc service to others before we finish. + StreamFields streamFields(meta.max_stream_size(), meta.page_size(), meta.auto_cleanup(), + meta.retain_num_consumer(), meta.encrypt_stream(), meta.reserve_size(), + meta.stream_mode()); + + // clang-format off + CHECK_FAIL_RETURN_STATUS(streamMetaManagerDict_.emplace( + streamName, std::make_shared(streamName, streamFields, streamMetaStore_.get(), + akSkManager_, rpcSessionManager_, etcdCM_, notifyWorkerManager_.get())), + StatusCode::K_RUNTIME_ERROR, "Load meta reconstruction insertion failed"); + // clang-format on + TbbMetaHashmap::accessor accessor; + RETURN_IF_NOT_OK(GetStreamMetadata(streamName, accessor)); + StreamMetadata *metadata = accessor->second.get(); + CHECK_FAIL_RETURN_STATUS(metadata != nullptr, K_RUNTIME_ERROR, "metadata is null"); + if (ScMetricsMonitor::Instance()->IsEnabled()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + metadata->InitStreamMetrics(), + FormatString("[%s, S:%s] Init master sc metrics failed", LogPrefix(), streamName)); + } + std::vector producerMetaPbVector; + std::vector consumerMetaPbVector; + RETURN_IF_NOT_OK(streamMetaStore_->GetOneStreamProducers(streamName, producerMetaPbVector)); + RETURN_IF_NOT_OK(streamMetaStore_->GetOneStreamConsumers(streamName, consumerMetaPbVector)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Stream:<%s>, Pub Worker number:<%d>, Consumer number:<%d>", + streamName, producerMetaPbVector.size(), consumerMetaPbVector.size()); + + for (const auto &producerMetaPb : producerMetaPbVector) { + RETURN_IF_NOT_OK(metadata->RecoveryPubMeta(producerMetaPb)); + } + for (const auto &consumerMetaPb : consumerMetaPbVector) { + // We assume only master restart, so the topo on worker does not need to sync. + RETURN_IF_NOT_OK(metadata->RecoverySubMeta(consumerMetaPb)); + } + + // Recover the consumer count + uint32_t consumerLifeCount = 0; + WARN_IF_ERROR(streamMetaStore_->GetLifeTimeConsumerCount(streamName, consumerLifeCount), + "Reading an older version of Metadata, without ConsumerLifeCount field."); + RETURN_IF_NOT_OK(metadata->RestoreConsumerLifeCount(consumerLifeCount)); + + // Recover the Retain data state + auto currentState = metadata->CheckNUpdateNeedRetainData(); + VLOG(SC_NORMAL_LOG_LEVEL) << "[RetainData] RetainData state is restored for stream: " << streamName << " to " + << currentState; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Stream:<%s>, Status:", streamName); + } + LOG(INFO) << "Recovery meta data into memory success"; + return Status::OK(); +} + +std::vector> SCMetadataManager::GetStreamMetaByWorkerAddr(const HostPort &workerAddr, + bool clearMeta) +{ + std::vector> streamMetadatas; + WriteLockHelper wlocker(LOCK_ARGS(metaDictMutex_)); + for (const auto &streamMetadata : streamMetaManagerDict_) { + const auto &workerAddress = workerAddr.ToString(); + INJECT_POINT_NO_RETURN("SCMetadataManager.SkipClearEmptyMeta", [&clearMeta]() { clearMeta = false; }); + if (clearMeta) { + streamMetadata.second->ClearEmptyMeta(workerAddress); + } + if (streamMetadata.second->CheckWorkerExistsPubSub(workerAddr.ToString())) { + streamMetadatas.push_back(streamMetadata.second); + } + } + return streamMetadatas; +} + +std::vector> SCMetadataManager::GetAllStreamMeta() +{ + std::vector> streamMetadatas; + WriteLockHelper wlocker(LOCK_ARGS(metaDictMutex_)); + for (const auto &streamMetadata : streamMetaManagerDict_) { + streamMetadatas.push_back(streamMetadata.second); + } + return streamMetadatas; +} + +Status SCMetadataManager::ClearWorkerMetadata(const HostPort &workerAddr, const bool forceClose) +{ + RETURN_OK_IF_TRUE(!EnableSCService()); + LOG(INFO) << "Start clear metadata from " << workerAddr.ToString() << " ,forceClose " << forceClose; + + // Clear the existing task. + RETURN_IF_NOT_OK(notifyWorkerManager_->ClearPendingNotification(workerAddr.ToString())); + + bool clearMeta = true; + std::vector> streamMetas = GetStreamMetaByWorkerAddr(workerAddr, clearMeta); + + Status status; + for (auto &streamMeta : streamMetas) { + LOG(INFO) << FormatString("ClearWorkerMetadata for stream [%s]", streamMeta->GetStreamName()); + auto tmpRc = streamMeta->ClearWorkerMetadata(workerAddr, forceClose); + if (tmpRc.IsError() && tmpRc.GetCode() != K_NOT_FOUND) { + LOG(ERROR) << FormatString("Clear worker meta for stream [%s] failed", streamMeta->GetStreamName()); + status = std::move(tmpRc); + } + } + LOG(INFO) << "Metadata cleared for " << workerAddr.ToString(); + return status; +} + +Status SCMetadataManager::CheckMetadata(const std::vector &workerAddrs, const worker::HashRange &hashRanges) +{ + RETURN_OK_IF_TRUE(!EnableSCService()); + Status lastRc; + std::unordered_set streamsMayNeedAutoDelete; + for (const auto &workerAddr : workerAddrs) { + auto rc = CheckMetadataImpl(workerAddr, hashRanges, streamsMayNeedAutoDelete); + if (rc.IsError()) { + LOG(ERROR) << "CheckMetadata for " << workerAddr << " failed, detail: " << rc.ToString(); + lastRc = std::move(rc); + } + } + // Now that we have complete metadata, we need to do some cleanup. + LOG(INFO) << "streamsMayNeedAutoDelete: " << VectorToString(streamsMayNeedAutoDelete); + PostCheckMetadata(streamsMayNeedAutoDelete); + return lastRc; +} + +Status SCMetadataManager::CheckMetadataImpl(const HostPort &workerAddr, const worker::HashRange &hashRanges, + std::unordered_set &streamsMayNeedAutoDelete) +{ + LOG(INFO) << "Started CheckMetadata with worker: " << workerAddr; + std::shared_ptr masterWorkerApi = nullptr; + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(workerAddr, masterWorkerApi, akSkManager_)); + auto *localApi = dynamic_cast(masterWorkerApi.get()); + GetMetadataAllStreamReqPb req; + std::vector metaRsp; + bool isPassiveScaleDown = !hashRanges.empty(); + if (!isPassiveScaleDown) { + req.set_master_address(masterAddress_.ToString()); + } + for (const auto &range : hashRanges) { + GetMetadataAllStreamReqPb::RangePb rangePb; + rangePb.set_from(range.first); + rangePb.set_end(range.second); + req.mutable_hash_ranges()->Add(std::move(rangePb)); + } + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + if (localApi) { + GetMetadataAllStreamRspPb rsp; + RETURN_IF_NOT_OK_APPEND_MSG(localApi->QueryMetadata(req, rsp), "QueryMetadata send to worker failed"); + metaRsp.insert(metaRsp.end(), rsp.stream_meta().begin(), rsp.stream_meta().end()); + } else { + // non-local worker, use the stream api for this query + std::unique_ptr> stream; + INJECT_POINT("worker.MasterRemoteWorkerSCApi.QueryMetadata"); + RETURN_IF_NOT_OK_APPEND_MSG(masterWorkerApi->QueryMetadata(stream), "QueryMetadata send to worker failed"); + RETURN_IF_NOT_OK(stream->Write(req)); + Status rc; + do { + GetStreamMetadataRspPb rsp; + rc = stream->Read(rsp); + if (rc.IsOk()) { + metaRsp.emplace_back(rsp); + } + } while (rc.IsOk()); + CHECK_FAIL_RETURN_STATUS(rc.GetCode() == K_RPC_STREAM_END, rc.GetCode(), rc.GetMsg()); + LOG_IF_ERROR(stream->Finish(), "Closing of stream rpc failed in master"); + } + RETURN_IF_NOT_OK_APPEND_MSG(UpdateMetadata(metaRsp, workerAddr, streamsMayNeedAutoDelete, hashRanges), + "UpdateMetadata failed"); + LOG(INFO) << "Finish CheckMetadata with " << workerAddr; + return Status::OK(); +} + +Status SCMetadataManager::UpdateMetadata(std::vector &metaRsp, const HostPort &workerAddr, + std::unordered_set &streamsMayNeedAutoDelete, + const worker::HashRange &hashRanges) +{ + Status status; + std::vector receivedStreams; + for (const auto &meta : metaRsp) { + const auto &streamName = meta.stream_name(); + receivedStreams.push_back(streamName); + StatusCode code = static_cast(meta.error().error_code()); + if (code != StatusCode::K_OK) { + status = Status(code, meta.error().error_msg()); + LOG(ERROR) << "Master failed to receive meta for stream: " << streamName + << " from worker: " << workerAddr.ToString() << " with message: " << status.GetMsg(); + continue; + } + UpdateSetByCondition(streamName, meta.producers().empty() && meta.consumers().empty(), + streamsMayNeedAutoDelete); + TbbMetaHashmap::accessor accessor; + status = CreateOrGetStreamMetadata(streamName, accessor); + if (status.IsError()) { + continue; + } + StreamMetadata *metadata = accessor->second.get(); + status = metadata->CheckMetadata(meta, workerAddr); + } + // Clear metadata for streams not found in worker but master has metadata for worker. + // Fixme: It is too wasteful to traverse all the streams. These streams should be classified according to the + // granularity of the worker. + WriteLockHelper wlocker(LOCK_ARGS(metaDictMutex_)); + for (const auto &streamMetadata : streamMetaManagerDict_) { + const auto &streamName = streamMetadata.first; + if ((hashRanges.empty() || etcdCM_->IsInRange(hashRanges, streamName, "")) + && std::find(receivedStreams.begin(), receivedStreams.end(), streamName) == receivedStreams.end()) { + auto rc = streamMetadata.second->ClearWorkerMetadata(workerAddr, false, false); + LOG_IF_ERROR_EXCEPT(rc, + FormatString("ClearWorkerMetadata for stream[%s] on worker[%s] failed, detail: %s", + streamMetadata.first, workerAddr.ToString(), rc.ToString()), + K_NOT_FOUND); + if (rc.IsOk()) { + streamsMayNeedAutoDelete.insert(streamMetadata.first); + } + } + } + return status; +} + +void SCMetadataManager::CheckMetadataWithAsyncRetry(const HostPort &workerAddr, std::shared_ptr timer, + size_t retryTimes) +{ + const int64_t secToMs = 1000; + int64_t elapsedMs = static_cast(std::round(timer->ElapsedMilliSecond())); + int64_t remaining = static_cast(FLAGS_node_dead_timeout_s) * secToMs - elapsedMs; + if (remaining <= 0) { + LOG(WARNING) << FormatString("CheckMetadata with %s timeout.", workerAddr.ToString()); + return; + } + Status rc = CheckMetadata({ workerAddr }); + if (IsRpcTimeoutOrTryAgain(rc)) { + static std::vector retryDelaySec = { 1, 2, 4, 8, 16, 32, 64 }; + uint64_t delaySec = retryTimes < retryDelaySec.size() ? retryDelaySec[retryTimes] : retryDelaySec.back(); + uint64_t delayMs = std::min(remaining, delaySec * secToMs); + + auto traceID = Trace::Instance().GetTraceID(); + auto delayTask = [this, workerAddr, retryTimes, traceID, timer = std::move(timer), exitFlag = exitFlag_] { + if (exitFlag->load()) { + return; + } + asyncReconciliationPool_->Execute([this, workerAddr, retryTimes, traceID, timer = std::move(timer)] { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); + CheckMetadataWithAsyncRetry(workerAddr, std::move(timer), retryTimes + 1); + }); + }; + TimerQueue::TimerImpl timerImpl; + TimerQueue::GetInstance()->AddTimer(delayMs, delayTask, timerImpl); + } + if (rc.IsError()) { + LOG(WARNING) << FormatString("CheckMetadata with %s failed: %s", workerAddr.ToString(), rc.GetMsg()); + } else { + LOG(INFO) << FormatString("CheckMetadata done with %s ", workerAddr.ToString()); + } +} + +void SCMetadataManager::StartCheckMetadata(const HostPort &workerAddr) +{ + if (!EnableSCService()) { + return; + } + if (!asyncReconciliationPool_) { + LOG(WARNING) << "reconciliation pool not exist."; + return; + } + auto traceID = Trace::Instance().GetTraceID(); + asyncReconciliationPool_->Execute([this, workerAddr, traceID] { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); + LOG(INFO) << FormatString("CheckMetadata starts with %s ", workerAddr.ToString()); + auto timer = std::make_shared(); + CheckMetadataWithAsyncRetry(workerAddr, std::move(timer)); + }); +} + +void SCMetadataManager::StartClearWorkerMetadata(const HostPort &workerAddr) +{ + if (!EnableSCService()) { + return; + } + if (!asyncReconciliationPool_) { + LOG(WARNING) << "reconciliation pool not exist."; + return; + } + auto traceID = Trace::Instance().GetTraceID(); + asyncReconciliationPool_->Execute([this, workerAddr, traceID] { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); + Status rc = ClearWorkerMetadata(workerAddr, true); + if (rc.IsError()) { + LOG(ERROR) << FormatString("ClearWorkerMetadata with %s failed: %s", workerAddr.ToString(), rc.GetMsg()); + } + }); +} + +Status SCMetadataManager::RecoverMetadataOfFaultyWorker(const std::vector &workerUuids, + const worker::HashRange &extraRanges) +{ + if (extraRanges.empty()) { + LOG_IF(INFO, !workerUuids.empty()) << "Only supports RecoverMetadataOfFaultyWorker by hash ranges for SC"; + return Status::OK(); + } + LOG(INFO) << "Start RecoverMetadataOfFaultyWorker by ranges"; + auto func = [this](const worker::HashRange &extraRanges) { + std::vector nodeAddrs; + RETURN_IF_NOT_OK(etcdCM_->GetNodeAddrListFromEtcd(nodeAddrs)); + RETURN_IF_NOT_OK(CheckMetadata(nodeAddrs, extraRanges)); + return Status::OK(); + }; + LOG_IF_ERROR(func(extraRanges), "RecoverMetadataOfFaultyWorker failed"); + LOG(INFO) << "Finish RecoverMetadataOfFaultyWorker"; + return Status::OK(); +} + +Status SCMetadataManager::CheckWorkerStatus(const HostPort &workerHostPort) +{ + if (etcdCM_ == nullptr) { + RETURN_STATUS_LOG_ERROR(StatusCode::K_INVALID, "ETCD cluster manager is nullptr."); + } + auto rc = etcdCM_->CheckConnection(workerHostPort); + if (rc.IsError()) { + RETURN_STATUS_LOG_ERROR(K_WORKER_ABNORMAL, + FormatString("The Worker %s is abnormal.", workerHostPort.ToString())); + } + return rc; // Status is OK +} + +Status SCMetadataManager::VerifyStreamMode(StreamMode streamMode, size_t consumerNumAfterModify, + size_t producerNumAfterModify) +{ + CHECK_FAIL_RETURN_STATUS( + streamMode == StreamMode::MPMC || consumerNumAfterModify <= 1, K_INVALID, + FormatString("There can be at most one consumer in this stream mode: %d. Consumer num after modify: %zd", + static_cast(streamMode), consumerNumAfterModify)); + CHECK_FAIL_RETURN_STATUS( + streamMode != StreamMode::SPSC || producerNumAfterModify <= 1, K_INVALID, + FormatString("There can be at most one producer in this stream mode: %d. Producer num after modify: %zd", + static_cast(streamMode), producerNumAfterModify)); + return Status::OK(); +} + +bool SCMetadataManager::CheckSCMetaExist(const HostPort &workerAddr) +{ + std::vector> streamMetas = GetStreamMetaByWorkerAddr(workerAddr); + return !streamMetas.empty(); +} + +void SCMetadataManager::TriggerAutoDelActively(const std::unordered_set &streamsMayNeedAutoDelete) +{ + for (const auto &stream : streamsMayNeedAutoDelete) { + TbbMetaHashmap::accessor accessor; + auto rc = GetStreamMetadata(stream, accessor); + LOG_IF_ERROR_EXCEPT(rc, "TriggerAutoDelActively failed", K_NOT_FOUND); + if (rc.IsError()) { + continue; + } + StreamMetadata *metadata = accessor->second.get(); + LOG_IF_ERROR(metadata->AutoCleanupIfNeeded(HostPort()), "TriggerAutoDelActively failed"); + } +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/sc_metadata_manager.h b/src/datasystem/master/stream_cache/sc_metadata_manager.h new file mode 100644 index 0000000..6eb6a49 --- /dev/null +++ b/src/datasystem/master/stream_cache/sc_metadata_manager.h @@ -0,0 +1,432 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Module responsible for managing the stream cache metadata on the master. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_SC_METADATA_MANAGER_H +#define DATASYSTEM_MASTER_STREAM_CACHE_SC_METADATA_MANAGER_H + +#include + +#include + +#include "datasystem/common/log/log.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/master/metadata_redirect_helper.h" +#include "datasystem/master/stream_cache/sc_notify_worker_manager.h" +#include "datasystem/master/stream_cache/stream_metadata.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" + +namespace datasystem { +namespace master { +using TbbMetaHashmap = tbb::concurrent_hash_map>; + +class SCMetadataManager : public MetadataRedirectHelper { +public: + /** + * @brief Construct a new SCMetadataManager instance. + * @param[in] masterHostPort The master address. + * @param[in] akSkManager Used to do AK/SK authenticate. + * @param[in] rpcSessionManager Master to Worker session manager. + * @param[in] cm The etcd cluster manager instance. + * @param[in] rocksStore The rocks store instance. + * @param[in] dbName The db name. + */ + SCMetadataManager(const HostPort &masterHostPort, std::shared_ptr akSkManager, + std::shared_ptr rpcSessionManager, EtcdClusterManager *cm, + RocksStore *rocksStore, const std::string &dbName); + + /** + * @brief Shutdown the sc metadata manager module. + */ + void Shutdown() override; + + /** + * @brief WorkerOCServer uses the SetClusterManager method to directly pass the std::unique_ptr address of + * etcdCM_ to SCMetadataManager. If the etcdCM_ destructor releases the memory, a core dump occurs when the + * SCMetadataManager object that holds the pointer address operates the address. Therefore, the + * EtcdClusterManager needs to notify the SCMetadataManager before exiting. + */ + void SetClusterManagerToNullptr(); + + /** + * @brief Initialize SCMetadataManager. + * @return Status of the call. + */ + Status Init(); + + ~SCMetadataManager(); + + /** + * @brief Check if the metadata is available for given stream name. + * @param[in] streamName The stream name. + * @return Returns true if meta is found in meta table. + */ + bool MetaIsFound(const std::string &streamName) override; + + /** + * @brief Get streams that meet the meta conditions + * @param[in] matchFunc The conditions to meet. + * @param[out] streamNames Streams that meet the meta conditions. + * @param[in] exitEarly Whether to exit cycle early. + */ + void GetMetasMatch(std::function &&matchFunc, std::vector &streamNames, + bool *exitEarly = nullptr); + + /** + * @brief Saves metadata migrated from other masters. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return Status of the result. + */ + Status SaveMigrationMetadata(const MigrateSCMetadataReqPb &req, MigrateSCMetadataRspPb &rsp); + + /** + * @brief Save migration data. + * @param[in] streamMeta The stream meta. + * @param[out] rsp The rpc response protobuf. + * @param[out] rsp The rpc response protobuf. + * @return status of the call. + */ + Status SaveMigrationData(const MetaForSCMigrationPb &streamMeta, Status &status, MigrateSCMetadataRspPb &rsp); + + /** + * @brief Fill in the data to be migrated. + * @param[in] streamName The stream name to be midrated. + * @param[out] meta Data to be sent. + * @return Status of the result. + */ + Status FillMetadataForMigration(const std::string &streamName, MetaForSCMigrationPb *meta); + + /** + * @brief Handling data migration failed. + * @param[in] streamMeta Metadata to be migrated + */ + void HandleMetaDataMigrationFailed(const MetaForSCMigrationPb &streamMeta); + + /** + * @brief Delete the migrated metadata. + * @param[in] streamName The stream name to be deleted. + */ + void HandleMetaDataMigrationSuccess(const std::string &streamName); + + /** + * @brief Create a producer, i.e., register a publisher to a stream. Similar to worker::CreateProducer. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status CreateProducer(const CreateProducerReqPb &req, CreateProducerRspPb &rsp); + + /** + * @brief Close a producer, force flushing and page seal, unregister a publisher to a stream. + * Similar to worker::CloseProducer. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducer(const CloseProducerReqPb &req, CloseProducerRspPb &rsp); + + /** + * @brief Subscribe to a stream, using a subscription name, i.e., register a consumer to a subscription. + * Similar to worker::Subscribe. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status Subscribe(const SubscribeReqPb &req, SubscribeRspPb &rsp); + + /** + * @brief Close a consumer, trigger subscription cursor change and unregister a subscribed consumer to a stream. + * Similar to worker::CloseConsumer. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status CloseConsumer(const CloseConsumerReqPb &req, CloseConsumerRspPb &rsp); + + /** + * @brief Delete a stream. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status DeleteStream(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp); + + /** + * @brief Query global producers for a stream. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalProducersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp); + + /** + * @brief Query global consumers for a stream. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalConsumersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp); + + /** + * @brief Clear worker metadata. + * @param[in] workerAddr The worker address. + * @param[in] forceClose If the node had a crash or regular close + * @return K_OK on success; the error code otherwise. + */ + Status ClearWorkerMetadata(const HostPort &workerAddr, const bool forceClose = false); + + /** + * @brief Check metadata with worker. + * @param[in] workerAddrs The worker address. + * @param[in] hashRanges The optional hash ranges, for passive scale down recovery purposes. + * @return K_OK on success; the error code otherwise. + */ + Status CheckMetadata(const std::vector &workerAddrs, const worker::HashRange &hashRanges = {}); + + /** + * @brief Start check metadata with worker. + * @param[in] workerAddr The worker address. + */ + void StartCheckMetadata(const HostPort &workerAddr); + + /** + * @brief Start clear worker metadata. + * @param[in] workerAddr The worker address. + */ + void StartClearWorkerMetadata(const HostPort &workerAddr); + + /** + * @brief Get rocksdb name. + * @return std::string the rocksdb name. + */ + std::string GetDbName() + { + return dbName_; + } + + /** + * @brief Check meta table is empty; + * @return meta table is empty or not + */ + bool CheckMetaTableEmpty(); + + /** + * @brief Check if there is SCmeta in the worker node. + * @param[in] workerAddr The worker address n. + * @return true if exist, false if not exist. + */ + bool CheckSCMetaExist(const HostPort &workerAddr); + +private: + friend SCNotifyWorkerManager; + + /** + * @brief Check metadata with worker. + * @param[in] workerAddrs The worker address. + * @param[in] hashRanges The optional hash ranges, for passive scale down recovery purposes. + * @param[in/out] streamsMayNeedAutoDelete Streams that may need to be automatically deleted. + * @return K_OK on success; the error code otherwise. + */ + Status CheckMetadataImpl(const HostPort &workerAddr, const worker::HashRange &hashRanges, + std::unordered_set &streamsMayNeedAutoDelete); + + /** + * @brief Create stream metadata if not exist and add to the hashmap. + * @note If already exists or not the first to create, we will find the slot and get the accessor lock. + * @param[in] streamName Name of stream. + * @param[out] accessor Lock of stream entry. + */ + Status CreateOrGetStreamMetadata(const std::string &streamName, TbbMetaHashmap::accessor &accessor); + + /** + * @brief Create stream metadata if not exist and add to the hashmap. + * @param[in] streamName Name of stream. + */ + Status CreateStreamMetadata(const std::string &streamName); + + /** + * @brief Get stream metadata without lock, accessor can either be TbbMetaHashmap::accessor or + * TbbMetaHashmap::const_accessor. + * @param[in] streamName Name of stream. + * @param[out] accessor Lock of stream entry. + */ + template + Status GetStreamMetadataNoLock(const std::string &streamName, T &accessor) + { + if (!streamMetaManagerDict_.find(accessor, streamName)) { + RETURN_STATUS(K_NOT_FOUND, FormatString("stream [%s] not found", streamName)); + } + return Status::OK(); + } + + /** + * @brief Get stream metadata, accessor can either be TbbMetaHashmap::accessor or TbbMetaHashmap::const_accessor. + * @param[in] streamName Name of stream. + * @param[out] accessor Lock of stream entry. + */ + template + Status GetStreamMetadata(const std::string &streamName, T &accessor) + { + ReadLockHelper rlocker(LOCK_ARGS_MSG(metaDictMutex_, streamName)); + return GetStreamMetadataNoLock(streamName, accessor); + } + + /** + * @brief Get the stream metadata object by worker address. + * @param[in] workerAddr The worker address. + * @param[in] bool Whether to clear metadata. + * @return The stream metadata list. + */ + std::vector> GetStreamMetaByWorkerAddr(const HostPort &workerAddr, + bool clearMeta = false); + + /** + * @brief Get all stream metadata object. + * @param[out] streamMetas The stream metadata list. + * @return The stream metadata list. + */ + std::vector> GetAllStreamMeta(); + + /** + * @brief If the input rc is an error case, then set the response rc only if this is the first error. + * Then append the producer to the response pb for list of failed producers. No-op if the input rc is ok, + * but track the successful producer in the response success list. + * @param[in/out] firstError Indicator if this is the first time an error was hit. + * @param[in] rc The rc to process + * @param[in] info The current producer that was being closed + * @param[out] rsp The CloseProducer response proto to adjust if there was an error. + * @return true if an error was handled. false if there was no error in the input rc. + */ + bool HandleCloseProducerError(bool &firstError, const Status &rc, const ProducerInfoPb &info, + CloseProducerRspPb &rsp); + + /** + * @brief Get the log prefix. + * @return The log prefix. + */ + std::string LogPrefix() const; + + /** + * @brief Load stream meta from rocksdb store for last runtime into memory. + * @return Status of the call. + */ + Status LoadMeta(); + + /** + * @brief Update metadata in master received from worker. + * @param[in] metaRsp The vector of metadata responses received from worker + * @param[in] workerAddr The target worker address for reconciliation + * @param[in/out] streamsMayNeedAutoDelete Streams that may need to be automatically deleted. + * @param[in] hashRanges The optional hash ranges, for passive scale down recovery purposes. + * @return K_OK on success; the error code otherwise + */ + Status UpdateMetadata(std::vector &metaRsp, const HostPort &workerAddr, + std::unordered_set &streamsMayNeedAutoDelete, + const worker::HashRange &hashRanges); + + /** + * @brief Recover metadata of faulty worker from the other workers. + * @param[in] workerUuids The uuids to be recovered. + * @param[in] extraRanges The hash range of faulty worker. + * @return Status of the result. + */ + Status RecoverMetadataOfFaultyWorker(const std::vector &workerUuids, + const worker::HashRange &extraRanges); + + /** + * @brief Send Delete Stream Context to a worker + * @param[in] streamName The stream to be deleted. + * @param[in] workerNode Worker in which stream needs to be deleted. + * @return Status of the result. + */ + Status SendDeleteStreamReqToWorker(const std::string &streamName, const HostPort workerNode); + + /** + * @brief Send delete stream requests to all related workers + * @param[in] streamName The name of the stream getting deleted. + * @param[in] relatedWorkerSet The list of workers. + * @return Status of the result. + */ + Status SendDeleteStreamReq(const std::string &streamName, std::set &relatedWorkerSet); + + /** + * @brief Check worker status. + * @param[in] workerHostPort The target worker address. + * @return Status of the call. + */ + Status CheckWorkerStatus(const HostPort &workerHostPort); + + /** + * @brief Verify Stream Mode. + * @param[in] streamMode The stream mode. + * @param[in] consumerNumAfterModify The consumer number after modify. + * @param[in] producerNumAfterModify The producer number after modify. + * @return Status of the call. + */ + static Status VerifyStreamMode(StreamMode streamMode, size_t consumerNumAfterModify, size_t producerNumAfterModify); + + /** + * @brief Check metadata with worker and retry for rpc timeout. + * @param[in] workerAddr The worker address. + */ + void CheckMetadataWithAsyncRetry(const HostPort &workerAddr, std::shared_ptr timer, size_t retryTimes = 0); + + /** + * @brief Actively trigger automatic deletion + * @param[in] streamsMayNeedAutoDelete Streams that may need to be automatically deleted. + */ + void TriggerAutoDelActively(const std::unordered_set &streamsMayNeedAutoDelete); + + /** + * @brief Now that we have complete metadata, we need to do some cleanup. + * @param[in] streamsMayNeedAutoDelete Streams that may need to be automatically deleted. + */ + void PostCheckMetadata(const std::unordered_set &streamsMayNeedAutoDelete) + { + TriggerAutoDelActively(streamsMayNeedAutoDelete); + } + + mutable std::shared_timed_mutex metaDictMutex_; + + HostPort masterAddress_; + + std::unique_ptr notifyWorkerManager_; + + // key:streamName value:StreamMetadata pointer. + TbbMetaHashmap streamMetaManagerDict_; + + std::shared_ptr streamMetaStore_{ nullptr }; + + std::unique_ptr asyncReconciliationPool_{ nullptr }; + std::shared_ptr akSkManager_{ nullptr }; + std::shared_ptr rpcSessionManager_{ nullptr }; + std::shared_ptr exitFlag_; + + const std::string dbName_; + const std::string eventName_; +}; +} // namespace master +} // namespace datasystem + +#endif // DATASYSTEM_MASTER_STREAM_CACHE_SC_METADATA_MANAGER_H diff --git a/src/datasystem/master/stream_cache/sc_migrate_metadata_manager.cpp b/src/datasystem/master/stream_cache/sc_migrate_metadata_manager.cpp new file mode 100644 index 0000000..64af19e --- /dev/null +++ b/src/datasystem/master/stream_cache/sc_migrate_metadata_manager.cpp @@ -0,0 +1,340 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Migrating Data in Scaling Scenarios for Stream Cache. + */ +#include "datasystem/master/stream_cache/sc_migrate_metadata_manager.h" + +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/thread_local.h" +#include "datasystem/common/util/gflag/common_gflags.h" +#include "datasystem/master/stream_cache/sc_metadata_manager.h" +#include "datasystem/worker/hash_ring/hash_ring_event.h" + +DS_DECLARE_uint32(node_dead_timeout_s); + +namespace datasystem { +namespace master { +static constexpr int MOVE_THREAD_NUM = 4; +static constexpr int MAX_MIGRATE_CNT_PER_STREAM = 30; +MasterMasterSCApi::MasterMasterSCApi(const HostPort &hostPort, const HostPort &localHostPort, + std::shared_ptr akSkManager) + : destHostPort_(hostPort), localHostPort_(localHostPort), akSkManager_(std::move(akSkManager)) +{ +} + +Status MasterMasterSCApi::Init() +{ + RpcCredential cred; + RETURN_IF_NOT_OK(RpcAuthKeyManager::CreateCredentials(WORKER_SERVER_NAME, cred)); + auto channel = std::make_shared(destHostPort_, cred); + rpcSession_ = std::make_unique(channel); + LOG(INFO) << FormatString("start stream meta client: %s", destHostPort_.ToString()); + return Status::OK(); +} + +Status MasterMasterSCApi::MigrateSCMetadata(MigrateSCMetadataReqPb &req, MigrateSCMetadataRspPb &rsp) +{ + return rpcSession_->MigrateSCMetadata(req, rsp); +} + +SCMigrateMetadataManager &SCMigrateMetadataManager::Instance() +{ + static SCMigrateMetadataManager instance; + return instance; +} + +Status SCMigrateMetadataManager::Init(const HostPort &localHostPort, std::shared_ptr akSkManager, + EtcdClusterManager *cm, ReplicaManager *replicaManager) +{ + localHostPort_ = localHostPort; + akSkManager_ = std::move(akSkManager); + cm_ = cm; + threadPool_ = std::make_unique(0, MOVE_THREAD_NUM, "ScMigrateMetadata"); + replicaManager_ = replicaManager; + + HashRingEvent::MigrateRanges::GetInstance().AddSubscriber( + "SCMigrateMetadataManager", + [this](const std::string &dbName, const std::string &dest, const std::string &destDbName, + const worker::HashRange &ranges, bool isNetworkRecovery) { + return MigrateByRanges(dbName, dest, destDbName, ranges, isNetworkRecovery); + }); + return Status::OK(); +} + +SCMigrateMetadataManager::~SCMigrateMetadataManager() +{ + Shutdown(); + LOG(INFO) << "~SCMigrateMetadataManager"; +} + +void SCMigrateMetadataManager::Shutdown() +{ + exitFlag_ = true; + HashRingEvent::MigrateRanges::GetInstance().RemoveSubscriber("SCMigrateMetadataManager"); + cm_ = nullptr; +} + +Status SCMigrateMetadataManager::MigrateByRanges(const std::string &dbName, const std::string &dest, + const std::string &destDbName, const worker::HashRange &ranges, + bool isNetworkRecovery) +{ + CHECK_FAIL_RETURN_STATUS(cm_ != nullptr, K_RUNTIME_ERROR, "SCMigrateMetadataManager has not inited."); + + std::shared_ptr scMetadataManager; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(replicaManager_->GetScMetadataManager(dbName, scMetadataManager), + "dbName not exists"); + MigrateMetaInfo info; + info.destAddr = dest; + info.destDbName = destDbName; + scMetadataManager->GetMetasMatch( + [this, &ranges](const std::string &objKey) { return cm_->IsInRange(ranges, objKey, ""); }, info.streamNames); + + return MigrateMetaDataWithRetry(scMetadataManager, info, isNetworkRecovery); +} + +void SCMigrateMetadataManager::HandleMigrationFailed( + const std::shared_ptr &scMetadataManager, MigrateMetaInfo &info, + std::unordered_map &retryCounter) +{ + info.streamNames.clear(); + for (const auto &stream : info.failedStreamNames) { + if (++retryCounter[stream] > MAX_MIGRATE_CNT_PER_STREAM) { + scMetadataManager->HandleMetaDataMigrationSuccess(stream); + LOG(WARNING) << "Stream " << stream << " abandoned after multiple failed migration retries"; + } else { + info.streamNames.emplace_back(stream); + } + } + return; +} + +Status SCMigrateMetadataManager::MigrateMetaDataWithRetry( + const std::shared_ptr &scMetadataManager, MigrateMetaInfo &info, bool isNetworkRecovery) +{ + int timeInterval = 500; + INJECT_POINT("SCMigrateMetadataManager.MigrateMetaDataWithRetry.interval", [&timeInterval] (int interval) { + timeInterval = interval; + return Status::OK(); + }); + Status status; + Timer timer; + HostPort destAddr; + RETURN_IF_NOT_OK(destAddr.ParseString(info.destAddr)); + std::unordered_map retryCounter; + Raii clean([&scMetadataManager, &info]() { scMetadataManager->CleanMigratingItems(info.streamNames); }); + + while (!exitFlag_) { + if ((!isNetworkRecovery && cm_->CheckConnection(destAddr).IsError()) + || (isNetworkRecovery && timer.ElapsedSecond() > FLAGS_node_timeout_s)) { + break; + } + status = MigrateMetaData(scMetadataManager, info); + if (status.IsError() && timer.ElapsedSecond() < FLAGS_node_dead_timeout_s) { + std::this_thread::sleep_for(std::chrono::milliseconds(timeInterval)); + continue; + } + + if (!info.failedStreamNames.empty()) { + HandleMigrationFailed(scMetadataManager, info, retryCounter); + if (!info.streamNames.empty()) { + std::this_thread::sleep_for(std::chrono::milliseconds(timeInterval)); + continue; + } + } + info.streamNames.clear(); + LOG(INFO) << "Migrate to " << info.destAddr << " success."; + return Status::OK(); + } + + return Status(K_RPC_UNAVAILABLE, + FormatString("LastStatus: %s. The connection to %s is %u. Unfinished stream size %u. " + "Time elapsed %d seconds. isNetworkRecovery %s", + status.ToString(), info.destAddr, cm_->CheckConnection(destAddr).IsOk(), + info.streamNames.size(), timer.ElapsedSecond(), isNetworkRecovery)); +} + +Status SCMigrateMetadataManager::MigrateMetaData(const std::shared_ptr &scMetadataManager, + MigrateMetaInfo &info) +{ + auto status = StartMigrateMetadataForScaleout(scMetadataManager, info); + if (status.IsError()) { + LOG(ERROR) << "Submit migrate task failed: " << status.GetMsg(); + return status; + } + auto dbName = scMetadataManager->GetDbName(); + status = GetMigrateMetadataResult(dbName, info.destAddr, info.failedStreamNames); + if (status.IsError()) { + LOG(ERROR) << "GetMigrateMetadataResult failed. " << status.GetMsg(); + } + return status; +} + +Status SCMigrateMetadataManager::StartMigrateMetadataForScaleout( + const std::shared_ptr &scMetadataManager, MigrateMetaInfo &info) +{ + auto futureKey = std::make_pair(info.destAddr, scMetadataManager->GetDbName()); + TbbFutureThreadTable::accessor accessor; + if (futureThread_.find(accessor, futureKey)) { + RETURN_STATUS( + StatusCode::K_TRY_AGAIN, + FormatString("The destination address[%s] has unfinished tasks. Please try again later.", info.destAddr)); + } else { + auto traceId = Trace::Instance().GetTraceID(); + std::future>> future = + threadPool_->Submit([this, scMetadataManager, &info, traceId] { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + return AsyncMigrateMetadata(scMetadataManager, info); + }); + futureThread_.emplace(accessor, futureKey, std::move(future)); + } + return Status::OK(); +} + +Status SCMigrateMetadataManager::GetMigrateMetadataResult(const std::string &dbName, const std::string &destination, + std::vector &failedStreams) +{ + auto futureKey = std::make_pair(destination, dbName); + TbbFutureThreadTable::accessor accessor; + auto found = futureThread_.find(accessor, futureKey); + CHECK_FAIL_RETURN_STATUS(found, StatusCode::K_RUNTIME_ERROR, "Can't find async future."); + accessor->second.wait(); + auto result = accessor->second.get(); + futureThread_.erase(accessor); + failedStreams = result.second; + return result.first; +} + +std::pair> SCMigrateMetadataManager::AsyncMigrateMetadata( + const std::shared_ptr &scMetadataManager, MigrateMetaInfo &info) +{ + LOG(INFO) << "Start migrate metadata. destination:" << info.destAddr + << ", source dbName:" << scMetadataManager->GetDbName() << ", dest dbName:" << info.destDbName + << ", stream count:" << info.streamNames.size(); + + std::unique_ptr api; + auto CreateApi = [this, &info, &api]() -> Status { + HostPort dest; + RETURN_IF_NOT_OK(dest.ParseString(info.destAddr)); + api = std::make_unique(dest, localHostPort_, akSkManager_); + RETURN_IF_NOT_OK(api->Init()); + g_MetaRocksDbName = info.destDbName; + return Status::OK(); + }; + + Status s = CreateApi(); + if (s.IsError()) { + return make_pair(s, info.streamNames); + } + + std::vector failedStreams; + s = MigrateMetadataForScaleout(scMetadataManager, api, info.streamNames, failedStreams); + LOG(INFO) << "Final migrate metadata. destination: " << info.destAddr + << ", source dbName:" << scMetadataManager->GetDbName() << ", dest dbName:" << info.destDbName + << ", stream count: " << info.streamNames.size() << ", failed stream count: " << failedStreams.size() + << ", status: " << s.ToString(); + return make_pair(s, failedStreams); +} + +Status SCMigrateMetadataManager::BatchMigrateMetadata( + const std::shared_ptr &scMetadataManager, std::unique_ptr &api, + MigrateSCMetadataReqPb &req, std::vector &failedStreams) +{ + INJECT_POINT("BatchMigrateMetadata.delay", [](uint32_t delay_s) { + sleep(delay_s); + return Status::OK(); + }); + + MigrateSCMetadataRspPb rsp; + auto streamSendData = [this, &api, &req, &rsp]() -> Status { + auto copyReq = req; + for (int i = 0; i < copyReq.stream_metas_size(); ++i) { + auto *meta = copyReq.mutable_stream_metas(i); + if (meta != nullptr) { + meta->clear_notifications(); + } + } + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(copyReq)); + req.set_access_key(copyReq.access_key()); + req.set_timestamp(copyReq.timestamp()); + req.set_signature(copyReq.signature()); + RETURN_IF_NOT_OK(api->MigrateSCMetadata(req, rsp)); + return Status::OK(); + }; + + Status s = streamSendData(); + if (s.IsError()) { + LOG(WARNING) << "Send metadata for migration failed. s=" << s.ToString(); + for (const auto &meta : req.stream_metas()) { + scMetadataManager->HandleMetaDataMigrationFailed(meta); + failedStreams.emplace_back(meta.meta().stream_name()); + } + return s; + } else { + int num = 0; + for (auto &result : rsp.results()) { + auto &meta = req.stream_metas()[num]; + if (result == MigrateSCMetadataRspPb::SUCCESSFUL) { + scMetadataManager->HandleMetaDataMigrationSuccess(meta.meta().stream_name()); + } else { + scMetadataManager->HandleMetaDataMigrationFailed(meta); + failedStreams.emplace_back(meta.meta().stream_name()); + } + ++num; + } + } + INJECT_POINT("BatchMigrateMetadata.finish"); + return Status::OK(); +} + +Status SCMigrateMetadataManager::MigrateMetadataForScaleout( + const std::shared_ptr &scMetadataManager, std::unique_ptr &api, + const std::vector &streamNames, std::vector &failedStreams) +{ + MigrateSCMetadataReqPb req; + uint32_t objBatch = 300; // Comparison test: The performance is optimal when the batch number is 300. + uint32_t count = 0; + req.set_source_addr(localHostPort_.ToString()); + Status lastRc; + for (auto &streamName : streamNames) { + req.set_source_addr(localHostPort_.ToString()); + Status s = scMetadataManager->FillMetadataForMigration(streamName, req.add_stream_metas()); + if (s.IsError()) { + LOG(WARNING) << "Fill metadata for migration failed. s=" << s.ToString(); + req.mutable_stream_metas()->RemoveLast(); + scMetadataManager->CleanMigratingItems({ streamName }); + continue; + } + ++count; + if (count >= objBatch) { + auto rc = BatchMigrateMetadata(scMetadataManager, api, req, failedStreams); + lastRc = rc.IsError() ? rc : lastRc; + req.Clear(); + req.set_source_addr(localHostPort_.ToString()); + count = 0; + } + } + + if (count > 0) { + auto rc = BatchMigrateMetadata(scMetadataManager, api, req, failedStreams); + lastRc = rc.IsError() ? rc : lastRc; + } + return lastRc; +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/sc_migrate_metadata_manager.h b/src/datasystem/master/stream_cache/sc_migrate_metadata_manager.h new file mode 100644 index 0000000..10d043e --- /dev/null +++ b/src/datasystem/master/stream_cache/sc_migrate_metadata_manager.h @@ -0,0 +1,229 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Migrating Data in Scaling Scenarios for Stream Cache. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_SC_MIGRATE_METADATA_MANAGER_H +#define DATASYSTEM_MASTER_STREAM_CACHE_SC_MIGRATE_METADATA_MANAGER_H + +#include +#include +#include +#include + +#include + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/common/rpc/rpc_constants.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/master/replica_manager.h" +#include "datasystem/protos/master_stream.stub.rpc.pb.h" +#include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" + +namespace datasystem { +#ifdef WITH_TESTS +namespace ut { +class SCMigrateMetadataManagerTest; +} +#endif + +namespace master { + +class MasterMasterSCApi { +public: + /** + * @brief Constructor for the remote version of the api + * @param[in] hostPort The host port of the target master + * @param[in] localHostPort The local worker rpc service host port. + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + MasterMasterSCApi(const HostPort &hostPort, const HostPort &localHostPort, + std::shared_ptr akSkManager); + + ~MasterMasterSCApi() = default; + + /** + * @brief Initialization. + * @return Status of the call. + */ + Status Init(); + + /** + * @brief Migrate the metadata + * @param[in] req The rpc request protobuf. + * @param[out] resp The rpc request protobuf. + */ + Status MigrateSCMetadata(MigrateSCMetadataReqPb &req, MigrateSCMetadataRspPb &rsp); + +private: + HostPort destHostPort_; // The HostPort of the destination node + HostPort localHostPort_; // The HostPort of the local node + std::shared_ptr akSkManager_; + std::unique_ptr rpcSession_{ nullptr }; // session to the master rpc service +}; + +using TbbFutureThreadTable = tbb::concurrent_hash_map, + std::future>>>; + +class SCMigrateMetadataManager { +public: + struct MigrateMetaInfo { + std::vector streamNames; + std::vector failedStreamNames; + std::string destAddr; + std::string destDbName; + }; + SCMigrateMetadataManager(const SCMigrateMetadataManager &other) = delete; + SCMigrateMetadataManager(SCMigrateMetadataManager &&other) = delete; + SCMigrateMetadataManager &operator=(const SCMigrateMetadataManager &) = delete; + SCMigrateMetadataManager &operator=(SCMigrateMetadataManager &&) = delete; + + /** + * @brief Singleton mode, obtaining instance. + * @return SCMigrateMetadataManager reference. + */ + static SCMigrateMetadataManager &Instance(); + + ~SCMigrateMetadataManager(); + + /** + * @brief Initialization. + * @param[in] localHostPort The local worker rpc service host port. + * @param[in] akSkManager Used to do AK/SK authenticate. + * @param[in] cm Used to get master of objects. + * @param[in] replicaManager The replica manager. + * @return Status of the call. + */ + Status Init(const HostPort &localHostPort, std::shared_ptr akSkManager, EtcdClusterManager *cm, + ReplicaManager *replicaManager); + + /** + * @brief Shutdown the oc migrage metadata module. + */ + void Shutdown(); + + /** + * @brief Migrate data by hash range + * @param[in] dbName The rocksdb name. + * @param[in] dest Destination address of the migration. + * @param[in] destDbName The dest rocksdb name. + * @param[in] ranges The stream names to be migrated. + * @param[in] isNetworkRecovery True if under network recovery scenario. + * @return Status of the call. + */ + Status MigrateByRanges(const std::string &dbName, const std::string &dest, const std::string &destDbName, + const worker::HashRange &ranges, bool isNetworkRecovery); + + /** + * @brief Starting Data Migration + * @param[in] scMetadataManager The SCMetadataManager instance. + * @param[in] info The migrate meta info. + * @return Status of the call. + */ + Status StartMigrateMetadataForScaleout(const std::shared_ptr &scMetadataManager, + MigrateMetaInfo &info); + + /** + * @brief Obtaining the Data Migration Result + * @param[in] dbName The rocksdb name. + * @param[in] destination Destination address of the migration. + * @param[out] failedStreams Failed stream names. + * @return Status of the call. + */ + Status GetMigrateMetadataResult(const std::string &dbName, const std::string &destination, + std::vector &failedStreams); + +private: + SCMigrateMetadataManager() = default; + + /** + * @brief Migrate meta with retry on error + * @param[in] scMetadataManager The SCMetadataManager instance. + * @param[in] info The migrate meta info. + * @param[in] isNetworkRecovery True if under network recovery scenario. + * @return Status of the call. + */ + Status MigrateMetaDataWithRetry(const std::shared_ptr &scMetadataManager, + MigrateMetaInfo &info, bool isNetworkRecovery); + + /** + * @brief Migrate meta + * @param[in] scMetadataManager The SCMetadataManager instance. + * @param[in] info The migrate meta info. + * @return Status of the call. + */ + Status MigrateMetaData(const std::shared_ptr &scMetadataManager, MigrateMetaInfo &info); + + /** + * @brief Performing Async Data Migration + * @param[in] scMetadataManager The SCMetadataManager instance. + * @param[in] info The migrate meta info. + * @return Status of the call and failed stream names. + */ + std::pair> AsyncMigrateMetadata( + const std::shared_ptr &scMetadataManager, MigrateMetaInfo &info); + + /** + * @brief Migrate metadata. + * @param[in] scMetadataManager The SCMetadataManager instance. + * @param[in] api Rpc channel for send data + * @param[in] streamNames The stream names to be migrated. + * @param[out] failedStreams Failed stream names. + * @return Status of the call. + */ + Status MigrateMetadataForScaleout(const std::shared_ptr &scMetadataManager, + std::unique_ptr &api, + const std::vector &streamNames, + std::vector &failedStreams); + + /** + * @brief Migrating Data in Batches + * @param[in] scMetadataManager The SCMetadataManager instance. + * @param[in] api Rpc channel for send data + * @param[in] req The request message for data migration. + * @param[out] failedStreams Failed stream names. + * @return Status of the call. + */ + Status BatchMigrateMetadata(const std::shared_ptr &scMetadataManager, + std::unique_ptr &api, MigrateSCMetadataReqPb &req, + std::vector &failedStreams); + + /** + * @brief Handle migration failed streams. + * @param[in] scMetadataManager The SCMetadataManager instance. + * @param[in] info The migrate meta info. + * @param[in] retryCounter The retry counter. + */ + void HandleMigrationFailed(const std::shared_ptr &scMetadataManager, + MigrateMetaInfo &info, std::unordered_map &retryCounter); + +#ifdef WITH_TESTS + friend class ::datasystem::ut::SCMigrateMetadataManagerTest; +#endif + + HostPort localHostPort_; + std::shared_ptr akSkManager_; + EtcdClusterManager *cm_{ nullptr }; + std::unique_ptr threadPool_; + // tbb::concurrent_hash_map>> + TbbFutureThreadTable futureThread_; + std::atomic exitFlag_{ false }; + ReplicaManager *replicaManager_; +}; +} // namespace master +} // namespace datasystem +#endif // DATASYSTEM_MASTER_STREAM_CACHE_SC_MIGRATE_METADATA_MANAGER_H diff --git a/src/datasystem/master/stream_cache/sc_notify_worker_manager.cpp b/src/datasystem/master/stream_cache/sc_notify_worker_manager.cpp new file mode 100644 index 0000000..daaac53 --- /dev/null +++ b/src/datasystem/master/stream_cache/sc_notify_worker_manager.cpp @@ -0,0 +1,777 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Managing notifications sent to workers. + */ + +#include "datasystem/master/stream_cache/sc_notify_worker_manager.h" + +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/log/log_helper.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/master/stream_cache/sc_metadata_manager.h" +#include "datasystem/master/stream_cache/sc_metadata_manager.h" +#include "datasystem/master/stream_cache/stream_metadata.h" +#include "datasystem/protos/master_stream.pb.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" + +namespace datasystem { +namespace master { +const size_t ASYNC_NOTIFY_THREAD_NUM = 8; +const size_t DELETE_STREAM_THREAD_NUM = 8; +SCNotifyWorkerManager::SCNotifyWorkerManager(std::shared_ptr streamMetaStore, + std::shared_ptr akSkManager, + std::shared_ptr rpcSessionManager, + EtcdClusterManager *cm, SCMetadataManager *scMetadataManager) + : streamMetaStore_(std::move(streamMetaStore)), + akSkManager_(std::move(akSkManager)), + rpcSessionManager_(std::move(rpcSessionManager)), + etcdCM_(cm), + scMetadataManager_(scMetadataManager) +{ +} + +SCNotifyWorkerManager::~SCNotifyWorkerManager() +{ + LOG(INFO) << "Destroy SCNotifyWorkerManager."; + if (!interruptFlag_) { + Shutdown(); + } +} + +Status SCNotifyWorkerManager::Init() +{ + LOG(INFO) << "Init SCNotifyWorkerManager"; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(RecoverNotification(), "Recover notification for rocksdb failed."); + notifyThreadPool_ = std::make_unique(1, ASYNC_NOTIFY_THREAD_NUM, "ScNotify"); + notifyFut_ = notifyThreadPool_->Submit(&SCNotifyWorkerManager::ProcessAsyncNotify, this); + deleteThreadPool_ = std::make_unique(1, DELETE_STREAM_THREAD_NUM, "ScDelete"); + deleteFut_ = deleteThreadPool_->Submit(&SCNotifyWorkerManager::ProcessDeleteStreams, this); + return Status::OK(); +} + +void SCNotifyWorkerManager::Shutdown() +{ + LOG(INFO) << "SCNotifyWorkerManager shutdown."; + if (interruptFlag_.exchange(true)) { + return; + } + cvLock_.Set(); + WARN_IF_ERROR(notifyFut_.get(), ""); + WARN_IF_ERROR(deleteFut_.get(), ""); + notifyThreadPool_.reset(); + deleteThreadPool_.reset(); +} + +Status SCNotifyWorkerManager::ProcessAsyncNotify() +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + LOG(INFO) << "Starting async notify thread."; + const size_t numNotifyThread = ASYNC_NOTIFY_THREAD_NUM - 1; + while (!interruptFlag_) { + INJECT_POINT("master.ProcessAsyncNotify"); + // Group notifications by stream name and partition them to the threads. + std::unordered_map streamNameToPartitionNum; + std::vector>> parts(numNotifyThread); + size_t index = 0; + { + std::lock_guard locker(notifyMutex_); + for (const auto &kv : notifyWorkerMap_) { + for (auto const &item : kv.second) { + auto &streamName = item.first; + size_t partitionNum; + auto it = streamNameToPartitionNum.find(streamName); + if (it == streamNameToPartitionNum.end()) { + partitionNum = index % numNotifyThread; + streamNameToPartitionNum.emplace(streamName, partitionNum); + index += 1; + } else { + partitionNum = it->second; + } + // worker addr and stream name. + parts[partitionNum].emplace_back(kv.first, streamName); + } + } + } + + std::vector> futures; + for (auto &streamList : parts) { + if (!streamList.empty()) { + futures.emplace_back( + notifyThreadPool_->Submit([this, &streamList] { return SendPendingNotification(streamList); })); + } + } + + for (auto &fut : futures) { + LOG_IF_ERROR(fut.get(), "SendPendingNotification failed"); + } + cvLock_.WaitFor(ASYNC_NOTIFY_TIME_MS); + } + LOG(INFO) << "Terminating async notify thread."; + return Status::OK(); +} + +Status SCNotifyWorkerManager::ProcessDeleteStreams() +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + LOG(INFO) << "Starting async delete streams thread."; + const size_t numDeleteThread = DELETE_STREAM_THREAD_NUM - 1; + while (!interruptFlag_) { + INJECT_POINT("master.ProcessDeleteStreams"); + // Get the pending delete stream list. + std::set deleteStreamList; + { + std::lock_guard locker(deleteMutex_); + deleteStreamList.swap(pendingDeleteStreams_); + } + if (deleteStreamList.empty()) { + cvLock_.WaitFor(ASYNC_NOTIFY_TIME_MS); + continue; + } + + // partition based on the number of delete stream threads + size_t index = 0; + std::vector> parts(numDeleteThread); + for (auto &streamName : deleteStreamList) { + parts[index % numDeleteThread].insert(streamName); + index += 1; + } + + // submit to thread pool. + std::vector> futures; + for (auto &streams : parts) { + if (!streams.empty()) { + futures.emplace_back(deleteThreadPool_->Submit([this, &streams] { return DeleteStreams(streams); })); + } + } + for (auto &fut : futures) { + LOG_IF_ERROR(fut.get(), "DeleteStreams failed"); + } + cvLock_.WaitFor(ASYNC_NOTIFY_TIME_MS); + } + LOG(INFO) << "Terminating async delete streams thread."; + return Status::OK(); +} + +Status SCNotifyWorkerManager::NotifyNewPubNode(const HostPort &workerAddr, const std::string &streamName, + const StreamFields &streamFields, const HostPort &srcWorkerAddr) +{ + bool forceClose = false; + bool asyncMode = false; + return NotifyPubNodeImpl(workerAddr, streamName, streamFields, srcWorkerAddr, false, forceClose, asyncMode); +} + +Status SCNotifyWorkerManager::NotifyDelPubNode(const HostPort &workerAddr, const std::string &streamName, + const HostPort &srcWorkerAddr, const bool forceClose) +{ + // Sending an async updateTopo + bool asyncMode = true; + StreamFields streamFields(0, 0, false, 0, false, 0, StreamMode::MPMC); + return NotifyPubNodeImpl(workerAddr, streamName, streamFields, srcWorkerAddr, true, forceClose, asyncMode); +} + +Status SCNotifyWorkerManager::NotifyNewConsumer(const HostPort &workerAddr, const ConsumerMetaPb &consumerMeta, + const RetainDataState::State retainData) +{ + bool asyncMode = false; + return NotifyConsumerImpl(workerAddr, consumerMeta, false, retainData, asyncMode); +} + +Status SCNotifyWorkerManager::NotifyDelConsumer(const HostPort &workerAddr, const ConsumerMetaPb &consumerMeta) +{ + // Sending an async updateTopo + bool asyncMode = true; + return NotifyConsumerImpl(workerAddr, consumerMeta, true, RetainDataState::State::INIT, asyncMode); +} + +Status SCNotifyWorkerManager::ClearPendingNotification(const std::string &workerAddress) +{ + LOG(INFO) << "Clear pending notification send to " << workerAddress; + std::shared_lock locker(notifyMutex_); + TbbNotifyWorkerMap::accessor accessor; + if (notifyWorkerMap_.find(accessor, workerAddress)) { + RETURN_IF_NOT_OK(streamMetaStore_->RemoveNotificationByWorker(workerAddress)); + CHECK_FAIL_RETURN_STATUS(notifyWorkerMap_.erase(accessor) == true, K_RUNTIME_ERROR, + "erase worker address failed"); + } else { + LOG(INFO) << "Not exists pending notification for " << workerAddress; + } + return Status::OK(); +} + +Status SCNotifyWorkerManager::SendNotification(const HostPort &workerAddr, UpdateTopoNotificationReq &req) +{ + std::shared_ptr masterWorkerApi = nullptr; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Stream:<%s>, Dest:<%s>, UpdateTopoNotification req: %s", + req.stream_name(), workerAddr.ToString(), + LogHelper::IgnoreSensitive(req)); + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(workerAddr, masterWorkerApi, akSkManager_)); + Status status = masterWorkerApi->UpdateTopoNotification(req); + LOG(INFO) << FormatString("Stream:<%s>, Dest:<%s>, UpdateTopoNotification result: %s", req.stream_name(), + workerAddr.ToString(), status.GetMsg()); + return status; +} + +Status SCNotifyWorkerManager::SendPendingNotification(std::vector> &streamList) +{ + INJECT_POINT("master.SendPendingNotification"); + for (const auto &item : streamList) { + auto traceGuard = Trace::Instance().SetTraceNewID(GetStringUuid() + "-sync"); + const auto &workerAddr = item.first; + const auto &streamName = item.second; + LOG_IF_ERROR(SendPendingNotificationForStream(workerAddr, streamName), + FormatString("SendPendingNotification to %s for stream %s failed", workerAddr, streamName)); + } + return Status::OK(); +}; + +Status SCNotifyWorkerManager::SendPendingNotificationForStream(const std::string &workerAddr, + const std::string &streamName) +{ + Status checkRc = CheckWorkerStatus(workerAddr); + // skip notify if target worker not exists in cluster. + bool skipNotify = checkRc.GetCode() == K_NOT_FOUND; + RETURN_OK_IF_TRUE(checkRc.IsError() && !skipNotify); + + TbbMetaHashmap::accessor streamAccessor; + Status status = scMetadataManager_->GetStreamMetadata(streamName, streamAccessor); + if (status.IsError()) { + LOG(INFO) << "GetStreamMetadata failed:" << status.GetMsg(); + skipNotify = true; + } + + std::shared_lock locker(notifyMutex_); + TbbNotifyWorkerMap::accessor accessor; + auto task = GetPendingNotification(workerAddr, streamName, accessor); + RETURN_OK_IF_TRUE(task == nullptr); + + if (skipNotify || task->Empty()) { + LOG(INFO) << "Skip send notify to " << workerAddr << " for stream " << streamName; + return RemoveAsyncNotification(accessor, streamName, task); + } + + UpdateTopoNotificationReq req; + task->ConstructRequest(req); + HostPort address; + RETURN_IF_NOT_OK(address.ParseString(workerAddr)); + status = SendNotification(address, req); + if (status.IsError()) { + LOG(WARNING) << "SendNotification failed:" << status.GetMsg(); + if (IsRpcTimeout(status)) { + return Status::OK(); + } + } + return RemoveAsyncNotification(accessor, streamName, task); +} + +Status SCNotifyWorkerManager::RemoveAsyncNotification(TbbNotifyWorkerMap::accessor &accessor, + const std::string &streamName, + std::shared_ptr task) +{ + for (const auto &kv : task->pubs) { + RETURN_IF_NOT_OK(streamMetaStore_->RemoveNotifyPub(accessor->first, kv.second)); + } + + for (const auto &kv : task->subs) { + RETURN_IF_NOT_OK(streamMetaStore_->RemoveNotifySub(accessor->first, kv.second)); + } + (void)accessor->second.erase(streamName); + return Status::OK(); +} + +Status SCNotifyWorkerManager::NotifyPubNodeImpl(const HostPort &workerAddr, const std::string &streamName, + const StreamFields &streamFields, const HostPort &srcWorkerAddr, + bool isClose, const bool forceClose, bool asyncMode) +{ + NotifyPubPb pub; + pub.set_is_close(isClose); + pub.set_force_close(forceClose); + pub.set_stream_name(streamName); + pub.set_worker_addr(srcWorkerAddr.ToString()); + pub.set_max_stream_size(streamFields.maxStreamSize_); + pub.set_page_size(streamFields.pageSize_); + pub.set_auto_cleanup(streamFields.autoCleanup_); + pub.set_retain_num_consumer(streamFields.retainForNumConsumers_); + pub.set_encrypt_stream(streamFields.encryptStream_); + pub.set_reserve_size(streamFields.reserveSize_); + pub.set_stream_mode(streamFields.streamMode_); + + Status status = CheckWorkerStatus(workerAddr.ToString()); + INJECT_POINT_NO_RETURN("SCNotifyWorkerManager.ForceAsyncNotification", + [&status]() { status = Status(StatusCode::K_RPC_UNAVAILABLE, ""); }); + if (status.IsError()) { + LOG(WARNING) << "Worker abnormal, async send notification. " << status.GetMsg(); + return AddAsyncPubNotification(workerAddr.ToString(), pub); + } + + { + std::shared_lock locker(notifyMutex_); + TbbNotifyWorkerMap::accessor accessor; + auto task = GetPendingNotification(workerAddr.ToString(), streamName, accessor); + bool exists = false; + RETURN_IF_NOT_OK(HandleExistsPubNotification(task, srcWorkerAddr.ToString(), isClose, true, exists)); + if (exists) { + LOG(INFO) << FormatString("Exists notification send to %s, NotifyPubPb: %s", workerAddr.ToString(), + LogHelper::IgnoreSensitive(pub)); + return Status::OK(); + } + } + + // If async mode is set, just enqueue the request + if (asyncMode) { + return AddAsyncPubNotification(workerAddr.ToString(), pub); + } + + UpdateTopoNotificationReq req; + req.set_stream_name(streamName); + *req.add_pubs() = pub; + status = SendNotification(workerAddr, req); + if (IsRpcTimeout(status)) { + LOG(WARNING) << "RPC timeout, async send notification. " << status.GetMsg(); + return AddAsyncPubNotification(workerAddr.ToString(), pub); + } + return status; +} + +Status SCNotifyWorkerManager::AddAsyncStopDataRetentionNotification(const HostPort &workerAddr, + const std::string &streamName) +{ + std::shared_lock locker(notifyMutex_); + TbbNotifyWorkerMap::accessor accessor; + auto task = GetOrCreatePendingNotification(workerAddr.ToString(), streamName, accessor); + CHECK_FAIL_RETURN_STATUS(task != nullptr, K_RUNTIME_ERROR, "task is null"); + + // If its same return okay + if (task->retainData == RetainDataState::State::NOT_RETAIN) { + return Status::OK(); + } else { + // Else change state and store it in meta store + task->retainData = RetainDataState::State::NOT_RETAIN; + } + + LOG(INFO) << FormatString("Stream:<%s>, Dest:<%s>, AsyncStopRetentionState", streamName, workerAddr.ToString()); + return Status::OK(); +} + +Status SCNotifyWorkerManager::NotifyConsumerImpl(const HostPort &workerAddr, const ConsumerMetaPb &consumerMeta, + bool isClose, RetainDataState::State retainData, bool asyncMode) +{ + NotifyConsumerPb sub; + *sub.mutable_consumer() = consumerMeta; + sub.set_is_close(isClose); + + Status status = CheckWorkerStatus(workerAddr.ToString()); + INJECT_POINT_NO_RETURN("SCNotifyWorkerManager.ForceAsyncNotification", + [&status]() { status = Status(StatusCode::K_RPC_UNAVAILABLE, ""); }); + if (status.IsError()) { + LOG(WARNING) << "Worker abnormal, async send notification. " << status.GetMsg(); + return AddAsyncSubNotification(workerAddr.ToString(), sub, true); + } + + { + std::shared_lock locker(notifyMutex_); + TbbNotifyWorkerMap::accessor accessor; + auto task = GetPendingNotification(workerAddr.ToString(), consumerMeta.stream_name(), accessor); + + bool exists = false; + RETURN_IF_NOT_OK(HandleExistsSubNotification(task, consumerMeta.consumer_id(), isClose, true, exists)); + if (exists) { + LOG(INFO) << FormatString("Exists notification send to %s, NotifyConsumerPb: %s", workerAddr.ToString(), + LogHelper::IgnoreSensitive(sub)); + return Status::OK(); + } + } + + UpdateTopoNotificationReq req; + req.set_stream_name(consumerMeta.stream_name()); + *req.add_subs() = sub; + req.set_retain_data(retainData); + + // If async mode is set, just enqueue the request + if (asyncMode) { + return AddAsyncSubNotification(workerAddr.ToString(), sub, true); + } + + status = SendNotification(workerAddr, req); + if (IsRpcTimeout(status)) { + LOG(WARNING) << "RPC timeout, async send notification. " << status.GetMsg(); + return AddAsyncSubNotification(workerAddr.ToString(), sub, true); + } + return status; +} + +Status SCNotifyWorkerManager::HandleExistsPubNotification(std::shared_ptr task, + const std::string &workerAddr, bool isClose, bool needPersist, + bool &exists) +{ + RETURN_OK_IF_TRUE(task == nullptr); + Status rc; + exists = false; + auto iter = task->pubs.find(workerAddr); + if (iter != task->pubs.end()) { + exists = true; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "Exists pending pub notification send to <%s>, isClose:%d, detail: %s", task->workerAddr, isClose, + LogHelper::IgnoreSensitive(iter->second)); + if (iter->second.is_close() == isClose) { + // Exists duplicate notification. + return Status::OK(); + } else { + // Exists opposite notification + if (needPersist) { + rc = streamMetaStore_->RemoveNotifyPub(task->workerAddr, iter->second); + } + (void)task->pubs.erase(iter); + return rc; + } + } + return rc; +} + +Status SCNotifyWorkerManager::HandleExistsSubNotification(std::shared_ptr task, + const std::string &consumerId, bool isClose, bool needPersist, + bool &exists) +{ + RETURN_OK_IF_TRUE(task == nullptr); + Status rc; + exists = false; + auto iter = task->subs.find(consumerId); + if (iter != task->subs.end()) { + exists = true; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "Exists pending sub notification send to <%s>, isClose:%d, detail: %s", task->workerAddr, isClose, + LogHelper::IgnoreSensitive(iter->second)); + if (iter->second.is_close() == isClose) { + // Exists duplicate notification. + return Status::OK(); + } else { + // Exists opposite notification + if (needPersist) { + rc = streamMetaStore_->RemoveNotifySub(task->workerAddr, iter->second); + } + (void)task->subs.erase(iter); + return rc; + } + } + return rc; +} + +Status SCNotifyWorkerManager::AddAsyncDeleteNotification(const std::string &streamName) +{ + LOG(INFO) << FormatString("Enqueuing AutoDelete request for stream %s", streamName); + std::unique_lock lock(deleteMutex_); + pendingDeleteStreams_.insert(streamName); + return Status::OK(); +} + +Status SCNotifyWorkerManager::AddAsyncPubNotification(const std::string &workerAddr, const NotifyPubPb &pub, + bool needPersist) +{ + const auto &streamName = pub.stream_name(); + std::shared_lock locker(notifyMutex_); + TbbNotifyWorkerMap::accessor accessor; + auto task = GetOrCreatePendingNotification(workerAddr, streamName, accessor); + CHECK_FAIL_RETURN_STATUS(task != nullptr, K_RUNTIME_ERROR, "task is null"); + bool exists = false; + Status rc = HandleExistsPubNotification(task, pub.worker_addr(), pub.is_close(), needPersist, exists); + if (exists) { + return rc; + } else { + (void)task->pubs.emplace(pub.worker_addr(), pub); + } + LOG(INFO) << FormatString("Stream:<%s>, Dest:<%s>, AsyncSendPubNodeChange: %s", streamName, workerAddr, + LogHelper::IgnoreSensitive(pub)); + if (needPersist) { + RETURN_IF_NOT_OK(streamMetaStore_->AddNotifyPub(workerAddr, pub)); + } + return Status::OK(); +} + +Status SCNotifyWorkerManager::AddAsyncSubNotification(const std::string &workerAddr, const NotifyConsumerPb &sub, + bool needPersist) +{ + const auto &consumerMeta = sub.consumer(); + const auto &streamName = consumerMeta.stream_name(); + + std::shared_lock locker(notifyMutex_); + TbbNotifyWorkerMap::accessor accessor; + auto task = GetOrCreatePendingNotification(workerAddr, streamName, accessor); + CHECK_FAIL_RETURN_STATUS(task != nullptr, K_RUNTIME_ERROR, "task is null"); + bool exists = false; + // no need to check retainData as same event (consumer id and close) will have same retainData + Status rc = HandleExistsSubNotification(task, consumerMeta.consumer_id(), sub.is_close(), needPersist, exists); + if (exists) { + return rc; + } else { + (void)task->subs.emplace(consumerMeta.consumer_id(), sub); + } + + LOG(INFO) << FormatString("Stream:<%s>, Dest:<%s>, AsyncSendConsumerChange: %s", streamName, workerAddr, + LogHelper::IgnoreSensitive(sub)); + if (needPersist) { + RETURN_IF_NOT_OK(streamMetaStore_->AddNotifySub(workerAddr, sub)); + } + return Status::OK(); +} + +std::shared_ptr SCNotifyWorkerManager::GetOrCreatePendingNotification( + const std::string &workerAddr, const std::string &streamName, TbbNotifyWorkerMap::accessor &accessor) +{ + // if other thread insert success, insert will return false and the accessor can get the new value. + if (notifyWorkerMap_.insert(accessor, workerAddr)) { + std::unordered_map> data; + accessor->second = std::move(data); + } + auto iter = accessor->second.find(streamName); + if (iter == accessor->second.end()) { + auto data = std::make_shared(streamName, workerAddr); + iter = accessor->second.emplace(streamName, std::move(data)).first; + } + return iter->second; +} + +std::shared_ptr SCNotifyWorkerManager::GetPendingNotification( + const std::string &workerAddr, const std::string &streamName, TbbNotifyWorkerMap::accessor &accessor) +{ + if (!notifyWorkerMap_.find(accessor, workerAddr)) { + return nullptr; + } + auto iter = accessor->second.find(streamName); + if (iter == accessor->second.end()) { + return nullptr; + } + return iter->second; +} + +bool SCNotifyWorkerManager::ExistsPendingNotification(const std::string &streamName) +{ + { + std::shared_lock locker(notifyMutex_); + if (notifyWorkerMap_.empty()) { + return false; + } + } + + std::lock_guard locker(notifyMutex_); + for (const auto &kv : notifyWorkerMap_) { + auto iter = kv.second.find(streamName); + if (iter != kv.second.end() && !iter->second->Empty()) { + return true; + } + } + return false; +} + +Status SCNotifyWorkerManager::CheckWorkerStatus(const std::string &workerAddr) +{ + // Check connection returns an error if the node is down or there is a problem, such as K_RPC_UNAVAILABLE + // If any error is given, change the rc to be K_WORKER_ABNORMAL + HostPort workerHostPort; + workerHostPort.ParseString(workerAddr); + if (etcdCM_ == nullptr) { + RETURN_STATUS(StatusCode::K_INVALID, "ETCD cluster manager is nullptr."); + } + return etcdCM_->CheckConnection(workerHostPort); +} + +Status SCNotifyWorkerManager::RecoverNotification() +{ + CHECK_FAIL_RETURN_STATUS(streamMetaStore_ != nullptr, K_RUNTIME_ERROR, "streamMetaStore_ is null"); + std::vector> pubs; + RETURN_IF_NOT_OK(streamMetaStore_->GetAllNotifyPub(pubs)); + std::vector> subs; + RETURN_IF_NOT_OK(streamMetaStore_->GetAllNotifySub(subs)); + + // the key is WorkerAddr_StreamName_PubWorkerAddr + for (const auto &kv : pubs) { + auto keyVec = Split(kv.first, "_"); + if (keyVec.size() > 1) { + (void)AddAsyncPubNotification(keyVec[0], kv.second, false); + } + } + + // the key is WorkerAddr_StreamName_ConsumerId + for (const auto &kv : subs) { + auto keyVec = Split(kv.first, "_"); + if (keyVec.size() > 1) { + (void)AddAsyncSubNotification(keyVec[0], kv.second, false); + } + } + + return Status::OK(); +} + +bool SCNotifyWorkerManager::CanRetryDeleteStream(const std::string &streamName) +{ + std::shared_lock lock(deleteMutex_); + auto itr = pendingDeleteStreamsLastRetry_.find(streamName); + if (itr == pendingDeleteStreamsLastRetry_.end()) { + // if entry does not exists this is first retry + return true; + } else { + // Check if it exceeded 60 secs + auto start = itr->second; + auto now = std::chrono::high_resolution_clock::now(); + auto escapedTimeMs = std::chrono::duration_cast(now - start).count(); + return (escapedTimeMs >= WAIT_TIME_BETWEEN_DELSTREAM_MS); + } +} + +Status SCNotifyWorkerManager::DeleteStreams(std::set &deleteStreams) +{ + INJECT_POINT("SCNotifyWorkerManager.DeleteStreams"); + auto iter = deleteStreams.begin(); + while (iter != deleteStreams.end()) { + auto traceGuard = Trace::Instance().SetTraceNewID(GetStringUuid() + "-del"); + const std::string streamName = *iter; + if (!CanRetryDeleteStream(streamName)) { + // Do not retry the stream yet skip to the next one + ++iter; + continue; + } + const std::string infoMsg = FormatString("AutoDelete stream %s", streamName); + LOG(INFO) << infoMsg; + DeleteStreamReqPb req; + DeleteStreamRspPb rsp; + req.set_stream_name(streamName); + req.mutable_src_node_addr()->set_host(""); + req.mutable_src_node_addr()->set_port(-1); + Status rc = scMetadataManager_->DeleteStream(req, rsp); + // Stop retrying if delete is successful or stream not found or if stream is in use (has a consumer or producer) + // We will retry if notifications are pending + if (rc.IsOk() || rc.GetCode() == StatusCode::K_NOT_FOUND || rc.GetCode() == StatusCode::K_SC_STREAM_IN_USE) { + LOG(INFO) << "AutoDelete for stream: " << streamName << " done with status " << rc.ToString(); + iter = deleteStreams.erase(iter); + std::unique_lock lock(deleteMutex_); + pendingDeleteStreamsLastRetry_.erase(streamName); + } else { + LOG(INFO) << FormatString("%s AutoDelete failed with error %s", infoMsg, rc.ToString()); + ++iter; + std::unique_lock lock(deleteMutex_); + pendingDeleteStreamsLastRetry_[streamName] = std::chrono::high_resolution_clock::now(); + } + } + if (!deleteStreams.empty()) { + std::unique_lock lock(deleteMutex_); + pendingDeleteStreams_.insert(deleteStreams.begin(), deleteStreams.end()); + } + return Status::OK(); +} + +Status SCNotifyWorkerManager::GetPendingNotificationByStreamName(const std::string &streamName, + std::vector ¬ifications) +{ + std::lock_guard locker(notifyMutex_); + for (const auto &kv : notifyWorkerMap_) { + auto iter = kv.second.find(streamName); + if (iter != kv.second.end() && !iter->second->Empty()) { + // Get notifications including notifyPub, notifySub and also stopRetention notifications. + for (auto &pair : iter->second->pubs) { + auto &pub = pair.second; + notifications.emplace_back(); + auto ¬ification = notifications.back(); + notification.set_is_pub(true); + notification.set_is_close(pub.is_close()); + notification.set_target_worker(kv.first); + notification.set_id(pub.worker_addr()); + notification.set_force_close(pub.force_close()); + notification.set_retain_data(RetainDataState::State::INIT); + } + for (auto &pair : iter->second->subs) { + auto &sub = pair.second; + notifications.emplace_back(); + auto ¬ification = notifications.back(); + notification.set_is_pub(false); + notification.set_is_close(sub.is_close()); + notification.set_target_worker(kv.first); + notification.set_id(sub.consumer().consumer_id()); + notification.set_retain_data(RetainDataState::State::INIT); + } + notifications.emplace_back(); + auto ¬ification = notifications.back(); + notification.set_retain_data(iter->second->retainData); + notification.set_target_worker(kv.first); + } + } + return Status::OK(); +} + +Status SCNotifyWorkerManager::AddAsyncNotifications(const StreamFields &streamFields, const std::string &streamName, + const MetaForSCMigrationPb &streamMeta) +{ + for (const auto ¬ification : streamMeta.notifications()) { + if (notification.retain_data() != RetainDataState::State::INIT) { + HostPort targetAddress; + RETURN_IF_NOT_OK(targetAddress.ParseString(notification.target_worker())); + RETURN_IF_NOT_OK(AddAsyncStopDataRetentionNotification(targetAddress, streamName)); + } else if (notification.is_pub()) { + NotifyPubPb pub; + pub.set_is_close(notification.is_close()); + pub.set_force_close(notification.force_close()); + pub.set_stream_name(streamName); + pub.set_worker_addr(notification.id()); + pub.set_max_stream_size(streamFields.maxStreamSize_); + pub.set_page_size(streamFields.pageSize_); + pub.set_auto_cleanup(streamFields.autoCleanup_); + pub.set_retain_num_consumer(streamFields.retainForNumConsumers_); + pub.set_encrypt_stream(streamFields.encryptStream_); + pub.set_reserve_size(streamFields.reserveSize_); + pub.set_stream_mode(streamFields.streamMode_); + RETURN_IF_NOT_OK(AddAsyncPubNotification(notification.target_worker(), pub)); + } else { + NotifyConsumerPb sub; + bool found = false; + for (const auto &consumerMetaPb : streamMeta.consumers()) { + if (consumerMetaPb.consumer_id() == notification.id()) { + *sub.mutable_consumer() = consumerMetaPb; + found = true; + break; + } + } + CHECK_FAIL_RETURN_STATUS(found, K_RUNTIME_ERROR, "Migrate async consumer notification failed."); + sub.set_is_close(notification.is_close()); + RETURN_IF_NOT_OK(AddAsyncSubNotification(notification.target_worker(), sub)); + } + } + return Status::OK(); +} + +Status SCNotifyWorkerManager::RemovePendingNotificationByStreamName(const std::string &streamName) +{ + std::lock_guard locker(notifyMutex_); + for (auto &kv : notifyWorkerMap_) { + auto iter = kv.second.find(streamName); + if (iter != kv.second.end() && !iter->second->Empty()) { + for (const auto &pubPair : iter->second->pubs) { + RETURN_IF_NOT_OK(streamMetaStore_->RemoveNotifyPub(kv.first, pubPair.second)); + } + for (const auto &subPair : iter->second->subs) { + RETURN_IF_NOT_OK(streamMetaStore_->RemoveNotifySub(kv.first, subPair.second)); + } + (void)kv.second.erase(iter); + } + } + return Status::OK(); +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/sc_notify_worker_manager.h b/src/datasystem/master/stream_cache/sc_notify_worker_manager.h new file mode 100644 index 0000000..2f91ccb --- /dev/null +++ b/src/datasystem/master/stream_cache/sc_notify_worker_manager.h @@ -0,0 +1,394 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Managing notifications sent to workers. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_SC_NOTIFY_WORKER_MANAGER_H +#define DATASYSTEM_MASTER_STREAM_CACHE_SC_NOTIFY_WORKER_MANAGER_H + +#include +#include +#include +#include +#include +#include + +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/common/util/wait_post.h" +#include "datasystem/master/stream_cache/rpc_session_manager.h" +#include "datasystem/master/stream_cache/store/rocks_stream_meta_store.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" + +namespace datasystem { +namespace master { +static constexpr int WAIT_TIME_BETWEEN_DELSTREAM_MS = 10000; // waiting time between retries. +struct PendingNotification { + PendingNotification(std::string streamName, std::string workerAddr) + : streamName(std::move(streamName)), workerAddr(std::move(workerAddr)) + { + } + ~PendingNotification() = default; + + /** + * @brief Generate the notification request. + * @param[out] req The notification request. + */ + void ConstructRequest(UpdateTopoNotificationReq &req) + { + req.set_stream_name(streamName); + for (const auto &kv : pubs) { + *req.add_pubs() = kv.second; + } + for (const auto &kv : subs) { + *req.add_subs() = kv.second; + } + req.set_retain_data(retainData); + } + + bool Empty() + { + return pubs.empty() && subs.empty() && retainData == RetainDataState::State::INIT; + } + + std::string streamName; + // the notification send to this worker. + std::string workerAddr; + // notification to stop retaining data + RetainDataState::State retainData{ RetainDataState::State::INIT }; + // worker address -> NotifyPubPb + std::unordered_map pubs; + // consumer id -> ConsumerMetaPb + std::unordered_map subs; +}; + +// worker address -> stream name -> PendingNotification. +using TbbNotifyWorkerMap = + tbb::concurrent_hash_map>>; +using RocksStreamMetaStore = stream_cache::RocksStreamMetaStore; + +class SCMetadataManager; +class SCNotifyWorkerManager { +public: + /** + * @brief Construct a new SCNotifyWorkerManager instance. + * @param[in] streamMetaStore The stream rocksdb store object. + * @param[in] akSkManager Used to do AK/SK authenticate. + * @param[in] rpcSessionManager Master to Worker session manager. + * @param[in] cm The etcd cluster manager instance. + * @param[in] scMetadataManager The sc metadata manager instance. + */ + SCNotifyWorkerManager(std::shared_ptr streamMetaStore, + std::shared_ptr akSkManager, + std::shared_ptr rpcSessionManager, EtcdClusterManager *cm, + SCMetadataManager *scMetadataManager); + + ~SCNotifyWorkerManager(); + + /** + * @brief Initialization. + * @return Status of the call. + */ + Status Init(); + + /** + * @brief New pub node notification. + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] streamName The stream name. + * @param[in] streamFields The stream fields. + * @param[in] srcWorkerAddr The source worker address. + * @return Status of the call. + */ + Status NotifyNewPubNode(const HostPort &workerAddr, const std::string &streamName, const StreamFields &streamFields, + const HostPort &srcWorkerAddr); + + /** + * @brief Del pub node notification. + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] streamName The stream name. + * @param[in] srcWorkerAddr The source worker address. + * @param[in] forceClose If the pub node had a crash or regular close + * @return Status of the call. + */ + Status NotifyDelPubNode(const HostPort &workerAddr, const std::string &streamName, const HostPort &srcWorkerAddr, + bool forceClose); + + /** + * @brief New consumer notification. + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] consumerMeta The consumer metadata. + * @param[in] retainData Wether to retain data or not + * @return Status of the call. + */ + Status NotifyNewConsumer(const HostPort &workerAddr, const ConsumerMetaPb &consumerMeta, + const RetainDataState::State retainData = RetainDataState::State::INIT); + + /** + * @brief Del consumer notification. + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] consumerMeta The consumer metadata. + * @return Status of the call. + */ + Status NotifyDelConsumer(const HostPort &workerAddr, const ConsumerMetaPb &consumerMeta); + + /** + * @brief Add the async pub change notification. + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] pub The producer change notification. + * @param[in] needPersist Indicate whether need persist to rocksdb. + * @return Status of the call. + */ + Status AddAsyncPubNotification(const std::string &workerAddr, const NotifyPubPb &pub, bool needPersist = true); + + /** + * @brief Add the async consumer change notification. + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] sub The consumer change notification. + * @param[in] needPersist Indicate whether need persist to rocksdb. + * @return Status of the call. + */ + Status AddAsyncSubNotification(const std::string &workerAddr, const NotifyConsumerPb &sub, bool needPersist = true); + + /** + * @brief Clear all pending notification send to the specific worker. + * @param[in] workerAddr The worker address. + * @return Status of the call. + */ + Status ClearPendingNotification(const std::string &workerAddr); + + /** + * @brief Check whether the stream exists pending notification. + * @param[in] streamName The stream name. + * @return True if the stream exists pending notification. + */ + bool ExistsPendingNotification(const std::string &streamName); + + /** + * @brief Shutdown the sc notify manager module. + */ + void Shutdown(); + + /** + * @brief Add an async notice to drop the stream when there are no more consumer/producer globally + * @param streamName + * @return Status of the call. + */ + Status AddAsyncDeleteNotification(const std::string &streamName); + + /** + * @brief Add an async notice to stop retaining data at producers + * @param[in] workerAddr Producer worker address + * @param[in] streamName Name of the stream + * @return Status of the call + */ + Status AddAsyncStopDataRetentionNotification(const HostPort &workerAddr, const std::string &streamName); + + /** + * @brief Get the async notifications related to a stream name. + * @param[in] streamName Name of the stream. + * @param[out] notifications The async notifications. + * @return Status of the call. + */ + Status GetPendingNotificationByStreamName(const std::string &streamName, + std::vector ¬ifications); + + /** + * @brief Add async notifications for a stream. + * @param[in] streamFields The stream fields. + * @param[in] streamName Name of the stream. + * @param[in] streamMeta The stream migration meta, containing the notifications. + * @return Status of the call. + */ + Status AddAsyncNotifications(const StreamFields &streamFields, const std::string &streamName, + const MetaForSCMigrationPb &streamMeta); + + /** + * @brief Remove the async notifications related to a stream name. + * @param[in] streamName Name of the stream. + * @return Status of the call. + */ + Status RemovePendingNotificationByStreamName(const std::string &streamName); + +private: + /** + * @brief Notification sending process. + */ + Status ProcessAsyncNotify(); + + /** + * @brief Delete stream sending process. + */ + Status ProcessDeleteStreams(); + + /** + * @brief Check worker status. + * @param[in] workerAddr The target worker address. + * @return Status of the call. + */ + Status CheckWorkerStatus(const std::string &workerAddr); + + /** + * @brief if the target worker is abnormal will add an async sending task otherwise it will be sent directly + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] streamName The stream name. + * @param[in] streamFields The stream fields + * @param[in] srcWorkerAddr The source worker address. + * @param[in] isClose True for create producer, False for close producer. + * @param[in] forceClose If the pub node had a crash or regular close when isClose = true + * @param[in] asyncMode Skip the sync call and always send async call. + * @return Status of the call. + */ + Status NotifyPubNodeImpl(const HostPort &workerAddr, const std::string &streamName, + const StreamFields &streamFields, const HostPort &srcWorkerAddr, bool isClose, + bool forceClose, bool asyncMode); + + /** + * @brief if the target worker is abnormal will add an async sending task otherwise it will be sent directly + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] consumerMeta The consumer metadata. + * @param[in] isClose True for create consumer, False for close consumer. + * @param[in] retainData Tell worker to stop retaining data. + * @param[in] asyncMode Skip the sync call and always send async call. + * @return Status of the call. + */ + Status NotifyConsumerImpl(const HostPort &workerAddr, const ConsumerMetaPb &consumerMeta, bool isClose, + RetainDataState::State retainData, bool asyncMode); + + /** + * @brief Get the or create pending notification task. + * @param[in] workerAddr The worker address. + * @param[in] streamName The stream name. + * @param[out] accessor The tbb accessor. + * @return The pending notification object. + */ + std::shared_ptr GetOrCreatePendingNotification(const std::string &workerAddr, + const std::string &streamName, + TbbNotifyWorkerMap::accessor &accessor); + /** + * @brief Get the pending notification task. + * @param[in] workerAddr The worker address. + * @param[in] streamName The stream name. + * @param[out] accessor The tbb accessor. + * @return The pending notification object. + */ + std::shared_ptr GetPendingNotification(const std::string &workerAddr, + const std::string &streamName, + TbbNotifyWorkerMap::accessor &accessor); + + /** + * @brief Send the pending notification. + * @param[in] streamList The list of streams with pending notifications. + * @return Status of the call. + */ + Status SendPendingNotification(std::vector> &streamList); + + /** + * @brief Remove the async notification. + * @param[in] accessor The tbb accessor. + * @param[in] streamName The stream name. + * @param[in] task The pending notification task. + * @return Status of the call. + */ + Status RemoveAsyncNotification(TbbNotifyWorkerMap::accessor &accessor, const std::string &streamName, + std::shared_ptr task); + + /** + * @brief Send the topo change notification to worker. + * @param[in] workerAddr The worker address which the notification send to. + * @param[in] req The notification request. + * @return Status of the call. + */ + Status SendNotification(const HostPort &workerAddr, UpdateTopoNotificationReq &req); + + /** + * @brief Recover notification from rocksdb. + * @return Status of the call. + */ + Status RecoverNotification(); + + /** + * @brief Handle the exists pub notification. + * @param[in] task The pending notification task. + * @param[in] workerAddr The worker address. + * @param[in] isClose Is close scenario or not. + * @param[in] needPersist Indicate whether need persist to rocksdb. + * @param[out] exists The worker already exists notification. + * @return Status of the call. + */ + Status HandleExistsPubNotification(std::shared_ptr task, const std::string &workerAddr, + bool isClose, bool needPersist, bool &exists); + + /** + * @brief Handle the exists sub notification. + * @param[in] task The pending notification task. + * @param[in] consumerId The consumer id. + * @param[in] isClose Is close scenario or not. + * @param[in] needPersist Indicate whether need persist to rocksdb. + * @param[out] exists The consumer already exists notification. + * @return Status of the call. + */ + Status HandleExistsSubNotification(std::shared_ptr task, const std::string &consumerId, + bool isClose, bool needPersist, bool &exists); + + /** + * @brief Async delete streams + * @param deleteStreams + * @return + */ + Status DeleteStreams(std::set &deleteStreams); + + /** + * @brief Determines if we need retry DeleteStream yet + * @param[in] streamName Stream name + * @return true if DeleteStream can be retried false if not + */ + bool CanRetryDeleteStream(const std::string &streamName); + + /** + * @brief Send the pending notification for stream. + * @return Status of the call. + */ + Status SendPendingNotificationForStream(const std::string &workerAddr, const std::string &streamName); + + const int ASYNC_NOTIFY_TIME_MS = 100; // Time interval between two async update object. + std::unique_ptr notifyThreadPool_{ nullptr }; + std::unique_ptr deleteThreadPool_{ nullptr }; + std::future notifyFut_; + std::future deleteFut_; + WaitPost cvLock_; + std::atomic interruptFlag_{ false }; + + // Protect notifyWorkerMap_. + std::shared_timed_mutex notifyMutex_; + TbbNotifyWorkerMap notifyWorkerMap_; + // Protect pendingDeleteStreams_ and pendingDeleteStreamsLastRetry_. + std::shared_timed_mutex deleteMutex_; + std::set pendingDeleteStreams_; + std::unordered_map pendingDeleteStreamsLastRetry_; + + std::shared_ptr streamMetaStore_{ nullptr }; + + std::shared_ptr akSkManager_{ nullptr }; + std::shared_ptr rpcSessionManager_{ nullptr }; + EtcdClusterManager *etcdCM_{ nullptr }; + SCMetadataManager *scMetadataManager_; +}; +} // namespace master +} // namespace datasystem +#endif // DATASYSTEM_MASTER_STREAM_CACHE_SC_NOTIFY_WORKER_MANAGER_H diff --git a/src/datasystem/master/stream_cache/store/CMakeLists.txt b/src/datasystem/master/stream_cache/store/CMakeLists.txt new file mode 100644 index 0000000..3c6ccd6 --- /dev/null +++ b/src/datasystem/master/stream_cache/store/CMakeLists.txt @@ -0,0 +1,19 @@ +set(MASTER_SC_STORE_SRCS + rocks_stream_meta_store.cpp + ) + +set(MASTER_SC_STORE_DEPEND_LIBS + RocksDB::rocksdb + common_log + common_rocksdb + master_stream_protos + posix_protos + worker_stream_protos + ) + +add_library(master_stream_cache_store STATIC ${MASTER_SC_STORE_SRCS}) +target_link_libraries(master_stream_cache_store PRIVATE ${MASTER_SC_STORE_DEPEND_LIBS}) +add_dependencies(master_stream_cache_store + posix_protos + worker_stream_protos + master_stream_protos) \ No newline at end of file diff --git a/src/datasystem/master/stream_cache/store/rocks_stream_meta_store.cpp b/src/datasystem/master/stream_cache/store/rocks_stream_meta_store.cpp new file mode 100644 index 0000000..10e7462 --- /dev/null +++ b/src/datasystem/master/stream_cache/store/rocks_stream_meta_store.cpp @@ -0,0 +1,337 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Define interface to store stream meta in RocksDB. + */ +#include "datasystem/master/stream_cache/store/rocks_stream_meta_store.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/constants.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/kvstore/rocksdb/replica.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/uri.h" +#include "datasystem/master/stream_cache/store/stream_transform.h" + +namespace datasystem { +namespace master { +namespace stream_cache { +const std::string RocksStreamMetaStore::streamTableName_ = STREAM_TABLE_NAME; +const std::string RocksStreamMetaStore::pubTableName_ = PUB_TABLE_NAME; +const std::string RocksStreamMetaStore::subTableName_ = SUB_TABLE_NAME; +const std::string RocksStreamMetaStore::notifyPubTableName_ = NOTIFY_PUB_TABLE_NAME; +const std::string RocksStreamMetaStore::notifySubTableName_ = NOTIFY_SUB_TABLE_NAME; +const std::string RocksStreamMetaStore::streamConCntTableName_ = STREAM_CON_CNT_TABLE_NAME; +const std::string RocksStreamMetaStore::streamProCntTableName_ = STREAM_PRODUCER_COUNT; +RocksStreamMetaStore::RocksStreamMetaStore(RocksStore *rocksStore) : rocksStore_(std::move(rocksStore)) +{ +} + +Status RocksStreamMetaStore::Init() +{ + RETURN_RUNTIME_ERROR_IF_NULL(rocksStore_); + return Replica::CreateScTable(rocksStore_); +} + +Status RocksStreamMetaStore::AddPubNode(const ProducerMetaPb &producerMeta) +{ + INJECT_POINT("master.RocksStreamMetaStore.DoNotAddPubSubMetadata"); + // Construct key-value pair in rocksdb + HostPort workerAddr(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + std::string key(producerMeta.stream_name() + PREFIX_SPLITTER + workerAddr.ToString()); + + std::string serializedStr; + CHECK_FAIL_RETURN_STATUS(producerMeta.SerializeToString(&serializedStr), StatusCode::K_UNKNOWN_ERROR, + "Failed to Serialize Pub meta"); + + std::lock_guard lock(mutex_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(pubTableName_, key, serializedStr), + "Failed to add pub meta: " + key); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Success to add pub meta: %s", LogPrefix(), key); + return Status::OK(); +} + +Status RocksStreamMetaStore::DelPubNode(const ProducerMetaPb &producerMeta) +{ + // Construct key-value pair in rocksdb + HostPort workerAddr(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + std::string key(producerMeta.stream_name() + PREFIX_SPLITTER + workerAddr.ToString()); + + std::lock_guard lock(mutex_); + return rocksStore_->Delete(pubTableName_, key); +} + +Status RocksStreamMetaStore::AddSubNode(const ConsumerMetaPb &consumerMeta) +{ + INJECT_POINT("master.RocksStreamMetaStore.DoNotAddPubSubMetadata"); + std::string key(consumerMeta.stream_name() + PREFIX_SPLITTER + consumerMeta.consumer_id()); + std::string serializedStr; + CHECK_FAIL_RETURN_STATUS(consumerMeta.SerializeToString(&serializedStr), StatusCode::K_UNKNOWN_ERROR, + "Failed to Serialize Sub meta"); + std::lock_guard lock(mutex_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(subTableName_, key, serializedStr), + "Failed to add sub meta: " + key); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Success to add sub meta: %s", LogPrefix(), key); + return Status::OK(); +} + +Status RocksStreamMetaStore::GetLifeTimeConsumerCount(const std::string &streamName, uint32_t &consumerCount) +{ + std::string outValue; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Get consumer life counts", LogPrefix(), streamName); + + std::shared_lock lock(mutex_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Get(streamConCntTableName_, streamName, outValue), + "Failed to get all pairs from table: " + streamConCntTableName_); + lock.unlock(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Succeed to get count from table: %s, streamName:<%s>: count:<%s>", + streamConCntTableName_, streamName, outValue); + consumerCount = static_cast(std::stoul(outValue)); + return Status::OK(); +} + +Status RocksStreamMetaStore::UpdateLifeTimeConsumerCount(const std::string &streamName, const uint32_t consumerCount) +{ + std::lock_guard lock(mutex_); + const std::string &key(streamName); + std::string value = std::to_string(consumerCount); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(streamConCntTableName_, key, value), + "Failed to put stream consumer count: " + key); + return Status::OK(); +} + +Status RocksStreamMetaStore::DelSubNode(const std::string &streamName, const std::string &consumerId) +{ + std::string key(streamName + PREFIX_SPLITTER + consumerId); + std::lock_guard lock(mutex_); + return rocksStore_->Delete(subTableName_, key); +} + +Status RocksStreamMetaStore::AddOrUpdateStream(const std::string &streamName, const StreamFields &streamFields) +{ + const std::string &key(streamName); + StreamMetaPb streamMetaPb; + streamMetaPb.set_stream_name(streamName); + streamMetaPb.set_max_stream_size(streamFields.maxStreamSize_); + streamMetaPb.set_page_size(streamFields.pageSize_); + streamMetaPb.set_auto_cleanup(streamFields.autoCleanup_); + streamMetaPb.set_retain_num_consumer(streamFields.retainForNumConsumers_); + streamMetaPb.set_encrypt_stream(streamFields.encryptStream_); + streamMetaPb.set_reserve_size(streamFields.reserveSize_); + streamMetaPb.set_stream_mode(streamFields.streamMode_); + std::string serializedStr; + CHECK_FAIL_RETURN_STATUS(streamMetaPb.SerializeToString(&serializedStr), StatusCode::K_UNKNOWN_ERROR, + "Failed to Serialize Stream meta"); + std::lock_guard lock(mutex_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(streamTableName_, key, serializedStr), + "Failed to put stream meta: " + key); + return Status::OK(); +} + +Status RocksStreamMetaStore::AddStream(const std::string &streamName, const StreamFields &streamFields) +{ + RETURN_IF_NOT_OK(AddOrUpdateStream(streamName, streamFields)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Success to add stream meta: %s", LogPrefix(), streamName); + return Status::OK(); +} + +Status RocksStreamMetaStore::UpdateStream(const std::string &streamName, const StreamFields &streamFields) +{ + RETURN_IF_NOT_OK(AddOrUpdateStream(streamName, streamFields)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Success to update stream meta: %s", LogPrefix(), streamName); + return Status::OK(); +} + +Status RocksStreamMetaStore::DelStream(const std::string &streamName) +{ + std::lock_guard lock(mutex_); + return rocksStore_->Delete(streamTableName_, streamName); +} + +Status RocksStreamMetaStore::GetOneStreamProducers(const std::string &streamName, + std::vector &pubWorkerMetas) +{ + std::vector> outKeyValues; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Get pub nodes", LogPrefix(), streamName); + + std::shared_lock lock(mutex_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + rocksStore_->PrefixSearch(pubTableName_, streamName + PREFIX_SPLITTER, outKeyValues), + "Failed to get all pairs from table: " + pubTableName_); + lock.unlock(); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Succeed to get all pairs from table: %s, Pub outKeyValues.size() = %zu", + pubTableName_, outKeyValues.size()); + ProducerMetaPb producerMetaPb; + for (const auto &outKeyValue : outKeyValues) { + CHECK_FAIL_RETURN_STATUS(producerMetaPb.ParseFromString(outKeyValue.second), StatusCode::K_UNKNOWN_ERROR, + "Parse string to producerMetaPb failed."); + if (streamName != producerMetaPb.stream_name()) { + continue; + } + pubWorkerMetas.emplace_back(producerMetaPb); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("HostPort:<%s:%d>, streamName:<%s>", + producerMetaPb.worker_address().host(), + producerMetaPb.worker_address().port(), producerMetaPb.stream_name()); + } + return Status::OK(); +} + +Status RocksStreamMetaStore::GetOneStreamConsumers(const std::string &streamName, + std::vector &consumerMetas) +{ + std::vector> outKeyValues; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Get consumers", LogPrefix(), streamName); + + std::shared_lock lock(mutex_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + rocksStore_->PrefixSearch(subTableName_, streamName + PREFIX_SPLITTER, outKeyValues), + "Failed to get all pairs from table: " + subTableName_); + lock.unlock(); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "Succeed to get all pairs from table: %s, Consumer outKeyValues.size() = %zu", subTableName_, + outKeyValues.size()); + ConsumerMetaPb consumerMetaPb; + for (const auto &outKeyValue : outKeyValues) { + CHECK_FAIL_RETURN_STATUS(consumerMetaPb.ParseFromString(outKeyValue.second), StatusCode::K_UNKNOWN_ERROR, + "Parse string to ConsumerMetaPb failed."); + if (streamName != consumerMetaPb.stream_name()) { + continue; + } + consumerMetas.emplace_back(std::move(consumerMetaPb)); + } + return Status::OK(); +} + +Status RocksStreamMetaStore::GetAllStream(std::vector &streamMetas) +{ + std::vector> outKeyValues; + + std::shared_lock lock(mutex_); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->GetAll(streamTableName_, outKeyValues), + FormatString("Fail to get all pairs from table:<%s>", streamTableName_)); + lock.unlock(); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "Succeed to get all pairs from table: %s, stream outKeyValues.size() = %zu", streamTableName_, + outKeyValues.size()); + StreamMetaPb streamMetaPb; + for (const auto &outKeyValue : outKeyValues) { + CHECK_FAIL_RETURN_STATUS(streamMetaPb.ParseFromString(outKeyValue.second), StatusCode::K_UNKNOWN_ERROR, + "Parse string to streamMetaPb failed."); + CHECK_FAIL_RETURN_STATUS( + streamMetaPb.stream_name() == outKeyValue.first, StatusCode::K_RUNTIME_ERROR, + FormatString("Key:<%s>, value:<%s> are not equal", outKeyValue.first, streamMetaPb.stream_name())); + streamMetas.emplace_back(streamMetaPb); + } + return Status::OK(); +} + +Status RocksStreamMetaStore::AddNotifyPub(const std::string &workerAddr, const NotifyPubPb &pub) +{ + std::string key = workerAddr + "_" + pub.stream_name() + "_" + pub.worker_addr(); + std::string value; + if (!pub.SerializeToString(&value)) { + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "Failed to Serialize NotifyPubPb"); + } + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(notifyPubTableName_, key, value), + FormatString("Failed to put: %s", key)); + return Status::OK(); +} + +Status RocksStreamMetaStore::RemoveNotifyPub(const std::string &workerAddr, const NotifyPubPb &pub) +{ + std::string key = workerAddr + "_" + pub.stream_name() + "_" + pub.worker_addr(); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Delete(notifyPubTableName_, key), + FormatString("Failed to delete: %s", key)); + return Status::OK(); +} + +Status RocksStreamMetaStore::AddNotifySub(const std::string &workerAddr, const NotifyConsumerPb &sub) +{ + const auto &meta = sub.consumer(); + std::string key = workerAddr + "_" + meta.stream_name() + "_" + meta.consumer_id(); + std::string value; + if (!sub.SerializeToString(&value)) { + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "Failed to Serialize NotifyConsumerPb"); + } + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Put(notifySubTableName_, key, value), + FormatString("Failed to put: %s", key)); + return Status::OK(); +} + +Status RocksStreamMetaStore::RemoveNotifySub(const std::string &workerAddr, const NotifyConsumerPb &sub) +{ + const auto &meta = sub.consumer(); + std::string key = workerAddr + "_" + meta.stream_name() + "_" + meta.consumer_id(); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->Delete(notifySubTableName_, key), + FormatString("Failed to delete: %s", key)); + return Status::OK(); +} + +Status RocksStreamMetaStore::RemoveNotificationByWorker(const std::string &workerAddr) +{ + std::string prefixKey = workerAddr + "_"; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->PrefixDelete(notifyPubTableName_, prefixKey), + FormatString("Failed to delete prefix key: %s", prefixKey)); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->PrefixDelete(notifySubTableName_, prefixKey), + FormatString("Failed to delete prefix key: %s", prefixKey)); + return Status::OK(); +} + +Status RocksStreamMetaStore::GetAllNotifyPub(std::vector> &pubs) +{ + std::vector> outKeyValues; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->GetAll(notifyPubTableName_, outKeyValues), + FormatString("Failed to get all pairs from table: %s", notifyPubTableName_)); + + NotifyPubPb pub; + pubs.clear(); + pubs.reserve(outKeyValues.size()); + for (const auto &kv : outKeyValues) { + CHECK_FAIL_RETURN_STATUS(pub.ParseFromString(kv.second), StatusCode::K_UNKNOWN_ERROR, + "Parse string to NotifyPubPb failed."); + pubs.emplace_back(kv.first, std::move(pub)); + } + return Status::OK(); +} + +Status RocksStreamMetaStore::GetAllNotifySub(std::vector> &subs) +{ + std::vector> outKeyValues; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rocksStore_->GetAll(notifySubTableName_, outKeyValues), + FormatString("Failed to get all pairs from table: %s", notifySubTableName_)); + + NotifyConsumerPb sub; + subs.clear(); + subs.reserve(outKeyValues.size()); + for (const auto &kv : outKeyValues) { + CHECK_FAIL_RETURN_STATUS(sub.ParseFromString(kv.second), StatusCode::K_UNKNOWN_ERROR, + "Parse string to NotifyConsumerPb failed."); + subs.emplace_back(kv.first, std::move(sub)); + } + return Status::OK(); +} + +std::string RocksStreamMetaStore::LogPrefix() +{ + return "SC Rocksdb on master"; +} +} // namespace stream_cache +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/store/rocks_stream_meta_store.h b/src/datasystem/master/stream_cache/store/rocks_stream_meta_store.h new file mode 100644 index 0000000..2daf605 --- /dev/null +++ b/src/datasystem/master/stream_cache/store/rocks_stream_meta_store.h @@ -0,0 +1,233 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Declare interface to store stream meta in RocksDB. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_STORE_ROCKS_STREAM_META_STORE_H +#define DATASYSTEM_MASTER_STREAM_CACHE_STORE_ROCKS_STREAM_META_STORE_H + +#include +#include +#include +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/kvstore/rocksdb/rocks_store.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/protos/master_stream.pb.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/common/util/status_helper.h" + +namespace datasystem { +namespace master { +namespace stream_cache { + +// Definition: +// PubNode: Worker node which contains at least one producer. +// SubNode: Consumer node which includes all information about one consumer. +class RocksStreamMetaStore { +public: + /** + * @brief Construct RocksStreamMetaStore. + */ + explicit RocksStreamMetaStore(RocksStore *rocksStore); + + /** + * @brief Used path backStorePath to start rocksdb in rocksStore. + * @return Status of the call. + */ + Status Init(); + + /** + * @brief Add Pub node meta in Rocksdb. + * @param[in] producerMeta The producer metadata. + * @return Status of the call. + */ + Status AddPubNode(const ProducerMetaPb &producerMeta); + + /** + * @brief Remove pub meta from Rocksdb. + * @param[in] producerMeta The producer metadata. + * @return Status of the call. + */ + Status DelPubNode(const ProducerMetaPb &producerMeta); + + /** + * @brief Add Sub consumer meta in Rocksdb. + * @param[in] consumerMeta The consumer metadata. + * @return Status of the call. + */ + Status AddSubNode(const ConsumerMetaPb &consumerMeta); + + /** + * @brief Remove consumer meta from Rocksdb. + * @param[in] streamName The stream name. + * @param[in] consumerId The consumer id. + * @return Status of the call. + */ + Status DelSubNode(const std::string &streamName, const std::string &consumerId); + + /** + * @brief Add stream meta in Rocksdb, we save the key as streamName. + * @param[in] streamName The target stream. + * @param[in] streamFields The stream fields + * @return Status of the call. + */ + Status AddStream(const std::string &streamName, const StreamFields &streamFields); + + /** + * @brief Update stream meta in Rocksdb, we save the key as streamName. + * @param[in] streamName The target stream. + * @param[in] streamFields The stream fields + * @return Status of the call. + */ + Status UpdateStream(const std::string &streamName, const StreamFields &streamFields); + + /** + * @brief Remove stream meta from Rocksdb. + * @param[in] streamName The stream meta key to be removed. + * @return Status of the call. + */ + Status DelStream(const std::string &streamName); + + /** + * @brief Updates consumer lifetime count in Rocksdb. + * @param[in] streamName The stream name. + * @param[in] consumerCount The consumer count to be updated + * @return Status of the call. + */ + Status UpdateLifeTimeConsumerCount(const std::string &streamName, const uint32_t consumerCount); + + /** + * @brief Gets consumer lifetime count from Rocksdb. + * @param[in] streamName The stream meta key to be removed. + * @param[out] consumerCount The consumer count + * @return Status of the call. + */ + Status GetLifeTimeConsumerCount(const std::string &streamName, uint32_t &consumerCount); + + /** + * @brief Get all pub node k-v pairs from rocksdb pubTable for a stream. + * @param[in] streamName The stream name. + * @param[out] producerMetas The producer metadata. + * @return Status of the call. + */ + Status GetOneStreamProducers(const std::string &streamName, std::vector &producerMetas); + + /** + * @brief Get all sub node k-v pairs from rocksdb subTable for a stream. + * @param[in] streamName The stream name. + * @param[out] consumerMetas The consumer metadata. + * @return Status of the call. + */ + Status GetOneStreamConsumers(const std::string &streamName, std::vector &consumerMetas); + + /** + * @brief Get all stream k-v pairs from rocksdb streamTable. + * @param[out] streamMetas The output metas in table. + * @return Status of the call. + */ + Status GetAllStream(std::vector &streamMetas); + + /** + * @brief Add pub node change notification to Rocksdb. + * @param[in] workerAddr The target worker address. + * @param[in] pub The pub metadata. + * @return Status of the call. + */ + Status AddNotifyPub(const std::string &workerAddr, const NotifyPubPb &pub); + + /** + * @brief Remote pub node change notification in Rocksdb. + * @param[in] workerAddr The target worker address. + * @param[in] pub The pub metadata. + * @return Status of the call. + */ + Status RemoveNotifyPub(const std::string &workerAddr, const NotifyPubPb &pub); + + /** + * @brief Add sub node change notification to Rocksdb. + * @param[in] workerAddr The target worker address. + * @param[in] sub The sub metadata. + * @return Status of the call. + */ + Status AddNotifySub(const std::string &workerAddr, const NotifyConsumerPb &sub); + + /** + * @brief Remove sub node change notification in Rocksdb. + * @param[in] workerAddr The target worker address. + * @param[in] sub The sub metadata. + * @return Status of the call. + */ + Status RemoveNotifySub(const std::string &workerAddr, const NotifyConsumerPb &sub); + + /** + * @brief Remove all notification send to the specific worker. + * @param[in] workerAddr + * @return Status of the call. + */ + Status RemoveNotificationByWorker(const std::string &workerAddr); + + /** + * @brief Get the all pub change notification from Rocksdb. + * @param[out] pubs The list of pub change notification. + * @return Status of the call. + */ + Status GetAllNotifyPub(std::vector> &pubs); + + /** + * @brief Get the all sub change notification from Rocksdb. + * @param[out] pubs The list of sub change notification. + * @return Status of the call. + */ + Status GetAllNotifySub(std::vector> &subs); + + /** + * @brief Get log prefix + * @return The log prefix + */ + static std::string LogPrefix(); + +private: + + /** + * @brief Add or update stream meta in Rocksdb, we save the key as streamName. + * @param[in] streamName The target stream. + * @param[in] streamFields The stream fields + * @return Status of the call. + */ + Status AddOrUpdateStream(const std::string &streamName, const StreamFields &streamFields); + + const static std::string streamTableName_; // the stream meta table name + const static std::string pubTableName_; // the global pub table name + const static std::string subTableName_; // the global sub table name + const static std::string notifyPubTableName_; // the notify pub table name + const static std::string notifySubTableName_; // the notify sub table name + const static std::string streamConCntTableName_; // the table to store lifetime consumer count + const static std::string streamProCntTableName_; // the table to store lifetime producer count + + std::shared_timed_mutex mutex_; // Concurrent control for this class + + // The backend rocksdb storage. + RocksStore *rocksStore_; +}; +} // namespace stream_cache +} // namespace master +} // namespace datasystem +#endif diff --git a/src/datasystem/master/stream_cache/store/stream_transform.h b/src/datasystem/master/stream_cache/store/stream_transform.h new file mode 100644 index 0000000..3ef1a98 --- /dev/null +++ b/src/datasystem/master/stream_cache/store/stream_transform.h @@ -0,0 +1,69 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Interface to search stream meta from RocksDB. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_STORE_STREAM_TRANSFORM_H +#define DATASYSTEM_MASTER_STREAM_CACHE_STORE_STREAM_TRANSFORM_H + +#include "rocksdb/slice.h" +#include "rocksdb/slice_transform.h" + +#include "datasystem/common/log/log.h" + +namespace datasystem { +namespace master { +namespace stream_cache { +static constexpr char PREFIX_SPLITTER[] = "_"; +class StreamTransform : public rocksdb::SliceTransform { +public: + StreamTransform() = default; + + const char *Name() const override + { + return "StreamTransform"; + } + + rocksdb::Slice Transform(const rocksdb::Slice &src) const override + { + DCHECK(InDomain(src)); + return GetPrefix(src); + } + + bool InDomain(const rocksdb::Slice &src) const override + { + return !GetPrefix(src).empty(); + } + + bool SameResultWhenAppended(const rocksdb::Slice &prefix) const override + { + return InDomain(prefix); + } + +private: + static rocksdb::Slice GetPrefix(const rocksdb::Slice &src) + { + std::string splitter(PREFIX_SPLITTER); + std::string::size_type sz = + std::find_end(src.data(), src.data() + src.size(), splitter.begin(), splitter.end()) - src.data(); + return { src.data(), sz }; + } +}; +} // namespace stream_cache +} // namespace master +} // namespace datasystem +#endif \ No newline at end of file diff --git a/src/datasystem/master/stream_cache/stream_metadata.cpp b/src/datasystem/master/stream_cache/stream_metadata.cpp new file mode 100644 index 0000000..8155e88 --- /dev/null +++ b/src/datasystem/master/stream_cache/stream_metadata.cpp @@ -0,0 +1,906 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The definition of stream metadata object. + */ +#include "datasystem/master/stream_cache/stream_metadata.h" + +#include +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/log/log_helper.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/master/stream_cache/sc_notify_worker_manager.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" + +namespace datasystem { +namespace master { +StreamMetadata::StreamMetadata(std::string streamName, const StreamFields &streamFields, + RocksStreamMetaStore *streamMetaStore, std::shared_ptr akSkManager, + std::shared_ptr rpcSessionManager, EtcdClusterManager *etcdCM, + SCNotifyWorkerManager *notifyWorkerManager) + : streamName_(std::move(streamName)), + streamFields_(streamFields), + topoManager_(std::make_unique(streamName_)), + streamStore_(streamMetaStore), + alive_(true), + akSkManager_(std::move(akSkManager)), + rpcSessionManager_(std::move(rpcSessionManager)), + etcdCM_(etcdCM), + notifyWorkerManager_(notifyWorkerManager) +{ +} + +StreamMetadata::~StreamMetadata() +{ + // Update stream metrics final time before exit + if (scStreamMetrics_) { + UpdateStreamMetrics(); + } +} + +Status StreamMetadata::PubIncreaseNode(const ProducerMetaPb &producerMeta, StreamFields &streamFields) +{ + HostPort pubWorkerAddress(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + LOG(INFO) << "Topo for stream " << streamName_ << ":" << *topoManager_; + // Check topo manager status and return if stream is getting deleted. + CHECK_FAIL_RETURN_STATUS(!topoManager_->GetStreamStatus(), StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Stream <%s> is undergoing deletion, do not operate it", streamName_)); + bool isRecon = false; + // The procedure is not entirely locked in this code path to allow parallel CreateProducer on master. + // PubIncreaseNodeStart and PubIncreaseNodeEnd will acquire unique lock when necessary. + // While the RPC requests can be sent in parallel without lock. + bool alreadyLocked = false; + RETURN_IF_NOT_OK(PubIncreaseNodeInternal(producerMeta, streamFields, pubWorkerAddress, isRecon, alreadyLocked)); + return Status::OK(); +} + +Status StreamMetadata::UpdateStreamFields(const StreamFields &streamFields) +{ + RETURN_IF_NOT_OK(streamStore_->UpdateStream(streamName_, streamFields)); + streamFields_ = streamFields; + LOG(INFO) << FormatString( + "[%s] Stream configuration updated with max stream size: %zu, page size: %zu, " + "auto cleanup: %s, retainDataForNumConsumers: %zu, encrypt stream: %s, and reserve size: %zu", + LogPrefix(), streamFields_.maxStreamSize_, streamFields_.pageSize_, + streamFields_.autoCleanup_ ? "true" : "false", streamFields_.retainForNumConsumers_, + streamFields_.encryptStream_ ? "true" : "false", streamFields_.reserveSize_); + retainData_.Init(streamFields_.retainForNumConsumers_); + return Status::OK(); +} + +Status StreamMetadata::PubIncreaseNodeStart(const StreamFields &streamFields, const ProducerMetaPb &producerMeta, + bool alreadyLocked, bool &streamFieldsVerified, bool &saveToRocksdb, + bool &isFirstProducer) +{ + WriteLockHelper wlocker(DEFER_LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + if (!alreadyLocked) { + wlocker.AcquireLock(); + } + // Check timeout before processing isFirstProducer, it might happen that the request already timed out, + // so another CreateProducer request of the stream comes through from the same worker. + CHECK_FAIL_RETURN_STATUS(scTimeoutDuration.CalcRealRemainingTime() > 0, K_RPC_DEADLINE_EXCEEDED, + "CreateProducer RPC timeout."); + + RETURN_IF_NOT_OK(VerifyStreamFields(streamFields)); + if (streamFields_ != streamFields) { + RETURN_IF_NOT_OK(UpdateStreamFields(streamFields)); + } + streamFieldsRefcount_++; + streamFieldsVerified = true; + + // Update topological structure on master node. + CHECK_FAIL_RETURN_STATUS(!topoManager_->GetStreamStatus(), StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Stream <%s> is undergoing deletion, do not operate it", streamName_)); + RETURN_IF_NOT_OK(topoManager_->PubIncreaseNode(producerMeta, isFirstProducer)); + // Save it in memory and rocksdb. + RETURN_IF_NOT_OK(streamStore_->AddPubNode(producerMeta)); + saveToRocksdb = true; + + return Status::OK(); +} + +Status StreamMetadata::PubIncreaseNodeEnd(bool isFirstProducer, bool needsRollback, bool alreadyLocked, + const ProducerMetaPb &producerMeta, const HostPort &pubWorkerAddress, + bool streamFieldsVerified, const std::vector ¬ifyNodeSet, + bool saveToRocksdb) +{ + WriteLockHelper wlocker(DEFER_LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + if (!alreadyLocked) { + wlocker.AcquireLock(); + } + + if (isFirstProducer) { + RETURN_IF_NOT_OK(topoManager_->PubNodeFirstOrLastDone(producerMeta)); + } + + // Early return if rollback is not needed. + RETURN_OK_IF_TRUE(!needsRollback); + + // If other CreateProducer requests come with the same stream fields, + // then there is no need to rollback the stream field settings. + // So only rollback when ref count is 0 after decrement. + if (streamFieldsVerified) { + streamFieldsRefcount_--; + if (streamFieldsRefcount_ == 0) { + // Rollback stream fields to empty. + streamFields_ = StreamFields(); + RETURN_IF_NOT_OK(streamStore_->UpdateStream(streamName_, StreamFields())); + // Retain data need to be unset. + retainData_.RollBackToInit(); + } + } + + for (const auto &node : notifyNodeSet) { + NotifyPubPb pub; + pub.set_is_close(true); + pub.set_stream_name(streamName_); + pub.set_worker_addr(pubWorkerAddress.ToString()); + RETURN_IF_NOT_OK(notifyWorkerManager_->AddAsyncPubNotification(node.ToString(), pub)); + } + + if (saveToRocksdb) { + RETURN_IF_NOT_OK(topoManager_->PubDecreaseNode(producerMeta, false)); + RETURN_IF_NOT_OK(streamStore_->DelPubNode(producerMeta)); + } + return Status::OK(); +} + +Status StreamMetadata::PubIncreaseNodeInternal(const ProducerMetaPb &producerMeta, StreamFields &streamFields, + const HostPort &pubWorkerAddress, bool isRecon, bool alreadyLocked) +{ + bool saveToRocksdb = false; + bool isFirstProducer = false; + bool streamFieldsVerified = false; + const auto &streamName = producerMeta.stream_name(); + CHECK_FAIL_RETURN_STATUS(streamName == streamName_, K_INVALID, + FormatString("Stream name mismatch, expected %s, received %s", streamName_, streamName)); + std::vector notifyNodeSet; + + bool needsRollback = true; + Raii pubIncreaseNodeEnd([&]() { + LOG_IF_ERROR(PubIncreaseNodeEnd(isFirstProducer, needsRollback, alreadyLocked, producerMeta, pubWorkerAddress, + streamFieldsVerified, notifyNodeSet, saveToRocksdb), + "PubIncreaseNodeEnd rollback failed."); + }); + + RETURN_IF_NOT_OK(PubIncreaseNodeStart(streamFields, producerMeta, alreadyLocked, streamFieldsVerified, + saveToRocksdb, isFirstProducer)); + if (isFirstProducer) { + Status rc = PubIncreaseNodeImpl(pubWorkerAddress, notifyNodeSet, isRecon); + LOG(INFO) << FormatString("[%s] Add new pub node:<%s> finish with %s.", LogPrefix(), + pubWorkerAddress.ToString(), rc.GetMsg()); + RETURN_IF_NOT_OK(rc); + } + needsRollback = false; + return Status::OK(); +} + +RetainDataState::State StreamMetadata::CheckNUpdateNeedRetainData() +{ + bool retainStateChange = false; + const bool update = true; + return CheckNeedRetainData(retainStateChange, update); +} + +RetainDataState::State StreamMetadata::CheckNeedRetainData(bool &retainStateChange, const bool update) +{ + // If stream is initialized (at least a producer is created) + // And we are currently retaining data + if (!streamFields_.Empty() && retainData_.GetRetainDataState() != RetainDataState::State::NOT_RETAIN) { + // We have at least as many consumers that user asked for + if (topoManager_->GetConsumerCountForLife() >= streamFields_.retainForNumConsumers_) { + // then do not retain data + if (update) { + retainData_.SetRetainDataState(RetainDataState::State::NOT_RETAIN); + } + retainStateChange = true; + } + } + return retainData_.GetRetainDataState(); +} + +Status StreamMetadata::PubIncreaseNodeImpl(const HostPort &pubWorkerAddress, std::vector ¬ifyNodeSet, + bool isRecon) +{ + // Now, we have two things to do, the first one is to Sync subTopoSet(in consumer sense) for src node + // consumerNodeSet = {consumer1, consumer2, ..., consumerN}, this set will sync to src pub node as remote consumers. + auto consumerMetaList = topoManager_->GetAllConsumerNotFromSrc(pubWorkerAddress.ToString()); + auto retainData = CheckNUpdateNeedRetainData(); // Check if we have to retain the data + std::shared_ptr masterWorkerApi = nullptr; + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(pubWorkerAddress, masterWorkerApi, akSkManager_)); + if (!consumerMetaList.empty() + || (retainData == RetainDataState::State::NOT_RETAIN && streamFields_.retainForNumConsumers_)) { + LOG(INFO) << "[RetainData] Master sending SyncConsumerNode for stream " << streamName_ << "Remote producer " + << pubWorkerAddress.ToString() << " Current state " << retainData; + RETURN_IF_NOT_OK_EXCEPT(masterWorkerApi->SyncConsumerNode(streamName_, consumerMetaList, retainData, isRecon), + StatusCode::K_DUPLICATED); + } + + INJECT_POINT("master.PubIncreaseNodeImpl.beforeSendNotification"); + + // And the second is to broadcast to each sub node of this stream the new producer comes. + std::set subNodeSet; + RETURN_IF_NOT_OK(topoManager_->GetAllSubNode(subNodeSet)); + RETURN_IF_NOT_OK(RemoveSourceWorker(pubWorkerAddress, subNodeSet)); + for (const auto &subNode : subNodeSet) { + RETURN_IF_NOT_OK_EXCEPT( + notifyWorkerManager_->NotifyNewPubNode(subNode, streamName_, streamFields_, pubWorkerAddress), + StatusCode::K_DUPLICATED); + notifyNodeSet.emplace_back(subNode); + INJECT_POINT("master.PubIncreaseNodeImpl.afterSendNotification"); + } + // Check after the UpdateTopoNotification, the RPC can be done through local bypass, that would run without + // scTimeoutDuration. If it timeout, we should rollback because worker side should have stopped waiting for + // response. + CHECK_FAIL_RETURN_STATUS(scTimeoutDuration.CalcRealRemainingTime() > 0, K_RPC_DEADLINE_EXCEEDED, + "CreateProducer RPC timeout."); + return Status::OK(); +} + +Status StreamMetadata::PubDecreaseNodeStart(const ProducerMetaPb &producerMeta, bool &isLastProducer) +{ + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + // Check if this worker's last producer is already closed + CHECK_FAIL_RETURN_STATUS(topoManager_->ExistsProducer(producerMeta), K_SC_PRODUCER_NOT_FOUND, + "producer not exists"); + RETURN_IF_NOT_OK(topoManager_->PubDecreaseNodeStart(producerMeta, isLastProducer)); + return Status::OK(); +} + +Status StreamMetadata::PubDecreaseNode(const ProducerMetaPb &producerMeta, const bool forceClose) +{ + LOG(INFO) << "Topo for stream " << streamName_ << ":" << *topoManager_; + // Check topo manager status and return if stream is getting deleted. + CHECK_FAIL_RETURN_STATUS(!topoManager_->GetStreamStatus(), StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Stream <%s> is undergoing deletion, do not operate it", streamName_)); + + // Pre-handling to check whether it is the last producer. + // Requests should have been serialized by worker level create lock in most cases. + // But still set firstProducerProcessing_ to stop + // other CreateProducer request of the same worker from happening in timeout cases. + bool isLastProducer = false; + RETURN_IF_NOT_OK(PubDecreaseNodeStart(producerMeta, isLastProducer)); + std::unique_ptr pubDecreaseNodeEnd = std::make_unique([&]() { + if (isLastProducer) { + LOG_IF_ERROR(topoManager_->PubNodeFirstOrLastDone(producerMeta), ""); + } + }); + HostPort pubWorkerAddress(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + if (isLastProducer) { + INJECT_POINT("master.PubDecreaseNode.beforeSendNotification"); + // Only send notification when force close is true. + if (forceClose) { + // Construct rpc channel between master and all sub node. + std::set subNodeSet; + RETURN_IF_NOT_OK(topoManager_->GetAllSubNode(subNodeSet)); + RETURN_IF_NOT_OK(RemoveSourceWorker(pubWorkerAddress, subNodeSet)); + for (const auto &subNode : subNodeSet) { + auto rc = notifyWorkerManager_->NotifyDelPubNode(subNode, streamName_, pubWorkerAddress, forceClose); + if (rc.IsError() && rc.GetCode() != StatusCode::K_NOT_FOUND + && rc.GetCode() != StatusCode::K_SC_STREAM_NOT_FOUND) { + LOG(ERROR) << "NotifyDelPubNode failed " << rc.GetMsg(); + return rc; + } + } + } + INJECT_POINT("master.PubDecreaseNode.afterSendNotification"); + } else { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Producer decrease S:%s. Not the last producer for worker %s", + streamName_, pubWorkerAddress.ToString()); + } + + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + // In the happy path, unset firstProducerProcessing_ state under lock protection to avoid timing hole. + pubDecreaseNodeEnd.reset(); + RETURN_IF_NOT_OK(topoManager_->PubDecreaseNode(producerMeta, false)); + LOG_IF_ERROR(streamStore_->DelPubNode(producerMeta), "Delete pub node failed in rocksdb"); + // If auto clean up is true and there is no more producer/consumer, delete the stream. + if (isLastProducer && streamFields_.autoCleanup_) { + RETURN_IF_NOT_OK(AutoCleanupIfNeededNotLocked(pubWorkerAddress)); + } + return Status::OK(); +} + +Status StreamMetadata::SubIncreaseNode(const ConsumerMetaPb &consumerMeta, bool isRecon) +{ + HostPort subWorkerAddress(consumerMeta.worker_address().host(), consumerMeta.worker_address().port()); + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + INJECT_POINT("master.SubIncreaseNode.afterLock"); + CHECK_FAIL_RETURN_STATUS(scTimeoutDuration.CalcRealRemainingTime() > 0, K_RPC_DEADLINE_EXCEEDED, + FormatString("[%s] Subscribe Request timeout.", LogPrefix())); + LOG(INFO) << "Topo for stream " << streamName_ << ":" << *topoManager_; + // Check topo manager status and return if stream is getting deleted. + CHECK_FAIL_RETURN_STATUS(!topoManager_->GetStreamStatus(), StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Stream <%s> is undergoing deletion, do not operate it", streamName_)); + return SubIncreaseNodeUnlocked(consumerMeta, subWorkerAddress, isRecon); +} + +Status StreamMetadata::SubIncreaseNodeUnlocked(const ConsumerMetaPb &consumerMeta, const HostPort &subWorkerAddress, + bool isRecon) +{ + bool saveToRocksdb = false; + bool sendToSrcNode = false; + std::vector notifyNodeSet; + CHECK_FAIL_RETURN_STATUS( + consumerMeta.stream_name() == streamName_, K_INVALID, + FormatString("Stream name mismatch, expected %s, received %s", streamName_, consumerMeta.stream_name())); + Status rc = + SubIncreaseNodeImpl(consumerMeta, subWorkerAddress, saveToRocksdb, sendToSrcNode, notifyNodeSet, isRecon); + LOG(INFO) << FormatString("[%s, C:%s] Add new consumer finish with %s.", LogPrefix(), + LogHelper::IgnoreSensitive(consumerMeta), rc.GetMsg()); + + RETURN_OK_IF_TRUE(rc.IsOk()); + + // Add async sub close notification to rollback. + for (const auto &node : notifyNodeSet) { + NotifyConsumerPb sub; + *sub.mutable_consumer() = consumerMeta; + sub.set_is_close(true); + RETURN_IF_NOT_OK(notifyWorkerManager_->AddAsyncSubNotification(node.ToString(), sub)); + } + + if (sendToSrcNode) { + std::shared_ptr masterWorkerApi = nullptr; + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(subWorkerAddress, masterWorkerApi, akSkManager_)); + RETURN_IF_NOT_OK_EXCEPT(masterWorkerApi->ClearAllRemotePub(streamName_), K_NOT_FOUND); + } + + if (saveToRocksdb) { + RETURN_IF_NOT_OK(topoManager_->SubDecreaseNode(consumerMeta, true, false)); + RETURN_IF_NOT_OK(streamStore_->DelSubNode(streamName_, consumerMeta.consumer_id())); + // Update count for rollback + RETURN_IF_NOT_OK( + streamStore_->UpdateLifeTimeConsumerCount(streamName_, topoManager_->GetConsumerCountForLife())); + } + return rc; +} + +Status StreamMetadata::NotifyStopRetainData(const HostPort &subWorkerAddress, bool retainStateChange) +{ + // Once the request is complete, retainData state has changed inform all related nodes + if (retainStateChange) { + // Relative node set for the producers + std::set relatedNodeSet; + RETURN_IF_NOT_OK(topoManager_->GetAllPubNode(relatedNodeSet, retainStateChange)); + RETURN_IF_NOT_OK(RemoveSourceWorker(subWorkerAddress, relatedNodeSet)); + for (const auto &relatedNode : relatedNodeSet) { + LOG(INFO) << "[RetainData] Master sending NotifyStopRetention for stream " << streamName_ + << "Remote producer " << relatedNode.ToString(); + RETURN_IF_NOT_OK(notifyWorkerManager_->AddAsyncStopDataRetentionNotification(relatedNode, streamName_)); + } + } + return Status::OK(); +} + +Status StreamMetadata::SubIncreaseNodeImpl(const ConsumerMetaPb &consumerMeta, const HostPort &subWorkerAddress, + bool &saveToRocksdb, bool &sendToSrcNode, + std::vector ¬ifyNodeSet, bool isRecon) +{ + CHECK_FAIL_RETURN_STATUS(!topoManager_->GetStreamStatus(), StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Stream <%s> is undergoing deletion, do not operate it", streamName_)); + RETURN_IF_NOT_OK(topoManager_->CheckNewConsumer(consumerMeta)); + + // Save it in memory and rocksdb. + bool isFirstConsumer = false; + RETURN_IF_NOT_OK(topoManager_->SubIncreaseNode(consumerMeta, isFirstConsumer)); + RETURN_IF_NOT_OK(streamStore_->AddSubNode(consumerMeta)); + RETURN_IF_NOT_OK(streamStore_->UpdateLifeTimeConsumerCount(streamName_, topoManager_->GetConsumerCountForLife())); + saveToRocksdb = true; + + bool retainStateChange = false; + (void)CheckNeedRetainData(retainStateChange); + + // Notify the source node to add remote producer information. + std::set pubNodeSet; + RETURN_IF_NOT_OK(topoManager_->GetAllPubNode(pubNodeSet)); + RETURN_IF_NOT_OK(RemoveSourceWorker(subWorkerAddress, pubNodeSet)); + if (isFirstConsumer && !pubNodeSet.empty()) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("consumer meta:<%s>, isFirstConsumer:<%d>", + LogHelper::IgnoreSensitive(consumerMeta), isFirstConsumer); + std::shared_ptr masterWorkerApi = nullptr; + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(subWorkerAddress, masterWorkerApi, akSkManager_)); + RETURN_IF_NOT_OK_EXCEPT(masterWorkerApi->SyncPubNode(streamName_, pubNodeSet, isRecon), + StatusCode::K_DUPLICATED); + sendToSrcNode = true; + } + + // Inform all nodes that have a producer or previously had a producer until retainData is flipped + if (retainData_.IsDataRetained() || retainStateChange) { + // When state changed, Inform all producers (already closed ones too) of new consumer + std::set relatedNodeSet; + RETURN_IF_NOT_OK(topoManager_->GetAllPubNode(relatedNodeSet, true)); + RETURN_IF_NOT_OK(RemoveSourceWorker(subWorkerAddress, relatedNodeSet)); + pubNodeSet = relatedNodeSet; + LOG(INFO) << "[RetainData] Informing all related nodes of SubIncrease " << pubNodeSet.size(); + } + // Construct rpc channel between master and all pub node, broadcast this new consumer to all pub worker node. + // We send this notifications to all pub nodes including the ones that have previous producer closed + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("New Sub of Stream: [%s], notify [%zu] pubNodes", streamName_, + pubNodeSet.size()); + for (const auto &pubNode : pubNodeSet) { + RETURN_IF_NOT_OK_EXCEPT(notifyWorkerManager_->NotifyNewConsumer(pubNode, consumerMeta), + StatusCode::K_DUPLICATED); + notifyNodeSet.emplace_back(pubNode); + INJECT_POINT("master.SubIncreaseNodeImpl.afterSendNotification"); + } + + // Notify all nodes of state change + RETURN_IF_NOT_OK(NotifyStopRetainData(subWorkerAddress, retainStateChange)); + if (retainStateChange) { + retainData_.SetRetainDataState(RetainDataState::State::NOT_RETAIN); + } + CHECK_FAIL_RETURN_STATUS(scTimeoutDuration.CalcRealRemainingTime() > 0, K_RPC_DEADLINE_EXCEEDED, + FormatString("[%s] Request timeout.", LogPrefix())); + return Status::OK(); +} + +Status StreamMetadata::SubDecreaseNode(const ConsumerMetaPb &consumerMeta) +{ + // Check topo manager status and return if stream is getting deleted. + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + LOG(INFO) << "Topo for stream " << streamName_ << ":" << *topoManager_; + CHECK_FAIL_RETURN_STATUS(!topoManager_->GetStreamStatus(), StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Stream <%s> is undergoing deletion, do not operate it", streamName_)); + CHECK_FAIL_RETURN_STATUS(topoManager_->ExistsConsumer(consumerMeta.consumer_id()), K_SC_CONSUMER_NOT_FOUND, + "consumer not exists"); + + // Construct rpc channel between master and all pub node. + HostPort subWorkerAddress(consumerMeta.worker_address().host(), consumerMeta.worker_address().port()); + + std::set pubNodeSet; + // Inform all nodes that have a producer or previously had a producer until retainData is flipped + if (retainData_.IsDataRetained()) { + // When state changed, Inform all producers (already closed ones too) of new consumer + RETURN_IF_NOT_OK(topoManager_->GetAllPubNode(pubNodeSet, true)); + RETURN_IF_NOT_OK(RemoveSourceWorker(subWorkerAddress, pubNodeSet)); + LOG(INFO) << "[RetainData] Informing all related nodes of SubDecrease " << pubNodeSet.size(); + } else { + RETURN_IF_NOT_OK(topoManager_->GetAllPubNode(pubNodeSet)); + RETURN_IF_NOT_OK(RemoveSourceWorker(subWorkerAddress, pubNodeSet)); + } + + for (const auto &pubNode : pubNodeSet) { + auto rc = notifyWorkerManager_->NotifyDelConsumer(pubNode, consumerMeta); + if (rc.IsError() && rc.GetCode() != StatusCode::K_SC_CONSUMER_NOT_FOUND + && rc.GetCode() != StatusCode::K_SC_STREAM_NOT_FOUND) { + LOG(ERROR) << "NotifyDelConsumer failed " << rc.GetMsg(); + return rc; + } + INJECT_POINT("master.SubDecreaseNode.afterSendNotification"); + } + + RETURN_IF_NOT_OK(topoManager_->SubDecreaseNode(consumerMeta, false, false)); + LOG_IF_ERROR(streamStore_->DelSubNode(streamName_, consumerMeta.consumer_id()), + "Delete sub node failed in rocksdb"); + LOG(INFO) << FormatString("[%s, C:%s] Delete consumer success.", LogPrefix(), + LogHelper::IgnoreSensitive(consumerMeta)); + // If there is no more consumer and auto clean up is on + // Also consider the case that + // (a) consumer is created first + // (b) consumer crash and the application will not create any producer + // We will take this chance to clean up the stream + bool lastConsumer = topoManager_->GetConsumerCountInWorker(subWorkerAddress.ToString()) == 0; + if ((streamFields_.Empty() || streamFields_.autoCleanup_) && lastConsumer) { + RETURN_IF_NOT_OK(AutoCleanupIfNeededNotLocked(subWorkerAddress)); + } + return Status::OK(); +} + +Status StreamMetadata::DeleteStreamStart(const HostPort &srcNode, std::set &relatedWorkerSet) +{ + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + // 1. Check if stream is already getting deleted + CHECK_FAIL_RETURN_STATUS(!topoManager_->GetStreamStatus(), StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Stream <%s> is undergoing deletion, do not operate it", streamName_)); + // 2. Deal with concurrent delete, directly return if this stream has been deleted. + RETURN_OK_IF_TRUE(alive_ == false); + + // 3. Check if any producer and consumers exists for this stream + // We do a check on the target stream to find out all producer and consumer in global scope has been closed. + std::set pubWorkerSet; + RETURN_IF_NOT_OK(topoManager_->GetAllPubNode(pubWorkerSet)); + CHECK_FAIL_RETURN_STATUS( + pubWorkerSet.empty(), StatusCode::K_SC_STREAM_IN_USE, + FormatString("Stream:<%s>, State:, Number:<%d>, Request src:<%s>", streamName_, + pubWorkerSet.size(), srcNode.ToString())); + std::set subWorkerSet; + RETURN_IF_NOT_OK(topoManager_->GetAllSubNode(subWorkerSet)); + CHECK_FAIL_RETURN_STATUS( + subWorkerSet.empty(), StatusCode::K_SC_STREAM_IN_USE, + FormatString("Stream:<%s>, State:, Number:<%d>, Request src:<%s>", streamName_, + subWorkerSet.size(), srcNode.ToString())); + + // 4. Check if any notifications are still needed to be sent + CHECK_FAIL_RETURN_STATUS(!notifyWorkerManager_->ExistsPendingNotification(streamName_), + StatusCode::K_SC_STREAM_NOTIFICATION_PENDING, + FormatString("Stream <%s> exists pending notification, try again later", streamName_)); + + // If we pass all checks, we set this stream state as [isDeleting], and do the following process. + RETURN_IF_NOT_OK(topoManager_->SetDeletingStatus()); + RETURN_IF_NOT_OK(topoManager_->GetAllRelatedNode(relatedWorkerSet)); + relatedWorkerSet.erase(srcNode); + deleterRefCount_ += 1; + LOG(INFO) << FormatString("[%s] Delete stream request started, ref count increase to %d.", LogPrefix(), + deleterRefCount_); + return Status::OK(); +} + +Status StreamMetadata::DeleteStreamEnd() +{ + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + + RETURN_IF_NOT_OK(streamStore_->DelStream(streamName_)); + alive_ = false; + deleterRefCount_ -= 1; + LOG(INFO) << FormatString("[%s] Delete stream from streamStore success. Ref Count decrease to %d", LogPrefix(), + deleterRefCount_); + return Status::OK(); +} + +void StreamMetadata::UndoDeleteStream(bool decrementRef) +{ + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + if (decrementRef) { + deleterRefCount_ -= 1; + LOG(INFO) << FormatString("[%s] Ref Count decrease to %d", LogPrefix(), deleterRefCount_); + } + // only undo if there are no other deletes running + if (deleterRefCount_ == 0) { + LOG(INFO) << FormatString("[%s] Undoing Delete stream.", LogPrefix()); + this->topoManager_->UnsetDeletingStatus(); + } +} + +std::string StreamMetadata::LogPrefix() const +{ + return FormatString("S:%s", streamName_); +} + +Status StreamMetadata::RecoveryPubMeta(const ProducerMetaPb &producerMetaPb) +{ + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + bool isFirstProducer = false; + RETURN_IF_NOT_OK(topoManager_->PubIncreaseNode(producerMetaPb, isFirstProducer)); + retainData_.Init(streamFields_.retainForNumConsumers_); + // Recovery code path would not need the protection, so directly call PubNodeFirstOrLastDone to unset if applicable. + if (isFirstProducer) { + RETURN_IF_NOT_OK(topoManager_->PubNodeFirstOrLastDone(producerMetaPb)); + } + LOG(INFO) << FormatString("[%s] Recovery pub node meta:<%s> on master success, isFirstProducer on that node:%s", + LogPrefix(), LogHelper::IgnoreSensitive(producerMetaPb), + isFirstProducer ? "true" : "false"); + return Status::OK(); +} + +Status StreamMetadata::RecoverySubMeta(const ConsumerMetaPb &consumerMeta) +{ + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + bool isFirstConsumer = false; + RETURN_IF_NOT_OK(topoManager_->SubIncreaseNode(consumerMeta, isFirstConsumer)); + LOG(INFO) << FormatString("[%s] Recovery consumer meta:<%s> on master success, isFirstConsumer on that node:%s", + LogPrefix(), LogHelper::IgnoreSensitive(consumerMeta), + isFirstConsumer ? "true" : "false"); + return Status::OK(); +} + +Status StreamMetadata::RemoveSourceWorker(const HostPort &srcWorkerAddress, std::set &nodeSet) +{ + auto totalSize = nodeSet.size(); + if (nodeSet.find(srcWorkerAddress) != nodeSet.end()) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Src worker node:<%s>, State:", + srcWorkerAddress.ToString()); + CHECK_FAIL_RETURN_STATUS(nodeSet.erase(srcWorkerAddress) == 1, StatusCode::K_RUNTIME_ERROR, + "Runtime error in set erase"); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Origin set size:<%d>, Actual set size:<%d>, Operation:", + totalSize, nodeSet.size()); + return Status::OK(); +} + +Status StreamMetadata::ClearPubSubMetaData(const HostPort &workerAddr, + const std::unordered_map &producerMap, + const std::unordered_map &consumerMap, + const bool forceClose, const bool delWorker) +{ + INJECT_POINT("StreamMetadata.ClearPubSubMetaData.sleep"); + bool pubExists = !producerMap.empty(); + bool subExists = !consumerMap.empty(); + LOG(INFO) << FormatString("[%s] ClearPubSubMetaData for worker [%s] producerCount:%d, consumerCount:%d", + LogPrefix(), workerAddr.ToString(), producerMap.size(), consumerMap.size()); + CHECK_FAIL_RETURN_STATUS(pubExists || subExists, K_NOT_FOUND, + FormatString("[%s] No meta on worker %s", LogPrefix(), workerAddr.ToString())); + Status status; + for (const auto &kvPub : producerMap) { + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, streamStore_->DelPubNode(kvPub.second), "DelPubNode failed"); + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, topoManager_->PubDecreaseNode(kvPub.second, delWorker), + "PubDecreaseNode failed"); + } + bool producerCnt = topoManager_->GetProducerCountInWorker(workerAddr.ToString()); + bool pubNodeDelete = !producerMap.empty() && producerCnt == 0; + + for (const auto &kvSub : consumerMap) { + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, streamStore_->DelSubNode(streamName_, kvSub.second.consumer_id()), + "DelSubNode failed"); + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, topoManager_->SubDecreaseNode(kvSub.second, false, delWorker), + "SubDecreaseNode failed"); + } + + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, + AddAsyncClearNotification(workerAddr, pubNodeDelete, consumerMap, forceClose), + "AddAsyncNotification failed"); + + // Check if we can delete the stream + bool lastProducerCleanup = (producerCnt == 0 && streamFields_.autoCleanup_); + bool lastConsumerCleanup = (topoManager_->GetConsumerCountInWorker(workerAddr.ToString()) == 0) + && (streamFields_.Empty() || streamFields_.autoCleanup_); + // If auto clean up is true and there is no more producer/consumer, delete the stream. + if (lastProducerCleanup && lastConsumerCleanup) { + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, AutoCleanupIfNeededNotLocked(workerAddr), + "AutoCleanupIfNeededNotLocked failed"); + } + return status; +} + +Status StreamMetadata::AddAsyncClearNotification(const HostPort &workerAddr, bool pubNodeDelete, + const std::unordered_map &consumerMap, + const bool forceClose) +{ + std::set pubNodeSet; + std::set subNodeSet; + std::set allNodeSet; + RETURN_IF_NOT_OK(topoManager_->GetAllPubNode(pubNodeSet)); + RETURN_IF_NOT_OK(topoManager_->GetAllSubNode(subNodeSet)); + RETURN_IF_NOT_OK(topoManager_->GetAllRelatedNode(allNodeSet)); + + Status status; + for (const auto &nodeAddr : allNodeSet) { + if (nodeAddr == workerAddr) { + continue; + } + + if (pubNodeDelete && subNodeSet.count(nodeAddr) > 0) { + NotifyPubPb pub; + pub.set_is_close(true); + pub.set_force_close(forceClose); + pub.set_stream_name(streamName_); + pub.set_worker_addr(workerAddr.ToString()); + pub.set_max_stream_size(streamFields_.maxStreamSize_); + pub.set_page_size(streamFields_.pageSize_); + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, + notifyWorkerManager_->AddAsyncPubNotification(nodeAddr.ToString(), pub), + "AddAsyncPubNotification failed"); + } + + if (!consumerMap.empty() && pubNodeSet.count(nodeAddr) > 0) { + for (const auto &kv : consumerMap) { + NotifyConsumerPb sub; + *sub.mutable_consumer() = kv.second; + sub.set_is_close(true); + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG( + status, notifyWorkerManager_->AddAsyncSubNotification(nodeAddr.ToString(), sub), + "AddAsyncSubNotification failed"); + } + } + } + return status; +} + +Status StreamMetadata::ClearWorkerMetadata(const HostPort &workerAddr, const bool forceClose, bool delWorker) +{ + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + std::unordered_map producerMap; + RETURN_IF_NOT_OK(topoManager_->GetAllProducerFromWorker(workerAddr, producerMap)); + auto consumerMap = topoManager_->GetAllConsumerFromWorker(workerAddr); + // Delete the related worker from master metadata as its faulty in this case + RETURN_IF_NOT_OK(ClearPubSubMetaData(workerAddr, producerMap, consumerMap, forceClose, delWorker)); + return Status::OK(); +} + +Status StreamMetadata::CheckMetadata(const GetStreamMetadataRspPb &meta, const HostPort &workerAddr) +{ + VLOG(SC_INTERNAL_LOG_LEVEL) << "Check metadata for stream " << streamName_; + WriteLockHelper wlocker(LOCK_ARGS_MSG_FN(mutex_, LogPrefix)); + + // Get worker metadata. + VLOG(SC_INTERNAL_LOG_LEVEL) << "Check metadata response from worker: " << LogHelper::IgnoreSensitive(meta); + std::vector workerProducers(meta.producers().begin(), meta.producers().end()); + std::vector workerConsumers(meta.consumers().begin(), meta.consumers().end()); + + // Get master metadata. + std::unordered_map masterProducerMap; + RETURN_IF_NOT_OK(topoManager_->GetAllProducerFromWorker(workerAddr, masterProducerMap)); + auto masterConsumerMap = topoManager_->GetAllConsumerFromWorker(workerAddr); + + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("Check metadata master producer count: %d, consumer count: %d", + masterProducerMap.size(), masterConsumerMap.size()); + + // Compare and get the difference. + CompareAndErase(workerProducers, masterProducerMap, + [](const ProducerMetaPb &meta) { return HostPb2Str(meta.worker_address()); }); + CompareAndErase(workerConsumers, masterConsumerMap, [](const ConsumerMetaPb &meta) { return meta.consumer_id(); }); + + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "Check result: ", workerProducers.size(), + workerConsumers.size(), masterProducerMap.size(), masterConsumerMap.size()); + + // Update missing pub-sub metadata. + auto streamFields = ConvertGetStreamMetadataRspPb2StreamFields(meta); + if (!workerProducers.empty() || !workerConsumers.empty()) { + // Update consumers first, so that we have count for retain consumers + for (auto &consumerMeta : workerConsumers) { + RETURN_IF_NOT_OK(SubIncreaseNodeUnlocked(consumerMeta, workerAddr, true)); + } + for (auto &producerMeta : workerProducers) { + RETURN_IF_NOT_OK(PubIncreaseNodeInternal(producerMeta, streamFields, workerAddr, true)); + } + } else if (topoManager_->RecoverEmptyMetaIfNeeded(workerAddr)) { + RETURN_IF_NOT_OK(VerifyAndUpdateStreamFields(streamFields)); + } + + // Notify clear pub and sub. + if (!masterProducerMap.empty() || !masterConsumerMap.empty()) { + // Do not delete the worker from metadata as its not faulty in this case + RETURN_IF_NOT_OK_EXCEPT(ClearPubSubMetaData(workerAddr, masterProducerMap, masterConsumerMap, false, false), + K_NOT_FOUND); + } else { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "no need to clear pub and sub in master after comparison with worker addr %s", workerAddr.ToString()); + } + + // Clear remote pubs in sub node if last consumer. + bool isAllConsumerClosed = IsAllConsumerClosed(workerAddr.ToString()); + if (!meta.is_remote_pub_empty() && isAllConsumerClosed && CheckWorkerStatus(workerAddr).IsOk()) { + std::shared_ptr masterWorkerApi = nullptr; + RETURN_IF_NOT_OK(rpcSessionManager_->GetRpcSession(workerAddr, masterWorkerApi, akSkManager_)); + RETURN_IF_NOT_OK(masterWorkerApi->ClearAllRemotePub(meta.stream_name())); + } + + VLOG(SC_INTERNAL_LOG_LEVEL) << "Check metadata for stream " << streamName_ << " finish."; + return Status::OK(); +} + +void StreamMetadata::GetAllProducerConsumer(std::vector &masterProducers, + std::vector &masterConsumers, + std::vector &producerRelatedNodes, + std::vector &consumerRelatedNodes) +{ + LOG_IF_ERROR(topoManager_->GetAllProducer(masterProducers), "failed to get producer data"); + masterConsumers = topoManager_->GetAllConsumer(); + topoManager_->GetAllRelatedNode(producerRelatedNodes, consumerRelatedNodes); +} + +void StreamMetadata::PreparePubSubRelNodes(const std::vector &producerRelatedNodes, + const std::vector &consumerRelatedNodes) +{ + topoManager_->PreparePubSubRelNodes(producerRelatedNodes, consumerRelatedNodes); +} + +Status StreamMetadata::CleanUpStreamPersistent(const std::string &streamName) +{ + RETURN_IF_NOT_OK(streamStore_->DelStream(streamName)); + return Status::OK(); +} + +Status StreamMetadata::AutoCleanupIfNeeded(const HostPort &srcHost) +{ + std::lock_guard lock(mutex_); + RETURN_IF_NOT_OK(AutoCleanupIfNeededNotLocked(srcHost)); + return Status::OK(); +} + +Status StreamMetadata::AutoCleanupIfNeededNotLocked(const HostPort &srcHost) +{ + RETURN_OK_IF_TRUE(!streamFields_.autoCleanup_ && !streamFields_.Empty()); + auto cleanup = topoManager_->CheckIfAllPubSubHaveClosed(); + RETURN_OK_IF_TRUE(!cleanup); + LOG(INFO) << FormatString("[%s] Driving auto stream delete%s", LogPrefix(), + srcHost.Empty() ? "" : " from worker " + srcHost.ToString()); + RETURN_IF_NOT_OK(notifyWorkerManager_->AddAsyncDeleteNotification(streamName_)); + return Status::OK(); +} + +Status StreamMetadata::CheckWorkerStatus(const HostPort &workerHostPort) +{ + if (etcdCM_ == nullptr) { + RETURN_STATUS_LOG_ERROR(StatusCode::K_INVALID, "ETCD cluster manager is nullptr."); + } + auto rc = etcdCM_->CheckConnection(workerHostPort); + if (rc.IsError()) { + RETURN_STATUS_LOG_ERROR(K_WORKER_ABNORMAL, FormatString("The worker %s is abnormal, detail: %s", + workerHostPort.ToString(), rc.GetMsg())); + } + return Status::OK(); +} + +Status StreamMetadata::ProcessClearAllRemotePub(const std::shared_ptr &masterWorkerApi, + const HostPort &subWorkerAddress) +{ + RETURN_RUNTIME_ERROR_IF_NULL(masterWorkerApi); + static const int RETRY_TIMEOUT_MS = 60'000; // 1 min + const std::unordered_set &retryOn = { StatusCode::K_TRY_AGAIN, StatusCode::K_RPC_CANCELLED, + StatusCode::K_RPC_DEADLINE_EXCEEDED, + StatusCode::K_RPC_UNAVAILABLE }; + Status rc; + switch (masterWorkerApi->TypeId()) { + case MasterWorkerSCApiType::MasterLocalWorkerSCApi: + RETURN_IF_NOT_OK(RetryOnError( + RETRY_TIMEOUT_MS, + [&masterWorkerApi, this](int32_t) { return masterWorkerApi->ClearAllRemotePub(streamName_); }, + []() { return Status::OK(); }, retryOn)); + break; + case MasterWorkerSCApiType::MasterRemoteWorkerSCApi: + auto masterRemoteWorkerOCApi = dynamic_cast(masterWorkerApi.get()); + int32_t maxRpcTimeoutMs = 5'000; // 5s + int64_t tagId; + auto timer = Timer(RETRY_TIMEOUT_MS); + RETURN_IF_NOT_OK(RetryOnError( + RETRY_TIMEOUT_MS, + [&masterRemoteWorkerOCApi, &subWorkerAddress, &tagId, this](int32_t timeoutMs) { + RETURN_IF_NOT_OK(CheckWorkerStatus(subWorkerAddress)); + scTimeoutDuration.Init(timeoutMs); + return masterRemoteWorkerOCApi->ClearAllRemotePubAsynWrite(streamName_, tagId); + }, + []() { return Status::OK(); }, retryOn, maxRpcTimeoutMs)); + int waitIntervalMs1 = 10, waitIntervalMs2 = 1'000, fastReadingMaxNum = 10; + int retryNum = 0; + do { + RETURN_IF_NOT_OK(CheckWorkerStatus(subWorkerAddress)); + rc = masterRemoteWorkerOCApi->ClearAllRemotePubAsynRead(tagId, RpcRecvFlags::DONTWAIT); + if (rc.IsOk()) { + break; + } + auto waitIntervalMs = retryNum > fastReadingMaxNum ? waitIntervalMs2 : waitIntervalMs1; + std::this_thread::sleep_for(std::chrono::milliseconds(waitIntervalMs)); + ++retryNum; + } while (timer.GetRemainingTimeMs() > 0); + if (rc.IsError()) { + (void)masterRemoteWorkerOCApi->ClearAllRemotePubAsynRead(tagId, RpcRecvFlags::NONE); + } + break; + } + return rc; +} + +Status StreamMetadata::InitStreamMetrics() +{ + return ScMetricsMonitor::Instance()->AddStreamMeta(streamName_, weak_from_this(), scStreamMetrics_); +} + +void StreamMetadata::UpdateStreamMetrics() +{ + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumProducersMaster, GetProducerCount()); + scStreamMetrics_->LogMetric(StreamMetric::NumConsumersMaster, GetConsumerCount()); + } +} + +StreamFields ConvertGetStreamMetadataRspPb2StreamFields(const GetStreamMetadataRspPb &pb) +{ + return { pb.max_stream_size(), static_cast(pb.page_size()), + pb.auto_cleanup(), pb.retain_num_consumer(), + pb.encrypt_stream(), pb.reserve_size(), + pb.stream_mode() }; +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/stream_metadata.h b/src/datasystem/master/stream_cache/stream_metadata.h new file mode 100644 index 0000000..506c374 --- /dev/null +++ b/src/datasystem/master/stream_cache/stream_metadata.h @@ -0,0 +1,546 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The stream metadata object. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_STREAM_METADATA_H +#define DATASYSTEM_MASTER_STREAM_CACHE_STREAM_METADATA_H + +#include +#include +#include +#include +#include + +#include "datasystem/common/log/log.h" +#include "datasystem/common/eventloop/timer_queue.h" +#include "datasystem/common/stream_cache/consumer_meta.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/file_util.h" +#include "datasystem/master/stream_cache/master_worker_sc_api.h" +#include "datasystem/master/stream_cache/rpc_session_manager.h" +#include "datasystem/master/stream_cache/sc_notify_worker_manager.h" +#include "datasystem/master/stream_cache/store/rocks_stream_meta_store.h" +#include "datasystem/master/stream_cache/topology_manager.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics.h" + +namespace datasystem { +namespace master { +using RocksStreamMetaStore = stream_cache::RocksStreamMetaStore; +class StreamMetadata : public std::enable_shared_from_this { +public: + StreamMetadata(std::string streamName, const StreamFields &streamFields, RocksStreamMetaStore *streamMetaStore, + std::shared_ptr akSkManager, std::shared_ptr rpcSessionManager, + EtcdClusterManager *etcdCM, SCNotifyWorkerManager *notifyWorkerManager); + + ~StreamMetadata(); + + /** + * @brief Increase a pub node for a stream and rollback when failed. + * @param[in] producerMeta The producer metadata. + * @param[in] streamFields The fields for the stream. + * @return Status of the call. + */ + Status PubIncreaseNode(const ProducerMetaPb &producerMeta, StreamFields &streamFields); + + /** + * @brief Increase a pub node for a stream and rollback when failed. + * @param[in] producerMeta The producer metadata. + * @param[in] streamFields The fields for the stream. + * @param[in] pubWorkerAddress The source worker address. + * @param[in] isRecon Is this part of reconciliation process. + * @param[in] alreadyLocked Indicate whether StreamMetadata lock is already acquired. + * @return Status of the call. + */ + Status PubIncreaseNodeInternal(const ProducerMetaPb &producerMeta, StreamFields &streamFields, + const HostPort &pubWorkerAddress, bool isRecon, bool alreadyLocked = true); + + /** + * @brief Pre-handling of PubIncreaseNode, including stream fields update and rocksdb update. + * @param[in] streamFields The fields for the stream. + * @param[in] producerMeta The producer metadata. + * @param[in] alreadyLocked Indicate whether StreamMetadata lock is already acquired. + * @param[out] streamFieldsVerified Indicate whether the stream fields are verified and ref count is incremented. + * @param[out] saveToRocksdb Indicate whether saved to rocksdb is done. + * @param[out] isFirstProducer Indicate whether it is the first producer from a worker. + * @return Status of the call. + */ + Status PubIncreaseNodeStart(const StreamFields &streamFields, const ProducerMetaPb &producerMeta, + bool alreadyLocked, bool &streamFieldsVerified, bool &saveToRocksdb, + bool &isFirstProducer); + + /** + * @brief Post-handling of PubIncreaseNode, including the rollback processing if applicable. + * @param[in] isFirstProducer Whether this is the first producer. + * @param[in] needsRollback Whether rollback is needed. + * @param[in] alreadyLocked Indicate whether StreamMetadata lock is already acquired. + * @param[in] producerMeta The producer metadata. + * @param[in] pubWorkerAddress The source worker address. + * @param[in] streamFieldsVerified Indicate whether the stream fields are verified and ref count is incremented. + * @param[in] notifyNodeSet The worker list already send notification success. + * @param[in] saveToRocksdb Indicate whether saved to rocksdb is done. + * @return Status of the call. + */ + Status PubIncreaseNodeEnd(bool isFirstProducer, bool needsRollback, bool alreadyLocked, + const ProducerMetaPb &producerMeta, const HostPort &pubWorkerAddress, + bool streamFieldsVerified, const std::vector ¬ifyNodeSet, + bool saveToRocksdb); + + /** + * @brief Increase a pub node for a stream. + * @param[in] pubWorkerAddress The source worker address. + * @param[out] notifyNodeSet The worker list already send notification success. + * @param[in] isRecon Is this part of reconciliation process. + * @return Status of the call. + */ + Status PubIncreaseNodeImpl(const HostPort &pubWorkerAddress, std::vector ¬ifyNodeSet, bool isRecon); + + /** + * @brief Decrease a pub node for a stream. + * @param[in] producerMeta The producer metadata. + * @param[in] forceClose If the pub node had a crash or regular close + * @return Status of the call. + */ + Status PubDecreaseNode(const ProducerMetaPb &producerMeta, bool forceClose); + + /** + * @brief Prepare for decrease of a pub node for a stream. + * @param[in] producerMeta The producer metadata. + * @param[out] isLastProducer Indicate whether the producer to close is the last producer. + * @return Status of the call. + */ + Status PubDecreaseNodeStart(const ProducerMetaPb &producerMeta, bool &isLastProducer); + + /** + * @brief Increase a sub node for a stream and rollback when failed. + * @param[in] consumerMeta The consumer meta info which will be transformed into a sub node. + * @param[in] isRecon Is this part of reconciliation (or migration) process. + * @return Status of the call. + */ + Status SubIncreaseNode(const ConsumerMetaPb &consumerMeta, bool isRecon = false); + + /** + * @brief Increase a sub node for a stream and rollback when failed. + * @param[in] consumerMeta The consumer meta info which will be transformed into a sub node. + * @param[in] subWorkerAddress The address of the sub worker. + * @param[in] isRecon Is this part of reconciliation process. + * @return Status of the call. + */ + Status SubIncreaseNodeUnlocked(const ConsumerMetaPb &consumerMeta, const HostPort &subWorkerAddress, + bool isRecon = false); + + /** + * @brief Increase a sub node for a stream. + * @param[in] consumerMeta The consumer metadata. + * @param[in] subWorkerAddress The source worker address. + * @param[out] saveToRocksdb Indicate whether save to rocksdb. + * @param[out] sendToSrcNode Indicate whether send rpc to source node. + * @param[out] notifyNodeSet The worker list already send notification success. + * @param[in] isRecon Is this part of reconciliation process. + * @return Status of the call. + */ + Status SubIncreaseNodeImpl(const ConsumerMetaPb &consumerMeta, const HostPort &subWorkerAddress, + bool &saveToRocksdb, bool &sendToSrcNode, std::vector ¬ifyNodeSet, + bool isRecon); + /** + * @brief Decrease a sub node for a stream. + * @param[in] consumerMeta The consumer meta info which will be transformed into a sub node. + * @return Status of the call. + */ + Status SubDecreaseNode(const ConsumerMetaPb &consumerMeta); + + /** + * @brief Start Delete stream request. + * @param[in] srcNode The source worker address of delete request. + * @param[out] relatedWorkerSet Set of nodes to notify + * @return Status of the call. + */ + Status DeleteStreamStart(const HostPort &srcNode, std::set &relatedWorkerSet); + + /** + * @brief Delete Stream from the RocksDb. + * @return Status of the call. + */ + Status DeleteStreamEnd(); + + /** + * @brief Undo delete state + * @param[in] decrementRef Whether to decrement reference count or not when undoing delete stream + */ + void UndoDeleteStream(bool decrementRef); + + /** + * @brief Recover pub worker node meta on master. + * @param[in] producerMetaPb The producer metadata. + * @return K_OK on success; the error code otherwise. + */ + Status RecoveryPubMeta(const ProducerMetaPb &producerMetaPb); + + /** + * @brief Recover sub consumer node meta on master. + * @param consumerMeta The consumer metadata information. + * @return K_OK on success; the error code otherwise. + */ + Status RecoverySubMeta(const ConsumerMetaPb &consumerMeta); + + /** + * @brief Check if the pub/sub exists in the worker node. + * @param[in] workerAddr The worker address. + * @return true if exists. + */ + bool CheckWorkerExistsPubSub(const std::string &workerAddr) const + { + return topoManager_->GetProducerCountInWorker(workerAddr) > 0 + || topoManager_->GetConsumerCountInWorker(workerAddr) > 0; + } + + /** + * @brief Clear the metadata with empty producer or consumer. + * @param[in] workerAddr The worker address. + */ + void ClearEmptyMeta(const std::string &workerAddr) const + { + topoManager_->ClearEmptyMeta(workerAddr); + } + + /** + * @brief Get the all worker address if exists metadata. + * @param[in] nodeSet The worker address. + * @return Status of the call. + */ + Status GetAllWorkerAddress(std::set &nodeSet) const + { + return topoManager_->GetAllRelatedNode(nodeSet); + } + + /** + * @brief Get the stream name. + * @return Stream name. + */ + const std::string &GetStreamName() const + { + return streamName_; + } + + /** + * @brief Clear worker metadata. + * @param[in] workerAddr The worker address. + * @param[in] forceClose If the pub node had a crash or regular close + * @param[in] delWorker Delete worker from relatedNodes + * @return K_OK on success; the error code otherwise. + */ + Status ClearWorkerMetadata(const HostPort &workerAddr, bool forceClose, bool delWorker = true); + + /** + * @brief Check metadata with worker. + * @param[in] meta Received metadata response for the stream from worker. + * @param[in] workerAddr The worker address. + * @return K_OK on success; the error code otherwise. + */ + Status CheckMetadata(const GetStreamMetadataRspPb &meta, const HostPort &workerAddr); + + /** + * @brief Get the producer count in this stream. + * @return size_t The producer count. + */ + size_t GetProducerCount() const + { + return topoManager_->GetProducerCount(); + } + + /** + * @brief Get the consumer count in this stream. + * @return size_t The consumer count. + */ + size_t GetConsumerCount() const + { + return topoManager_->GetConsumerCount(); + } + + /** + * @brief Get stream metrics + * @return stream metrics + */ + auto GetSCStreamMetrics() + { + return scStreamMetrics_; + } + + /** + * @brief Restores consumer count for stream lifetime from RocksDb. + * @return K_OK on success; the error code otherwise. + */ + Status RestoreConsumerLifeCount(const uint32_t consumerCount) const + { + LOG(INFO) << "[RetainData] Number of consumers for stream restored to " << consumerCount + << " for stream: " << streamName_; + return topoManager_->RestoreConsumerLifeCount(consumerCount); + } + + /** + * @brief Restores consumer count for stream lifetime from RocksDb. + * @return the consumer life count. + */ + uint32_t GetConsumerLifeCount() const + { + return topoManager_->GetConsumerCountForLife(); + } + + /** + * @brief Verify the max stream size and the page size. + * @param streamFields The stream fields to check + * @return K_OK on success; the error code otherwise. + */ + Status VerifyStreamFields(const StreamFields &streamFields) const + { + CHECK_FAIL_RETURN_STATUS( + streamFields_.Empty() || (streamFields_ == streamFields), K_INVALID, + FormatString("[%s] Changing stream fields [max stream size, page size, auto cleanup, retain for num " + "consumers, encrypt stream, reserve size] not supported:" + "Current: [%zu, %zu, %s, %zu, %s, %zu] Invalid: [%zu, %zu, %s, %zu, %s, %zu]", + LogPrefix(), streamFields_.maxStreamSize_, streamFields_.pageSize_, + (streamFields_.autoCleanup_ ? "true" : "false"), streamFields_.retainForNumConsumers_, + streamFields_.encryptStream_ ? "true" : "false", streamFields_.reserveSize_, + streamFields.maxStreamSize_, streamFields.pageSize_, + (streamFields.autoCleanup_ ? "true" : "false"), streamFields.retainForNumConsumers_, + streamFields.encryptStream_ ? "true" : "false", streamFields.reserveSize_)); + return Status::OK(); + } + + /** + * @brief Get the stream fields. + * @return Stream fields. + */ + const StreamFields &GetStreamFields() const + { + return streamFields_; + } + + /** + * @brief Get the producer and consumer metadata and related nodes. + * @return All the producer and consumer metadata regarding this stream. + */ + void GetAllProducerConsumer(std::vector &masterProducers, + std::vector &masterConsumers, + std::vector &producerRelatedNodes, + std::vector &consumerRelatedNodes); + + /** + * @brief Initialize producer count and consumer count with zero to keep the related node info. + * @param[in] producerRelatedNodes The producer related nodes. + * @param[in] consumerRelatedNodes The consumer related nodes. + */ + void PreparePubSubRelNodes(const std::vector &producerRelatedNodes, + const std::vector &consumerRelatedNodes); + + /** + * @brief Clean up stream from rocks db store + * @param[in] streamName The target stream name. + * @return All the producer and consumer metadata regarding this stream. + */ + Status CleanUpStreamPersistent(const std::string &streamName); + + /** + * @brief Checks if we need to retain data in workers, and also update the state. + * @return The enum value result. + */ + RetainDataState::State CheckNUpdateNeedRetainData(); + + /** + * @brief Checks if we need to retain data in workers, + * returns whether the state changes and caller can choose to not update the state yet. + * @param[out] retainStateChange Whether there is a state change. + * @param[in] update Whether to update value the upfront, default to false. + * @return The enum value result. If update is false, it is the current state. + */ + RetainDataState::State CheckNeedRetainData(bool &retainStateChange, const bool update = false); + + /** + * @brief Checks if all consumers are closed. + * @param[in] workerAdress The worker need to be checked. + * @return T/F + */ + bool IsAllConsumerClosed(const std::string &workerAddress) + { + return topoManager_->GetConsumerCountInWorker(workerAddress) == 0; + } + + /** + * @brief Initializes the stream metrics with master stream metrics. + * @return Status of the call. + */ + Status InitStreamMetrics(); + + /** + * @brief Updates the stream metrics with master stream metrics. + */ + void UpdateStreamMetrics(); + + /** + * @brief Auto stream clean up. + * @param[in] srcHost Last host with the last consumer/producer closed. + * @return Status of the call. + */ + Status AutoCleanupIfNeeded(const HostPort &srcHost); + +private: + /** + * @brief Updates the stream fields. + * @param[in] streamFields New stream fields to update on. + * @return Status of the call. + */ + Status UpdateStreamFields(const StreamFields &streamFields); + + /** + * @brief Notify all related nodes to stop retaining data. + * @param[in] subWorkerAddress Target source node. + * @param[in] retainStateChange Node set which needs to be notified + * @return Status of the call. + */ + Status NotifyStopRetainData(const HostPort &subWorkerAddress, bool retainStateChange); + + /** + * @brief Remove source node from node set. + * @param[in] srcWorkerAddress Target source node. + * @param[out] nodeSet Node set which has been modified. + * @return Status of the call. + */ + static Status RemoveSourceWorker(const HostPort &srcWorkerAddress, std::set &nodeSet); + + /** + * @brief Get log prefix. + * @return The log prefix. + */ + std::string LogPrefix() const; + + /** + * @brief Clear metadata of pub nodes and sub nodes . + * @param[in] workerAddr Target source node. + * @param[in] producerMap The producers that should be delete. + * @param[in] consumerMap The consumers that should be delete. + * @param[in] forceClose If the node had a crash or regular close + * @param[in] delWorker Delete worker from relatedNodes + * @return Status of the call. + */ + Status ClearPubSubMetaData(const HostPort &workerAddr, + const std::unordered_map &producerMap, + const std::unordered_map &consumerMap, bool forceClose, + bool delWorker); + + /** + * @brief Add async clear notification to worker. + * @param[in] workerAddr The worker address. + * @param[in] pubNodeDelete Notify worker to clear pub node. + * @param[in] consumerMap The consumers list. + * @param[in] forceClose If the node had a crash or regular close + * @return Status of the call. + */ + Status AddAsyncClearNotification(const HostPort &workerAddr, bool pubNodeDelete, + const std::unordered_map &consumerMap, + bool forceClose); + + /** + * @brief Auto stream clean up. + * @param[in] srcHost Last host with the last consumer/producer closed. + * @return Status of the call. + */ + Status AutoCleanupIfNeededNotLocked(const HostPort &srcHost); + + /** + * @brief Check worker status. + * @param[in] workerHostPort The target worker address. + * @return Status of the call. + */ + Status CheckWorkerStatus(const HostPort &workerHostPort); + + /** + * @brief Clear all remote pub node for target stream on src node. + * @param[in] masterWorkerApi The api of src node. + * @param[in] subWorkerAddress The address of src node. + * @return Status of the call. + */ + Status ProcessClearAllRemotePub(const std::shared_ptr &masterWorkerApi, + const HostPort &subWorkerAddress); + + /** + * @brief Verify/Update the stream fields. + * @param[in] streamFields New stream fields to update on. + * @return Status of the call. + */ + Status VerifyAndUpdateStreamFields(const StreamFields &streamFields) + { + RETURN_IF_NOT_OK(VerifyStreamFields(streamFields)); + if (streamFields_ != streamFields) { + RETURN_IF_NOT_OK(UpdateStreamFields(streamFields)); + } + return Status::OK(); + } + + std::string streamName_; + StreamFields streamFields_; + // This reference count is for rollback purposes. + // Rollback on the stream fields will proceed only if ref count is 1. + uint32_t streamFieldsRefcount_ = { 0 }; + mutable std::shared_timed_mutex mutex_; // To lock all the process + int deleterRefCount_ = 0; + + std::unique_ptr masterAddress_{ nullptr }; + // Key: streamName, Value: TopologyManager + std::unique_ptr topoManager_; + RocksStreamMetaStore *streamStore_; + bool alive_; + std::shared_ptr akSkManager_{ nullptr }; + std::shared_ptr rpcSessionManager_{ nullptr }; + RetainDataState retainData_; + EtcdClusterManager *etcdCM_{ nullptr }; + SCNotifyWorkerManager *notifyWorkerManager_{ nullptr }; + std::shared_ptr scStreamMetrics_{ nullptr }; +}; + +/** + * @brief Compare the metadata and erase the metadata if match. + * @tparam Meta The consumer or producer metadata type. + * @tparam F The function type. + * @param[in/out] workerMetas The metadata from worker. + * @param[in/out] masterMetas The metadata from master. + * @param f The function to get the id from metadata. + */ +template +void CompareAndErase(std::vector &workerMetas, std::unordered_map &masterMetas, F &&f) +{ + for (auto iterWorker = workerMetas.begin(); iterWorker != workerMetas.end();) { + auto iterMaster = masterMetas.find(f(*iterWorker)); + if (iterMaster != masterMetas.end()) { + masterMetas.erase(iterMaster); + iterWorker = workerMetas.erase(iterWorker); + } else { + ++iterWorker; + } + } +} + +/** + * @brief Convert GetStreamMetadataRspPb to StreamFields. + * @param[in] pb GetStreamMetadataRspPb. + * @return StreamFields. + */ +StreamFields ConvertGetStreamMetadataRspPb2StreamFields(const GetStreamMetadataRspPb &pb); +} // namespace master +} // namespace datasystem +#endif // DATASYSTEM_MASTER_STREAM_CACHE_STREAM_METADATA_H diff --git a/src/datasystem/master/stream_cache/topology_manager.cpp b/src/datasystem/master/stream_cache/topology_manager.cpp new file mode 100644 index 0000000..1accfcc --- /dev/null +++ b/src/datasystem/master/stream_cache/topology_manager.cpp @@ -0,0 +1,478 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The interface of file cache worker descriptor. + */ +#include + +#include "datasystem/master/stream_cache/topology_manager.h" + +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/status_helper.h" + +namespace datasystem { +namespace master { +TopologyManager::TopologyManager(std::string streamName) : streamName_(std::move(streamName)), isDeleting_(false) +{ +} + +Status TopologyManager::PubIncreaseNode(const ProducerMetaPb &producerMeta, bool &isFirstProducer) +{ + // Add pub worker node into pubTopo set. + HostPort workerAddress(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + std::string workerAddressStr = workerAddress.ToString(); + std::lock_guard lock(mutex_); + auto iter = producerCount_.find(workerAddressStr); + if (iter == producerCount_.end()) { + iter = producerCount_.emplace(workerAddressStr, ProducerCount()).first; + } + CHECK_FAIL_RETURN_STATUS(iter->second.firstProducerProcessing_ == false, K_TRY_AGAIN, + FormatString("First create producer request or the last close producer request from the " + "worker [%s] for stream <%s> is not done yet, try again later.", + workerAddressStr, streamName_)); + // Check if this is first request for the worker + CHECK_FAIL_RETURN_STATUS(!ExistsProducerUnlocked(workerAddress), K_DUPLICATED, + "producer for worker already exists"); + + auto preCount = iter->second.count_++; + currentProducerCount_++; + isFirstProducer = (preCount == 0); + if (isFirstProducer) { + iter->second.firstProducerProcessing_ = true; + } + return Status::OK(); +} + +Status TopologyManager::PubNodeFirstOrLastDone(const ProducerMetaPb &producerMeta) +{ + // Add pub worker node into pubTopo set. + HostPort workerAddress(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + std::lock_guard lock(mutex_); + auto iter = producerCount_.find(workerAddress.ToString()); + CHECK_FAIL_RETURN_STATUS( + iter != producerCount_.end() && iter->second.count_ == 1 && iter->second.firstProducerProcessing_, + StatusCode::K_RUNTIME_ERROR, + FormatString("Invalid producer source node <%s>, should be the first or last producer", + workerAddress.ToString())); + iter->second.firstProducerProcessing_ = false; + return Status::OK(); +} + +Status TopologyManager::PubDecreaseNodeStart(const ProducerMetaPb &producerMeta, bool &isLastProducer) +{ + HostPort workerAddress(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + std::string workerAddressStr = workerAddress.ToString(); + std::lock_guard lock(mutex_); + auto iter = producerCount_.find(workerAddressStr); + CHECK_FAIL_RETURN_STATUS(iter->second.firstProducerProcessing_ == false, K_TRY_AGAIN, + FormatString("First create producer request or the last close producer request from the " + "worker [%s] for stream <%s> is not done yet, try again later.", + workerAddressStr, streamName_)); + CHECK_FAIL_RETURN_STATUS(iter != producerCount_.end() && iter->second.count_ > 0, StatusCode::K_RUNTIME_ERROR, + FormatString("Invalid producer source node <%s>", workerAddressStr)); + isLastProducer = (iter->second.count_ == 1); + if (isLastProducer) { + iter->second.firstProducerProcessing_ = true; + } + return Status::OK(); +} + +Status TopologyManager::PubDecreaseNode(const ProducerMetaPb &producerMeta, const bool delWorker) +{ + HostPort workerAddress(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + std::lock_guard lock(mutex_); + auto iter = producerCount_.find(workerAddress.ToString()); + + CHECK_FAIL_RETURN_STATUS(iter != producerCount_.end() && iter->second.count_ > 0, StatusCode::K_RUNTIME_ERROR, + FormatString("Invalid producer source node <%s>", workerAddress.ToString())); + + // Keep the producer count when deduct to 0 for delete stream. + iter->second.count_--; + currentProducerCount_--; + // If last producer then delete the related node if worker is lost + if (delWorker && !iter->second.count_) { + LOG(INFO) << "Removing worker " << workerAddress.ToString() << " from producer related nodes"; + producerCount_.erase(workerAddress.ToString()); + } + return Status::OK(); +} + +Status TopologyManager::SubIncreaseNode(const ConsumerMetaPb &consumerMeta, bool &isFirstConsumer) +{ + HostPort workerAddress(consumerMeta.worker_address().host(), consumerMeta.worker_address().port()); + std::lock_guard lock(mutex_); + auto ret = consumerTopo_.emplace(consumerMeta.consumer_id(), consumerMeta); + CHECK_FAIL_RETURN_STATUS(ret.second, StatusCode::K_RUNTIME_ERROR, + FormatString("Fail to add consumer [%s] into topo structure for stream <%s>", + consumerMeta.consumer_id(), streamName_)); + + const auto &config = consumerMeta.sub_config(); + if (config.subscription_type() == SubscriptionTypePb::STREAM_PB) { + const std::string &subName = config.subscription_name(); + CHECK_FAIL_RETURN_STATUS(streamModeSubDict_.find(subName) == streamModeSubDict_.end(), StatusCode::K_DUPLICATED, + "In STREAM mode, one subscription can only contain one consumer"); + streamModeSubDict_.emplace(subName); + } + + auto preCount = consumerCount_[workerAddress.ToString()]++; + // we also maintain consumer count over stream life + consumerLifeCount_++; + LOG(INFO) << "[RetainData] Number of consumers for stream increased to " << consumerLifeCount_; + isFirstConsumer = (preCount == 0); + return Status::OK(); +} + +uint32_t TopologyManager::GetConsumerCountForLife() +{ + std::shared_lock lock(mutex_); + return consumerLifeCount_; +} + +Status TopologyManager::SubDecreaseNode(const ConsumerMetaPb &consumerMeta, bool rollback, const bool delWorker) +{ + HostPort workerAddress(consumerMeta.worker_address().host(), consumerMeta.worker_address().port()); + std::lock_guard lock(mutex_); + const auto &config = consumerMeta.sub_config(); + if (config.subscription_type() == SubscriptionTypePb::STREAM_PB) { + const std::string &subName = config.subscription_name(); + CHECK_FAIL_RETURN_STATUS(streamModeSubDict_.erase(subName) == 1, StatusCode::K_NOT_FOUND, + FormatString("Consumer:<%s>, Subscription:<%s>, Mode:, Status:<%s>", + consumerMeta.consumer_id(), subName, "Not found on master")); + } + + auto iter = consumerCount_.find(workerAddress.ToString()); + + CHECK_FAIL_RETURN_STATUS(iter != consumerCount_.end() && iter->second > 0, StatusCode::K_RUNTIME_ERROR, + FormatString("Invalid consumer source node <%s>", workerAddress.ToString())); + + // Keep the consumer count when deduct to 0 for delete stream. + iter->second--; + + CHECK_FAIL_RETURN_STATUS(consumerTopo_.erase(consumerMeta.consumer_id()) == 1, StatusCode::K_RUNTIME_ERROR, + "Fail to delete sub node"); + if (rollback) { + consumerLifeCount_--; + LOG(INFO) << "[RetainData] Number of consumers for stream decreased to " << consumerLifeCount_; + } + // If last consumer then delete the related node if worker is lost + if (delWorker && !iter->second) { + LOG(INFO) << "Removing worker " << workerAddress.ToString() << " from consumer related nodes"; + consumerCount_.erase(workerAddress.ToString()); + } + return Status::OK(); +} + +Status TopologyManager::GetAllPubNode(std::set &pubNodeSet, bool informAll) const +{ + std::shared_lock lock(mutex_); + for (const auto &kv : producerCount_) { + if (kv.second.count_ > 0 || informAll) { + HostPort nodeAddr; + RETURN_IF_NOT_OK(nodeAddr.ParseString(kv.first)); + pubNodeSet.emplace(std::move(nodeAddr)); + } + } + return Status::OK(); +} + +Status TopologyManager::GetAllSubNode(std::set &subNodeSet) const +{ + std::shared_lock lock(mutex_); + for (const auto &kv : consumerCount_) { + if (kv.second > 0) { + HostPort nodeAddr; + RETURN_IF_NOT_OK(nodeAddr.ParseString(kv.first)); + subNodeSet.emplace(std::move(nodeAddr)); + } + } + return Status::OK(); +} + +Status TopologyManager::GetAllRelatedNode(std::set &nodeSet) const +{ + std::shared_lock lock(mutex_); + for (const auto &kv : producerCount_) { + HostPort nodeAddr; + RETURN_IF_NOT_OK(nodeAddr.ParseString(kv.first)); + (void)nodeSet.emplace(std::move(nodeAddr)); + } + + for (const auto &kv : consumerCount_) { + HostPort nodeAddr; + RETURN_IF_NOT_OK(nodeAddr.ParseString(kv.first)); + (void)nodeSet.emplace(std::move(nodeAddr)); + } + return Status::OK(); +} + +void TopologyManager::GetAllRelatedNode(std::vector &producerRelatedNodes, + std::vector &consumerRelatedNodes) const +{ + std::shared_lock lock(mutex_); + producerRelatedNodes.reserve(producerCount_.size()); + for (const auto &kv : producerCount_) { + producerRelatedNodes.emplace_back(kv.first); + } + + consumerRelatedNodes.reserve(consumerCount_.size()); + for (const auto &kv : consumerCount_) { + consumerRelatedNodes.emplace_back(kv.first); + } +} + +bool TopologyManager::RecoverEmptyMetaIfNeeded(const HostPort &nodeAddr) +{ + std::set nodeSet; + auto rc = GetAllRelatedNode(nodeSet); + if (rc.IsError()) { + LOG(WARNING) << "GetAllRelatedNode failed: " << rc.ToString(); + return false; + } + if (nodeSet.count(nodeAddr) > 0) { + return false; + } + LOG(INFO) << FormatString("[S: %s] Recover empty meta for worker: %s", streamName_, nodeAddr.ToString()); + std::lock_guard lock(mutex_); + if (producerCount_.find(nodeAddr.ToString()) == producerCount_.end()) { + (void)producerCount_.emplace(nodeAddr.ToString(), ProducerCount()); + } + if (consumerCount_.find(nodeAddr.ToString()) == consumerCount_.end()) { + (void)consumerCount_.emplace(nodeAddr.ToString(), 0); + } + return true; +} + +void TopologyManager::PreparePubSubRelNodes(const std::vector &producerRelatedNodes, + const std::vector &consumerRelatedNodes) +{ + for (const auto &producerAddr : producerRelatedNodes) { + producerCount_.emplace(producerAddr, ProducerCount()); + } + for (const auto &consumerAddr : consumerRelatedNodes) { + consumerCount_.emplace(consumerAddr, 0); + } +} + +std::vector TopologyManager::GetAllConsumerNotFromSrc(const std::string &srcNode) const +{ + std::vector consumerList; + std::shared_lock lock(mutex_); + for (const auto &kv : consumerTopo_) { + const auto hostPortPb = kv.second.worker_address(); + HostPort addr(hostPortPb.host(), hostPortPb.port()); + if (addr.ToString() != srcNode) { + consumerList.emplace_back(kv.second); + } + } + return consumerList; +} + +std::unordered_map TopologyManager::GetAllConsumerFromWorker( + const HostPort &workerAddress) const +{ + std::unordered_map consumerMap; + std::shared_lock lock(mutex_); + for (auto &kv : consumerTopo_) { + const auto hostPortPb = kv.second.worker_address(); + HostPort addr(hostPortPb.host(), hostPortPb.port()); + if (addr == workerAddress) { + consumerMap.emplace(kv.first, kv.second); + } + } + return consumerMap; +} + +Status TopologyManager::GetAllProducerFromWorker(const HostPort &workerAddress, + std::unordered_map &producerMap) +{ + std::shared_lock lock(mutex_); + for (const auto &kv : producerCount_) { + if (kv.first == workerAddress.ToString()) { + // Create a temporary ProducerMetaPb + ProducerMetaPb producerMeta; + + // Add stream name, Address, and producer count to it (which will be 0 or 1) + producerMeta.set_stream_name(streamName_); + HostPort workerHostPort; + RETURN_IF_NOT_OK(workerHostPort.ParseString(kv.first)); + producerMeta.mutable_worker_address()->set_host(workerHostPort.Host()); + producerMeta.mutable_worker_address()->set_port(workerHostPort.Port()); + producerMeta.set_producer_count(kv.second.count_); + + // Key: WorkerAddress Value: MetaPb + producerMap.emplace(kv.first, producerMeta); + } + } + return Status::OK(); +} + +Status TopologyManager::GetAllProducer(std::vector &producerList) +{ + std::shared_lock lock(mutex_); + producerList.reserve(producerCount_.size()); + for (const auto &kv : producerCount_) { + // Skip the closed producer, we are getting that in the related node info. + if (kv.second.count_ == 0) { + continue; + } + // Create a temporary ProducerMetaPb + ProducerMetaPb producerMeta; + + // Add stream name, Address, and producer count to it (which will be 0 or 1) + producerMeta.set_stream_name(streamName_); + HostPort workerHostPort; + RETURN_IF_NOT_OK(workerHostPort.ParseString(kv.first)); + producerMeta.mutable_worker_address()->set_host(workerHostPort.Host()); + producerMeta.mutable_worker_address()->set_port(workerHostPort.Port()); + producerMeta.set_producer_count(kv.second.count_); + + producerList.emplace_back(producerMeta); + } + return Status::OK(); +} + +std::vector TopologyManager::GetAllConsumer() const +{ + std::vector consumerList; + std::shared_lock lock(mutex_); + consumerList.reserve(consumerTopo_.size()); + for (const auto &kv : consumerTopo_) { + consumerList.emplace_back(kv.second); + } + return consumerList; +} + +bool TopologyManager::GetStreamStatus() const +{ + std::shared_lock lock(mutex_); + return isDeleting_; +} + +Status TopologyManager::SetDeletingStatus() +{ + std::lock_guard lock(mutex_); + CHECK_FAIL_RETURN_STATUS(!isDeleting_, StatusCode::K_IO_ERROR, + FormatString("Stream:<%s>, State:", streamName_)); + isDeleting_ = true; + return Status::OK(); +} + +void TopologyManager::UnsetDeletingStatus() +{ + std::lock_guard lock(mutex_); + isDeleting_ = false; +} + +Status TopologyManager::GlobalUniqueCheck(const ConsumerMetaPb &consumerMeta) +{ + std::shared_lock lock(mutex_); + bool isUnique = true; + const auto &config = consumerMeta.sub_config(); + RETURN_OK_IF_TRUE(config.subscription_type() != SubscriptionTypePb::STREAM_PB); + RETURN_OK_IF_TRUE(consumerTopo_.empty()); + + const std::string &subName = config.subscription_name(); + for (const auto &kv : consumerTopo_) { + if (subName == kv.second.sub_config().subscription_name()) { + isUnique = false; + break; + } + } + CHECK_FAIL_RETURN_STATUS(isUnique, StatusCode::K_RUNTIME_ERROR, + FormatString("Stream:<%s>, SubscriptionName:<%s> is not unique in global scope", + streamName_, consumerMeta.sub_config().subscription_name())); + return Status::OK(); +} + +Status TopologyManager::CheckNewConsumer(const ConsumerMetaPb &consumerMeta) +{ + std::shared_lock lock(mutex_); + auto iter = consumerTopo_.find(consumerMeta.consumer_id()); + if (iter != consumerTopo_.end()) { + auto retCode = iter->second.client_id() == consumerMeta.client_id() ? StatusCode::K_DUPLICATED + : StatusCode::K_RUNTIME_ERROR; + RETURN_STATUS(retCode, FormatString("The consumer [%s] already exists in stream <%s>", + consumerMeta.consumer_id(), streamName_)); + } + return GlobalUniqueCheck(consumerMeta); +} + +bool TopologyManager::CheckIfAllPubSubHaveClosed() +{ + std::unique_lock lock(mutex_); + return consumerTopo_.empty() && (currentProducerCount_ == 0); +} + +namespace { +template +std::string IntoString(const T &val); + +template <> +std::string IntoString(const ConsumerMetaPb &val) +{ + ConsumerMetaPb meta = val; + meta.clear_consumer_id(); + meta.clear_stream_name(); + return "<" + meta.ShortDebugString() + ">"; +} + +template <> +std::string IntoString(const uint32_t &val) +{ + return std::to_string(val); +} + +template <> +std::string IntoString(const TopologyManager::ProducerCount &val) +{ + return std::to_string(val.count_); +} + +template +void GetTopoInformation(std::ostream &os, const std::unordered_map &map) +{ + const size_t maxCount = 64; + size_t count = 0; + os << "{"; + for (const auto &kv : map) { + if (count >= maxCount) { + break; + } + os << kv.first << ":" << IntoString(kv.second); + count++; + } + if (count > maxCount) { + os << "...(" << (count - maxCount) << ")"; + } + os << "}"; +} +} // namespace + +std::ostream &operator<<(std::ostream &os, const TopologyManager &obj) +{ + std::shared_lock lock(obj.mutex_); + os << "{producerCount:"; + GetTopoInformation(os, obj.producerCount_); + os << ", consumers:"; + GetTopoInformation(os, obj.consumerTopo_); + os << ", consumerCount:"; + GetTopoInformation(os, obj.consumerCount_); + os << "}"; + return os; +} +} // namespace master +} // namespace datasystem diff --git a/src/datasystem/master/stream_cache/topology_manager.h b/src/datasystem/master/stream_cache/topology_manager.h new file mode 100644 index 0000000..42a2efa --- /dev/null +++ b/src/datasystem/master/stream_cache/topology_manager.h @@ -0,0 +1,361 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: The interface of file cache worker descriptor. + */ +#ifndef DATASYSTEM_MASTER_STREAM_CACHE_TOPOLOGY_MANAGER_H +#define DATASYSTEM_MASTER_STREAM_CACHE_TOPOLOGY_MANAGER_H + +#include +#include +#include +#include +#include +#include + +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/protos/worker_stream.pb.h" + +namespace datasystem { +namespace master { +class TopologyManager { +public: + explicit TopologyManager(std::string streamName); + ~TopologyManager() = default; + + /** + * @brief Increase a pub node for a stream. + * @param[in] producerMeta The producer metadata. + * @param[out] isFirstProducer Whether the target producer is the first one on current local worker node. + * @return Status of the call. + */ + Status PubIncreaseNode(const ProducerMetaPb &producerMeta, bool &isFirstProducer); + + /** + * @brief The first producer processing is done for the worker. + * @param[in] producerMeta The producer metadata. + * @return Status of the call. + */ + Status PubNodeFirstOrLastDone(const ProducerMetaPb &producerMeta); + + /** + * @brief Prepare for decrease of a pub node for a stream. + * @param[in] producerMeta The producer metadata. + * @param[out] isLastProducer Indicate whether the producer to close is the last producer. + * @return Status of the call. + */ + Status PubDecreaseNodeStart(const ProducerMetaPb &producerMeta, bool &isLastProducer); + + /** + * @brief Decrease a pub node for a stream. + * @param[in] producerMeta The producer metadata. + * @param[in] delWorker Delete the related worker as part of reconciliation. + * @return Status of the call. + */ + Status PubDecreaseNode(const ProducerMetaPb &producerMeta, const bool delWorker); + + /** + * @brief Increase a sub node for a stream. + * @param[in] consumerMeta The consumer meta info which will be transformed into a sub node. + * @param[out] isFirstConsumer Whether the target consumer is the first one on current local worker node. + * @return Status of the call. + */ + Status SubIncreaseNode(const ConsumerMetaPb &consumerMeta, bool &isFirstConsumer); + + /** + * @brief Decrease a sub node for a stream. + * @param[in] consumerMeta The consumer meta info which will be transformed into a sub node. + * @param[in] rollback Wether this is a rollback or actual close. + * @param[in] delWorker Delete the related worker as part of reconciliation + * @return Status of the call. + */ + Status SubDecreaseNode(const ConsumerMetaPb &consumerMeta, bool rollback, const bool delWorker); + + /** + * @brief Get all pub node for a stream. + * @param[out] pubNodeSet The set of pub node address. + * @param[in] informAll Whether to inform nodes with no producers. + * @return Status of the call. + */ + Status GetAllPubNode(std::set &pubNodeSet, bool informAll = false) const; + + /** + * @brief Get all sub node for a stream, only in worker node layer definition. + * @param[out] subNodeSet The set of sub node address. + * @return Status of the call. + */ + Status GetAllSubNode(std::set &subNodeSet) const; + + /** + * @brief Get the all worker address if metadata exist. + * @param[out] nodeSet The set of all pub sub node address. + * @return Status of the call. + */ + Status GetAllRelatedNode(std::set &nodeSet) const; + + /** + * @brief Recover Empty Meta for a node If Needed. (Failure recovery scenarios require) + * @param[in] nodeAddr The node where both consumer and producer are closed. + * @return True if empty metadata was recovered, false otherwise. + */ + bool RecoverEmptyMetaIfNeeded(const HostPort &nodeAddr); + + /** + * @brief Get the all worker address if metadata exist. + * @param[out] producerRelatedNodes All of the producer related nodes. + * @param[out] consumerRelatedNodes All of the consumer related nodes. + * @return Status of the call. + */ + void GetAllRelatedNode(std::vector &producerRelatedNodes, + std::vector &consumerRelatedNodes) const; + + /** + * @brief Initialize producer count and consumer count with zero to keep the related node info. + * @param[in] producerRelatedNodes The producer related nodes. + * @param[in] consumerRelatedNodes The consumer related nodes. + */ + void PreparePubSubRelNodes(const std::vector &producerRelatedNodes, + const std::vector &consumerRelatedNodes); + + /** + * @brief Get all sub node for a stream except source node. + * @param[in] srcNode Source node that will be excluded. + * @return The consumer metadatas for the stream. + */ + std::vector GetAllConsumerNotFromSrc(const std::string &srcNode) const; + + /** + * @brief Get the all consumer metadata from worker. + * @param[in] workerAddress The worker address. + * @return The The producer metadatas. + */ + std::unordered_map GetAllConsumerFromWorker(const HostPort &workerAddress) const; + + /** + * @brief Get all the producer metadata. + * @param[out] producerList list of producers in the worker + * @return Status of the call. + */ + Status GetAllProducer(std::vector &producerList); + + /** + * @brief Get all the consumer metadata. + * @return The consumer metadatas. + */ + std::vector GetAllConsumer() const; + + /** + * @brief Get the all producer metadata from worker. + * @param[in] workerAddress The worker address. + * @param[out] producerMap list of producers in the worker + * @return Status of the call. + */ + Status GetAllProducerFromWorker(const HostPort &workerAddress, + std::unordered_map &producerMap); + /** + * @brief Get the status of a stream. + * @return True if target stream is deleting. + */ + bool GetStreamStatus() const; + + /** + * @brief Set the status of current stream as isDeleting. + * @return Status of the call. + */ + Status SetDeletingStatus(); + + /** + * @brief Unset the isDeleting status. + */ + void UnsetDeletingStatus(); + + /** + * @brief Check whether the subscription name of target consumer is unique in global scope. + * @param[in] consumerMeta Target consumer meta info. + * @return Status of the call. + */ + Status GlobalUniqueCheck(const ConsumerMetaPb &consumerMeta); + + /** + * @brief Get the producer count in worker. + * @param[in] workerAddr The worker address. + */ + void ClearEmptyMeta(const std::string &workerAddr) + { + std::shared_lock lock(mutex_); + auto producerIter = producerCount_.find(workerAddr); + if (producerIter != producerCount_.end() && !producerIter->second.count_) { + producerCount_.erase(workerAddr); + } + auto consumerIter = consumerCount_.find(workerAddr); + if (consumerIter != consumerCount_.end() && !consumerIter->second) { + consumerCount_.erase(workerAddr); + } + } + + /** + * @brief Get the producer count in worker. + * @param[in] workerAddr The worker address. + * @return The producer count in worker. + */ + size_t GetProducerCountInWorker(const std::string &workerAddr) + { + std::shared_lock lock(mutex_); + auto iter = producerCount_.find(workerAddr); + if (iter != producerCount_.end()) { + return iter->second.count_; + } + return 0; + } + + /** + * @brief Get the consumer count in worker. + * @param[in] workerAddr The worker address. + * @return The producer count in worker.. + */ + size_t GetConsumerCountInWorker(const std::string &workerAddr) const + { + std::shared_lock lock(mutex_); + auto iter = consumerCount_.find(workerAddr); + if (iter != consumerCount_.end()) { + return iter->second; + } + return 0; + } + + /** + * @brief Get the producer count in this stream. + * @return uint64_t The producer count. + */ + uint64_t GetProducerCount() const + { + std::shared_lock lock(mutex_); + return std::accumulate( + producerCount_.begin(), producerCount_.end(), 0ul, + [](uint64_t init, const std::pair &kv) { return init + kv.second.count_; }); + } + + /** + * @brief Get the consumer count in this stream. + * @return uint64_t The consumer count. + */ + uint64_t GetConsumerCount() const + { + std::shared_lock lock(mutex_); + return std::accumulate( + consumerCount_.begin(), consumerCount_.end(), 0ul, + [](uint64_t init, const std::pair &kv) { return init + kv.second; }); + } + + /** + * @brief Check if the worker has a producer or not. + * @param[in] producerMeta The producer meta. + * @return The producer exists or not. + */ + bool ExistsProducer(const ProducerMetaPb &producerMeta) const + { + HostPort workerAddress(producerMeta.worker_address().host(), producerMeta.worker_address().port()); + std::shared_lock lock(mutex_); + return ((producerCount_.count(workerAddress.ToString()) > 0) + && (producerCount_.at(workerAddress.ToString()).count_ > 0)); + } + + /** + * @brief Check if the worker has a producer or not. + * @param[in] workerAddress workerAddress of producer + * @return The producer exists or not. + */ + bool ExistsProducerUnlocked(const HostPort &workerAddress) const + { + return ((producerCount_.count(workerAddress.ToString()) > 0) + && (producerCount_.at(workerAddress.ToString()).count_ > 0)); + } + + /** + * @brief Check the consumer exists or not. + * @param[in] consumerId The consumer id. + * @return The consumer exists or not. + */ + bool ExistsConsumer(const std::string &consumerId) const + { + std::shared_lock lock(mutex_); + return consumerTopo_.count(consumerId) > 0; + } + + Status RestoreConsumerLifeCount(const uint32_t consumerCount) + { + std::lock_guard lock(mutex_); + consumerLifeCount_ = consumerCount; + return Status::OK(); + } + + /** + * @brief Check for add new consumer. + * @param[in] consumerMeta The consumer metadata. + * @return Status of the call. + */ + Status CheckNewConsumer(const ConsumerMetaPb &consumerMeta); + + /** + * @brief Check if all the producers/consumers have closed + * @return T/F + */ + bool CheckIfAllPubSubHaveClosed(); + + /** + * @brief Returns consumer count over stream life time + * @return Consumer count + */ + uint32_t GetConsumerCountForLife(); + + struct ProducerCount { + uint32_t count_ = 0; + bool firstProducerProcessing_ = false; + }; + +private: + friend std::ostream &operator<<(std::ostream &os, const TopologyManager &obj); + + mutable std::shared_timed_mutex mutex_; + + std::string streamName_; + + // Key: workerAddress. Value: consumer number on that worker. + std::unordered_map consumerCount_; + + // Num of consumers over life of the stream. + uint32_t consumerLifeCount_{ 0 }; + + // Key: workerAddress. Value: producer number on that worker, + // and whether the first producer request is still processing. + std::unordered_map producerCount_; + + // Topology for consumer. Key: consumerId. Value: ConsumerMetaPb. + std::unordered_map consumerTopo_; + // Measure total number of workers with atleast one producer + // Used as a check before DeleteStreams + uint64_t currentProducerCount_{ 0 }; + + std::set streamModeSubDict_; // Used for record all stream mode subscription in global. + + bool isDeleting_; +}; +} // namespace master +} // namespace datasystem + +#endif // DATASYSTEM_MASTER_STREAM_CACHE_TOPOLOGY_MANAGER_H diff --git a/src/datasystem/protos/CMakeLists.txt b/src/datasystem/protos/CMakeLists.txt index 7254dcb..3020d21 100644 --- a/src/datasystem/protos/CMakeLists.txt +++ b/src/datasystem/protos/CMakeLists.txt @@ -106,20 +106,22 @@ GEN_GLOBAL_PROTO_LIB_DEPEND(CLIENT # Compile the proto file at the first layer, for example, object_posix.proto. GENERATE_ZMQ_CPP(zmq_proto_lib_depend OBJECT_POSIX_SRCS OBJECT_POSIX_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/object_posix.proto) +GENERATE_ZMQ_CPP(zmq_proto_lib_depend STREAM_POSIX_SRCS STREAM_POSIX_ZMQ_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/stream_posix.proto) GEN_GLOBAL_PROTO_LIB_DEPEND( NAME posix_protos - SOURCE ${OBJECT_POSIX_SRCS} + SOURCE ${OBJECT_POSIX_SRCS} ${STREAM_POSIX_SRCS} DEPEND zmq_proto_lib_depend LINK ${PROTOBUF_LIBRARIES} zmq_meta_protos utils_protos p2p_subscribe_protos) GEN_GLOBAL_PROTO_LIB_DEPEND(CLIENT NAME posix_protos_client - SOURCE ${OBJECT_POSIX_SRCS} + SOURCE ${OBJECT_POSIX_SRCS} ${STREAM_POSIX_SRCS} DEPEND zmq_proto_lib_depend LINK ${PROTOBUF_LIBRARIES} zmq_meta_protos_client utils_protos_client p2p_subscribe_protos_client) # Compile the proto file at the second layer, for example, posix proto. GENERATE_ZMQ_CPP(zmq_proto_lib_depend WORKER_OBJECT_SRCS WORKER_OBJECT_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/worker_object.proto) +GENERATE_ZMQ_CPP(zmq_proto_lib_depend WORKER_STREAM_SRCS WORKER_STREAM_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/worker_stream.proto) GEN_GLOBAL_PROTO_LIB_DEPEND( NAME worker_object_protos @@ -132,8 +134,20 @@ GEN_GLOBAL_PROTO_LIB_DEPEND(CLIENT DEPEND zmq_proto_lib_depend LINK ${PROTOBUF_LIBRARIES} zmq_meta_protos_client hash_ring_protos_client) -# Compile the proto file at the third layer, for example, master_object.proto +GEN_GLOBAL_PROTO_LIB_DEPEND( + NAME worker_stream_protos + SOURCE ${WORKER_STREAM_SRCS} + DEPEND zmq_proto_lib_depend + LINK ${PROTOBUF_LIBRARIES} zmq_meta_protos utils_protos posix_protos) +GEN_GLOBAL_PROTO_LIB_DEPEND(CLIENT + NAME worker_stream_protos_client + SOURCE ${WORKER_STREAM_SRCS} + DEPEND zmq_proto_lib_depend + LINK ${PROTOBUF_LIBRARIES} zmq_meta_protos_client utils_protos_client posix_protos_client) + +# Compile the proto file at the third layer, for example, master_object.proto, master_stream.proto GENERATE_ZMQ_CPP(zmq_proto_lib_depend MASTER_OBJECT_SRCS MASTER_OBJECT_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/master_object.proto) +GENERATE_ZMQ_CPP(zmq_proto_lib_depend MASTER_STREAM_SRCS MASTER_STREAM_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/master_stream.proto) GENERATE_ZMQ_CPP(zmq_proto_lib_depend MASTER_HEARTBEAT_SRCS MASTER_HEARTBEAT_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/master_heartbeat.proto) GEN_GLOBAL_PROTO_LIB_DEPEND( @@ -147,6 +161,17 @@ GEN_GLOBAL_PROTO_LIB_DEPEND(CLIENT DEPEND zmq_proto_lib_depend LINK ${PROTOBUF_LIBRARIES} posix_protos_client worker_object_protos_client) +GEN_GLOBAL_PROTO_LIB_DEPEND( + NAME master_stream_protos + SOURCE ${MASTER_STREAM_SRCS} + DEPEND zmq_proto_lib_depend + LINK ${PROTOBUF_LIBRARIES} worker_stream_protos utils_protos posix_protos) +GEN_GLOBAL_PROTO_LIB_DEPEND(CLIENT + NAME master_stream_protos_client + SOURCE ${MASTER_STREAM_SRCS} + DEPEND zmq_proto_lib_depend + LINK ${PROTOBUF_LIBRARIES} worker_stream_protos_client utils_protos_client posix_protos_client) + GEN_GLOBAL_PROTO_LIB_DEPEND( NAME master_heartbeat_protos SOURCE ${MASTER_HEARTBEAT_SRCS} @@ -176,7 +201,6 @@ if (WITH_TESTS) set_target_properties(common_rpc_zmq_demo PROPERTIES COMPILE_FLAGS "-Wno-unused-parameter") GENERATE_ZMQ_CPP(zmq_proto_lib_depend UNIT_TEST_SRCS UNIT_TEST_HDRS ${PROTO_BUILD_DIR} ${PROTO_SRC_DIR} ${PROTO_SRC_DIR}/ut_object.proto) add_library(ut_object_protos STATIC ${UNIT_TEST_SRCS}) - target_link_libraries(ut_object_protos common_rpc_zmq) add_dependencies(ut_object_protos zmq_proto_lib_depend) set_target_properties(ut_object_protos PROPERTIES COMPILE_FLAGS "-Wno-unused-parameter") endif() diff --git a/src/datasystem/protos/README.md b/src/datasystem/protos/README.md index 5d977a8..77fad0f 100644 --- a/src/datasystem/protos/README.md +++ b/src/datasystem/protos/README.md @@ -12,14 +12,14 @@ ## 1.2 Hierarchical Design of Proto Files -### 1.2.1 Object proto files +### 1.2.1 Object/Stream proto files | layer | type | protos files | | ----- | ----------------- |-----------------------------------------------------------------| | 0 | utils proto files | rpc_options.proto、utils.proto、 | -| 1 | POSIX API | object_posix.proto/share_memory.proto | -| 2 | worker RPC API | worker_object.proto | -| 3 | master RPC API | master_object.proto/master_heartbeat.protoc | +| 1 | POSIX API | object_posix.proto/stream_posix.proto/share_memory.proto | +| 2 | worker RPC API | worker_object.proto/worker_stream.proto | +| 3 | master RPC API | master_object.proto/master_stream.proto/master_heartbeat.protoc | ### 1.2.2 Agent proto files diff --git a/src/datasystem/protos/master_stream.proto b/src/datasystem/protos/master_stream.proto new file mode 100644 index 0000000..9df900f --- /dev/null +++ b/src/datasystem/protos/master_stream.proto @@ -0,0 +1,228 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Defines the RPC API of the master component. + */ +syntax = "proto3"; +package datasystem.master; + +import "worker_stream.proto"; +import "utils.proto"; +import "rpc_option.proto"; + +message StreamMetaPb { + string stream_name = 1; + uint64 max_stream_size = 2; + int64 page_size = 3; + bool auto_cleanup = 4; + uint64 retain_num_consumer = 5; + uint64 consumer_life_count = 6; + bool encrypt_stream = 7; + uint64 producer_life_count = 8; + uint64 reserve_size = 9; + int32 stream_mode = 10; +} + +message CreateProducerReqPb { + ProducerMetaPb producer_meta = 1; + int64 timeout = 2; + uint64 max_stream_size = 3; + int64 page_size = 4; + bool auto_cleanup = 5; + bool redirect = 6; + uint64 retain_num_consumer = 7; + bool encrypt_stream = 8; + uint64 reserve_size = 9; + int32 stream_mode = 10; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CreateProducerRspPb { + RedirectMetaInfo info = 1; + bool meta_is_moving = 2; + uint32 retain_data = 3; + uint64 producer_no = 4; +} + +message ProducerInfoPb { + string stream_name = 1; +} + +message CloseProducerReqPb { + repeated ProducerInfoPb producer_infos = 1; + HostPortPb worker_address = 2; + int64 timeout = 3; + bool force_close = 4; + bool redirect = 5; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CloseProducerRspPb { + repeated ProducerInfoPb failed_producers = 1; + repeated ProducerInfoPb success_producers = 2; + ErrorInfoPb err = 3; + repeated RedirectMetaInfo info = 4; + bool meta_is_moving = 5; +} + +message SubscribeReqPb { + ConsumerMetaPb consumer_meta = 1; + int64 timeout = 2; + bool redirect = 3; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message SubscribeRspPb { + uint64 max_stream_size = 1; + int64 page_size = 2; + bool auto_cleanup = 3; + RedirectMetaInfo info = 4; + bool meta_is_moving = 5; + uint64 retain_num_consumer = 6; + uint32 retain_data = 7; + bool encrypt_stream = 8; + uint64 reserve_size = 9; + int32 stream_mode = 10; +} + +message CloseConsumerReqPb { + ConsumerMetaPb consumer_meta = 1; + int64 timeout = 2; + bool redirect = 3; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CloseConsumerRspPb { + RedirectMetaInfo info = 1; + bool meta_is_moving = 2; +} + +message DeleteStreamReqPb { + string stream_name = 1; + HostPortPb src_node_addr = 2; + int64 timeout = 3; + bool redirect = 4; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message DeleteStreamRspPb { + RedirectMetaInfo info = 1; + bool meta_is_moving = 2; +} + +message NotifyWorkerExpectedGlbConsumerNumReqPb{ + string stream_name = 1; + uint64 global_consumner_num = 2; + HostPortPb worker_addr =3; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message QueryGlobalNumReqPb { + string stream_name = 1; + bool redirect = 2; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message QueryGlobalNumRsqPb { + uint64 global_count = 1; + RedirectMetaInfo info = 2; + bool meta_is_moving = 3; +} + +message AsyncNotificationPb { + bool is_pub = 1; + bool is_close = 2; + string target_worker = 3; + string id = 4; // Producer worker address or consumer id. + bool force_close = 5; + uint32 retain_data = 6; +} + +message MetaForSCMigrationPb { + StreamMetaPb meta = 1; + repeated ProducerMetaPb producers = 2; + repeated ConsumerMetaPb consumers = 3; + repeated string producer_rel_nodes = 4; + repeated string consumer_rel_nodes = 5; + repeated AsyncNotificationPb notifications = 6; +} + +message MigrateSCMetadataReqPb { + string source_addr = 1; + repeated MetaForSCMigrationPb stream_metas = 2; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message MigrateSCMetadataRspPb { + repeated Status results = 1; + + enum Status { + SUCCESSFUL = 0; + FAILED = 1; + } +} + +service MasterSCService { + rpc CreateProducer (CreateProducerReqPb) returns (CreateProducerRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc CloseProducer (CloseProducerReqPb) returns (CloseProducerRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc Subscribe (SubscribeReqPb) returns (SubscribeRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc CloseConsumer (CloseConsumerReqPb) returns (CloseConsumerRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc DeleteStream (DeleteStreamReqPb) returns (DeleteStreamRspPb) {} + rpc QueryGlobalProducersNum (QueryGlobalNumReqPb) returns (QueryGlobalNumRsqPb) {} + rpc QueryGlobalConsumersNum (QueryGlobalNumReqPb) returns (QueryGlobalNumRsqPb) {} + rpc MigrateSCMetadata(MigrateSCMetadataReqPb) returns (MigrateSCMetadataRspPb) {} +} diff --git a/src/datasystem/protos/meta_zmq.proto b/src/datasystem/protos/meta_zmq.proto index 7472fc6..160659a 100644 --- a/src/datasystem/protos/meta_zmq.proto +++ b/src/datasystem/protos/meta_zmq.proto @@ -85,16 +85,27 @@ message PayloadDirectGetRspPb { } // Exchange of jfr. Send our local jfr and host port +message JfrBondInfo { + message JettyId { + bytes eid = 1; + uint32 uasid = 2; + uint32 id = 3; + } + JettyId base_id = 1; + repeated JettyId slave_ids = 2; + int32 dev_num = 3; + bool is_in_matrix_server = 4; + bool is_multipath = 5; +} + message UrmaHandshakeReqPb { bytes eid = 1; uint32 uasid = 2; repeated uint32 jfr_ids = 3; HostPortPb address = 4; + repeated JfrBondInfo bond_infos = 5; } // Exchange of jfr. Remote's jfr message UrmaHandshakeRspPb { - bytes eid = 1; - uint32 uasid = 2; - repeated uint32 jfr_ids = 3; } diff --git a/src/datasystem/protos/object_posix.proto b/src/datasystem/protos/object_posix.proto index 7a46e21..7f776d3 100644 --- a/src/datasystem/protos/object_posix.proto +++ b/src/datasystem/protos/object_posix.proto @@ -68,6 +68,7 @@ message MultiCreateReqPb { repeated uint64 data_size = 3; string token = 4; string tenant_id = 5; + bool skip_check_existence = 6; // put to the end, the previous data is used to generate AK and SK signatures. uint64 timestamp = 100; @@ -77,6 +78,7 @@ message MultiCreateReqPb { message MultiCreateRspPb { repeated CreateRspPb results = 1; + repeated bool exists = 2; } // For the meaning of the value, see 'ExistenceOpt' enum class. @@ -131,6 +133,7 @@ message MultiPublishReqPb { string tenant_id = 9; uint32 cache_type = 10; bool is_replica = 11; + bool auto_release_memory_ref = 12 ; // put to the end, the previous data is used to generate AK and SK signatures. uint64 timestamp = 100; diff --git a/src/datasystem/protos/p2p_subscribe.proto b/src/datasystem/protos/p2p_subscribe.proto index db0e3d1..1215301 100644 --- a/src/datasystem/protos/p2p_subscribe.proto +++ b/src/datasystem/protos/p2p_subscribe.proto @@ -177,6 +177,7 @@ message SendRootInfoReqPb { } message SendRootInfoRspPb{ + bool is_dead_lock = 1; } @@ -195,6 +196,7 @@ message RecvRootInfoReqPb { message RecvRootInfoRspPb{ bytes root_info = 1; + bool is_dead_lock = 2; } message RemoveP2PLocationReqPb { diff --git a/src/datasystem/protos/share_memory.proto b/src/datasystem/protos/share_memory.proto index 3c4df3d..7f341c6 100644 --- a/src/datasystem/protos/share_memory.proto +++ b/src/datasystem/protos/share_memory.proto @@ -56,6 +56,7 @@ message RegisterClientReqPb { } message RegisterClientRspPb { + int32 page_size = 1; bytes client_id = 2; double quorum_timeout_mult = 3; uint32 lock_id = 4; diff --git a/src/datasystem/protos/stream_posix.proto b/src/datasystem/protos/stream_posix.proto new file mode 100644 index 0000000..09ffcda --- /dev/null +++ b/src/datasystem/protos/stream_posix.proto @@ -0,0 +1,419 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Defines the POSIX API of the worker component, including the meta of the stream cache. + * These APIs are used for communication between clients and workers, + */ +syntax = "proto3"; +package datasystem; + +import "rpc_option.proto"; +import "utils.proto"; + +message ShmViewPb { + int32 fd = 1; + uint64 mmap_size = 2; + uint64 offset = 3; + uint64 size = 4; +} + +message QueryGlobalNumReqPb { + string stream_name = 1; + string token = 2; + string client_id = 3; + string tenant_id = 4; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message QueryGlobalNumRsqPb { + uint64 global_count = 1; +} + +message NotifyStreamInvalidRspPb { +} + +message CreateProducerReqPb { + string stream_name = 1; + int64 page_size = 2; + string client_id = 3; + string producer_id = 4; + uint64 max_stream_size = 5; // stream size in worker. + string token = 6; + bool auto_cleanup = 7; + string tenant_id = 8; + uint64 retain_num_consumer = 9; + bool encrypt_stream = 10; + uint64 reserve_size = 11; + int32 stream_mode = 12; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CreateProducerRspPb { + ShmViewPb page_view = 1; + uint64 sender_producer_no = 2; + bool enable_data_verification = 3; + uint64 stream_no = 4; + uint64 shared_page_size = 5; + bool enable_shared_page = 6; + ShmViewPb stream_meta_view = 7; +} + +enum SubscriptionTypePb { + STREAM_PB = 0; + ROUND_ROBIN_PB = 1; + KEY_PARTITIONS_PB = 2; +} + +message SubscriptionConfigPb { + string subscription_name = 1; + SubscriptionTypePb subscription_type = 2; + bool report_producer_fault = 3; +} + +message SubscribeReqPb { + string stream_name = 1; + SubscriptionConfigPb subscription_config = 2; + string client_id = 3; + string consumer_id = 4; + string token = 5; + string tenant_id = 6; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message SubscribeRspPb { + uint64 last_recv_cursor = 1; + int32 worker_fd = 2; + uint64 mmap_size = 3; + uint64 offset = 4; + uint64 size = 5; +} + +message ElementsMetaPb { + repeated uint32 element_sizes = 1; + uint64 total_flush_count = 2; + repeated bool header_bits = 3; +} + +message GetDataPageReqPb { + string stream_name = 1; + string subscription_name = 2; + bytes consumer_id = 3; + uint64 last_recv_cursor = 4; + int64 timeout_ms = 5; + string token = 6; + string client_id = 7; + string tenant_id = 8; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message GetDataPageRspPb { + ShmViewPb page_view = 1; +} + +message CreateShmPageReqPb { + string stream_name = 1; + bytes producer_id = 2; + string token = 3; + int64 sub_timeout = 4; + string client_id = 5; + ShmViewPb cur_view = 6; + string tenant_id = 7; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CreateShmPageRspPb { + ShmViewPb last_page_view = 1; +} + +message CloseProducerReqPb { + string stream_name = 1; + bytes producer_id = 2; + string client_id = 3; + string token = 4; + string tenant_id = 5; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CloseProducerRspPb { + +} + +message CloseConsumerReqPb { + string stream_name = 1; + string subscription_name = 2; + bytes consumer_id = 3; + string client_id = 4; + string token = 5; + string tenant_id = 6; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CloseConsumerRspPb { + +} + +message DeleteStreamReqPb { + string stream_name = 1; + string token = 2; + string client_id = 3; + string tenant_id = 4; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message DeleteStreamRspPb { +} + +message PushReqPb { + string stream_name = 1; + string worker_addr = 2; + string producer_id = 5; + uint64 first_seq = 6; + repeated ElementsMetaPb element_meta = 3; + repeated uint64 seq = 4; + string trace_id = 7; + uint64 chunk_size = 8; + string worker_instance_id = 9; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message PushRspPb { + repeated ErrorInfoPb error = 1; +} + +message StreamElementsMetaPb { + uint64 stream_index = 1; + uint64 seq = 2; + ElementsMetaPb element_meta = 3; +} + +message SharedPagePushReqPb { + repeated string stream_names = 1; + string worker_addr = 2; + string producer_id = 3; + repeated StreamElementsMetaPb metas = 4; + string trace_id = 5; + string worker_instance_id = 6; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message UnblockProducerReqPb { + string stream_name = 1; + string worker_addr = 2; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message UnblockProducerRspPb { +} + +message BlockProducerReqPb { + string stream_name = 1; + string worker_addr = 2; + int64 timeout = 3; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message BlockProducerRspPb { +} + +message TopologyElementsPb { + HostPortPb worker_address = 1; + repeated string element_names = 2; +} + +message LastAppendCursorReqPb { + string stream_name = 1; + string client_id = 2; + string token = 3; + string tenant_id = 4; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message LastAppendCursorRspPb { + uint64 last_append_cursor = 1; +} + +message ResetOrResumeStreamsReqPb { + repeated string stream_names = 1; + string client_id = 2; + string token = 3; + string tenant_id = 4; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message ResetOrResumeStreamsRspPb { +} + +message CreateLobPageReqPb { + string stream_name = 1; + string client_id = 2; + string token = 3; + string producer_id = 4; + uint64 page_size = 5; + string tenant_id = 6; + int64 sub_timeout = 7; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message CreateLobPageRspPb { + ShmViewPb page_view = 1; +} + +message ReleaseLobPageReqPb { + string stream_name = 1; + string client_id = 2; + string token = 3; + string producer_id = 4; + ShmViewPb page_view = 5; + string tenant_id = 6; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message ReleaseLobPageRspPb { +} + +// Service block for clients and other workers +service ClientWorkerSCService { + rpc CreateProducer (CreateProducerReqPb) returns (CreateProducerRspPb) { + option (datasystem.unary_socket_option) = true; + } + + rpc CloseProducer(CloseProducerReqPb) returns (CloseProducerRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc Subscribe (SubscribeReqPb) returns (SubscribeRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc CloseConsumer(CloseConsumerReqPb) returns (CloseConsumerRspPb) { + option (datasystem.unary_socket_option) = true; + } + + rpc GetDataPage(GetDataPageReqPb) returns (GetDataPageRspPb) { + option (datasystem.unary_socket_option) = true; + } + + rpc CreateShmPage(CreateShmPageReqPb) returns (CreateShmPageRspPb) { + option (datasystem.unary_socket_option) = true; + } + + rpc DeleteStream(DeleteStreamReqPb) returns (DeleteStreamRspPb) {} + + // query node number part + rpc QueryGlobalProducersNum(QueryGlobalNumReqPb) returns (QueryGlobalNumRsqPb) {} + rpc QueryGlobalConsumersNum(QueryGlobalNumReqPb) returns (QueryGlobalNumRsqPb) {} + + // Cross node worker flow control rpc. + rpc UnblockProducer (UnblockProducerReqPb) returns (UnblockProducerRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc BlockProducer(BlockProducerReqPb) returns (BlockProducerRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc GetLastAppendCursor(LastAppendCursorReqPb) returns (LastAppendCursorRspPb) {} + + // some more client worker request rpcs. + rpc ResetStreams(ResetOrResumeStreamsReqPb) returns (ResetOrResumeStreamsRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc ResumeStreams(ResetOrResumeStreamsReqPb) returns (ResetOrResumeStreamsRspPb) { + option (datasystem.unary_socket_option) = true; + } + // Extra rpc call for BigElement insert + rpc AllocBigShmMemory(CreateLobPageReqPb) returns (CreateLobPageRspPb) { + option (datasystem.unary_socket_option) = true; + } + rpc ReleaseBigShmMemory(ReleaseLobPageReqPb) returns (ReleaseLobPageRspPb) { + option (datasystem.unary_socket_option) = true; + } +} + +service WorkerWorkerSCService { + option (datasystem.channel_number_option) = 2; + option (datasystem.multi_session_option) = true; + // Batch push data + rpc PushElementsCursors (PushReqPb) returns (PushRspPb) { + option (datasystem.send_payload_option) = true; + option (datasystem.unary_socket_option) = true; + } + rpc PushSharedPageCursors (SharedPagePushReqPb) returns (PushRspPb) { + option (datasystem.send_payload_option) = true; + option (datasystem.unary_socket_option) = true; + } +} diff --git a/src/datasystem/protos/utils.proto b/src/datasystem/protos/utils.proto index e5fb262..dae691a 100644 --- a/src/datasystem/protos/utils.proto +++ b/src/datasystem/protos/utils.proto @@ -35,12 +35,25 @@ message RedirectMetaInfo { repeated string change_meta_ids = 2; } +message UrmaSegPb { + bytes eid = 1; + uint32 uasid = 2; + uint64 va = 3; + uint64 len = 4; + uint32 attr = 5; + uint32 token_id = 6; +} + +message UrmaBondSegInfoPb{ + UrmaSegPb base = 1; + repeated UrmaSegPb slaves = 2; + int32 dev_num = 3; +} + message UrmaImportSegmentPb { HostPortPb request_address = 1; /* segment */ - uint64 seg_va = 3; - uint64 seg_len = 4; - uint32 seg_flag = 5; - uint32 seg_token_id = 6; - uint64 seg_data_offset = 7; + UrmaSegPb seg = 2; + uint64 seg_data_offset = 3; + UrmaBondSegInfoPb bond_info = 4; } diff --git a/src/datasystem/protos/worker_stream.proto b/src/datasystem/protos/worker_stream.proto new file mode 100644 index 0000000..bc1c631 --- /dev/null +++ b/src/datasystem/protos/worker_stream.proto @@ -0,0 +1,181 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Defines the RPC API of the worker component, including the meta of the stream cache. + * These APIs are used for communication between workers and workers or master to worker. + */ +syntax = "proto3"; +package datasystem; +import "stream_posix.proto"; +import "utils.proto"; + +message DelStreamContextReqPb { + string stream_name = 1; + bool force_delete = 2; + int64 timeout = 3; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message DelStreamContextRspPb { + +} + +message ProducerMetaPb { + string stream_name = 1; + HostPortPb worker_address = 2; + uint64 producer_count = 3; +} + +message ConsumerMetaPb { + string stream_name = 1; + HostPortPb worker_address = 2; + string consumer_id = 3; + SubscriptionConfigPb sub_config = 4; + uint64 last_ack_cursor = 5; + string client_id = 6; +} + +message GetStreamMetadataRspPb { + string stream_name = 1; + uint64 max_stream_size = 2; + int64 page_size = 3; + bool auto_cleanup = 4; + ErrorInfoPb error = 5; + repeated ProducerMetaPb producers = 6; + repeated ConsumerMetaPb consumers = 7; + uint64 retain_num_consumer = 8; + bool is_remote_pub_empty = 9; + bool encrypt_stream = 10; + uint64 reserve_size = 11; + int32 stream_mode = 12; +} + +message GetMetadataAllStreamReqPb { + string master_address = 1; + message RangePb { + uint32 from = 1; + uint32 end = 2; + } + repeated RangePb hash_ranges = 2; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message GetMetadataAllStreamRspPb { + repeated GetStreamMetadataRspPb stream_meta = 1; +} + +message SyncPubNodeReqPb { + string stream_name = 1; + bool is_reconciliation = 2; + repeated HostPortPb worker_address_vector = 3; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message SyncPubNodeRspPb { + +} + +// Clear all remote sub when last producer close. Clear all remote pub when last consumer close +message ClearRemoteInfoReqPb { + string stream_name = 1; + bool force_close = 2; + bool rollback = 3; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message ClearRemoteInfoRspPb { + +} + +// Sync meta data part is as follows +message SyncConsumerNodeReqPb { + string stream_name = 1; + repeated ConsumerMetaPb consumer_meta_vector = 2; + uint32 retain_data = 3; + bool is_reconciliation = 4; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message SyncConsumerNodeRspPb { + +} + +message NotifyConsumerPb { + ConsumerMetaPb consumer = 1; + bool is_close = 2; +} + +message NotifyPubPb { + string stream_name = 1; + string worker_addr = 2; + bool is_close = 3; + uint64 max_stream_size = 4; + int64 page_size = 5; + bool force_close = 6; + bool auto_cleanup = 7; + uint64 retain_num_consumer = 8; + bool encrypt_stream = 9; + uint64 reserve_size = 10; + int32 stream_mode = 11; +} + +message UpdateTopoNotificationReq { + string stream_name = 1; + repeated NotifyPubPb pubs = 2; + repeated NotifyConsumerPb subs = 3; + uint32 retain_data = 4; + + // put to the end, the previous data is used to generate AK and SK signatures. + uint64 timestamp = 100; + string signature = 101; + string access_key = 102; +} + +message UpdateTopoNotificationRsp { + +} + +service MasterWorkerSCService { + // master --> worker topological change feedback + rpc UpdateTopoNotification(UpdateTopoNotificationReq) returns (UpdateTopoNotificationRsp) {} + rpc DelStreamContext(DelStreamContextReqPb) returns (DelStreamContextRspPb) {} + rpc SyncPubNode(SyncPubNodeReqPb) returns (SyncPubNodeRspPb) {} + rpc SyncConsumerNode(SyncConsumerNodeReqPb) returns (SyncConsumerNodeRspPb) {} + rpc ClearAllRemotePub(ClearRemoteInfoReqPb) returns (ClearRemoteInfoRspPb) {} + rpc ClearAllRemoteConsumer(ClearRemoteInfoReqPb) returns (ClearRemoteInfoRspPb) {} + rpc QueryMetadata(stream GetMetadataAllStreamReqPb) returns (stream GetStreamMetadataRspPb) {} +} diff --git a/src/datasystem/pybind_api/CMakeLists.txt b/src/datasystem/pybind_api/CMakeLists.txt index 86fabee..0ed42aa 100644 --- a/src/datasystem/pybind_api/CMakeLists.txt +++ b/src/datasystem/pybind_api/CMakeLists.txt @@ -2,6 +2,7 @@ add_library(ds_client_py MODULE pybind_register.cpp pybind_register_common.cpp pybind_register_object.cpp + pybind_register_stream.cpp pybind_register_kv.cpp pybind_register_hetero.cpp pybind_register_context.cpp) diff --git a/src/datasystem/pybind_api/pybind_register.cpp b/src/datasystem/pybind_api/pybind_register.cpp index 0743555..dc05b55 100644 --- a/src/datasystem/pybind_api/pybind_register.cpp +++ b/src/datasystem/pybind_api/pybind_register.cpp @@ -36,7 +36,7 @@ PybindDefinedFunctionRegister &PybindDefinedFunctionRegister::GetSingleton() // Import all functions with priority = 0 as *client_lib. PYBIND11_MODULE(libds_client_py, m) { - m.doc() = "pybind11 for object_cache client"; + m.doc() = "pybind11 for object_cache and stream_cache client"; auto all_fns = datasystem::PybindDefinedFunctionRegister::AllFunctions(); diff --git a/src/datasystem/pybind_api/pybind_register_object.cpp b/src/datasystem/pybind_api/pybind_register_object.cpp index ec23c77..55232c5 100644 --- a/src/datasystem/pybind_api/pybind_register_object.cpp +++ b/src/datasystem/pybind_api/pybind_register_object.cpp @@ -147,20 +147,20 @@ PybindDefineRegisterer g_pybind_define_f_Client("ObjectClient", PRIORITY_LOW, [] }) .def("create", - [](ObjectClient &client, const std::string &objectKey, uint64_t size, WriteMode writeMode, + [](ObjectClient &client, const std::string &objectKey, uint64_t size, ConsistencyType consistency) { TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); std::shared_ptr buffer; - CreateParam param{ .writeMode = writeMode, .consistencyType = consistency }; + CreateParam param{ .consistencyType = consistency }; datasystem::Status status = client.Create(objectKey, size, param, buffer); return std::make_pair(status, std::move(buffer)); }) .def("put", - [](ObjectClient &client, const std::string &objectKey, py::buffer value, WriteMode writeMode, + [](ObjectClient &client, const std::string &objectKey, py::buffer value, ConsistencyType consistency, const std::vector &refIds) { TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); - CreateParam param{ .writeMode = writeMode, .consistencyType = consistency }; + CreateParam param{ .consistencyType = consistency }; py::buffer_info info(value.request()); std::unordered_set nestedObjectKeys = { refIds.begin(), refIds.end() }; return client.Put(objectKey, reinterpret_cast(info.ptr), info.size, param, diff --git a/src/datasystem/pybind_api/pybind_register_stream.cpp b/src/datasystem/pybind_api/pybind_register_stream.cpp new file mode 100644 index 0000000..3637b2a --- /dev/null +++ b/src/datasystem/pybind_api/pybind_register_stream.cpp @@ -0,0 +1,189 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Register function to python. + */ +#include + +#include +#include + +#include "datasystem/common/log/log.h" +#include "datasystem/common/log/trace.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/pybind_api/pybind_register.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/stream_config.h" + +using datasystem::Consumer; +using datasystem::Element; +using datasystem::Producer; +using datasystem::StreamClient; +using datasystem::SubscriptionConfig; + +namespace datasystem { +PybindDefineRegisterer g_pybind_define_f_StreamClient("StreamClient", PRIORITY_LOW, [](const py::module *m) { + py::class_>(*m, "StreamClient") + .def(py::init([](const std::string &host, int32_t port, const std::string &clientPublicKey, + const std::string &clientPrivateKey, const std::string &serverPublicKey, + const std::string &accessKey, const std::string &secretKey, + const std::string &tenantId) { + ConnectOptions connectOpts{ .host = host, .port = port }; + connectOpts.clientPublicKey = clientPublicKey; + connectOpts.clientPrivateKey = clientPrivateKey; + connectOpts.serverPublicKey = serverPublicKey; + connectOpts.accessKey = accessKey; + connectOpts.secretKey = secretKey; + connectOpts.tenantId = tenantId; + return std::make_unique(connectOpts); + })) + .def("init", + [](StreamClient &client) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + return client.Init(); + }) + .def("CreateProducer", + [](StreamClient &client, const std::string &streamName, int64_t delayFlushTime, int64_t pageSize, + int64_t maxStreamSize, bool autoCleanup, int64_t retainForNumConsumers, bool encryptStream, + int64_t reserveSize) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + std::shared_ptr outProducer; + ProducerConf producerConf = { .delayFlushTime = delayFlushTime, + .pageSize = pageSize, + .maxStreamSize = static_cast(maxStreamSize), + .autoCleanup = autoCleanup, + .retainForNumConsumers = static_cast(retainForNumConsumers), + .encryptStream = encryptStream, + .reserveSize = static_cast(reserveSize) }; + auto status = client.CreateProducer(streamName, outProducer, producerConf); + if (status.IsError()) { + LOG(ERROR) << FormatString("CreateProducer failed for stream %s with error %s", streamName, + status.ToString()); + } + return std::make_pair(status, outProducer); + }) + .def("Subscribe", + [](StreamClient &client, const std::string &streamName, const std::string &subName, + const int subscriptionType) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + std::shared_ptr outConsumer; + const struct SubscriptionConfig config = + SubscriptionConfig(subName, (SubscriptionType)subscriptionType); + auto status = client.Subscribe(streamName, config, outConsumer); + if (status.IsError()) { + LOG(ERROR) << FormatString("Subscribe failed for stream %s with error %s", streamName, + status.ToString()); + } + return std::make_pair(status, outConsumer); + }) + .def("DeleteStream", + [](StreamClient &client, const std::string &streamName) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + auto status = client.DeleteStream(streamName); + if (status.IsError()) { + LOG(ERROR) << FormatString("DeleteStream failed for stream %s with error %s", streamName, + status.ToString()); + } + return status; + }) + .def("QueryGlobalProducersNum", + [](StreamClient &client, const std::string &streamName) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + uint64_t globalProducerNum = 0; + auto status = client.QueryGlobalProducersNum(streamName, globalProducerNum); + if (status.IsError()) { + LOG(ERROR) << FormatString("QueryGlobalProducersNum failed, stream name is %s", streamName); + } + return std::make_pair(status, globalProducerNum); + }) + .def("QueryGlobalConsumersNum", [](StreamClient &client, const std::string &streamName) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + uint64_t globalConsumerNum = 0; + auto status = client.QueryGlobalConsumersNum(streamName, globalConsumerNum); + if (status.IsError()) { + LOG(ERROR) << FormatString("QueryGlobalConsumerNum failed, stream name is %s", streamName); + } + return std::make_pair(status, globalConsumerNum); + }); +}); + +PybindDefineRegisterer g_pybind_define_f_Producer("Producer", PRIORITY_LOW, [](const py::module *m) { + py::class_>(*m, "Producer") + .def("Send", + [](Producer &producer, const py::buffer &buf, const int timeoutMs) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + py::buffer_info info = buf.request(); + Element element(static_cast(info.ptr), info.size); + return producer.Send(element, timeoutMs); + }) + .def("Send", + [](Producer &producer, const py::buffer &buf) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + py::buffer_info info = buf.request(); + Element element(static_cast(info.ptr), info.size); + return producer.Send(element); + }) + .def("Close", [](Producer &producer) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + return producer.Close(); + }); +}); + +PybindDefineRegisterer g_pybind_define_f_Consumer("Consumer", PRIORITY_LOW, [](const py::module *m) { + py::class_>(*m, "Consumer") + .def("Receive", + [](Consumer &consumer, const int expectNum, const int timeoutMs) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + std::vector outElement; + auto status = consumer.Receive(expectNum, timeoutMs, outElement); + return std::make_pair(status, outElement); + }) + .def("ReceiveAny", + [](Consumer &consumer, const int timeoutMs) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + std::vector outElement; + auto status = consumer.Receive(timeoutMs, outElement); + return std::make_pair(status, outElement); + }) + .def("Ack", + [](Consumer &consumer, const int element_id) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + return consumer.Ack(element_id); + }) + .def("Close", [](Consumer &consumer) { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + return consumer.Close(); + }); +}); + +PybindDefineRegisterer g_pybind_define_f_element("Element", PRIORITY_LOW, ([](const py::module *m) { + (void)py::class_(*m, "Element", + pybind11::buffer_protocol()) + .def("get_id", + [](Element &element) { + return element.id != ULONG_MAX ? element.id : -1; + }) + .def_buffer([](Element &element) { + return py::buffer_info(element.ptr, element.size); + }); + })); +} // namespace datasystem diff --git a/src/datasystem/worker/CMakeLists.txt b/src/datasystem/worker/CMakeLists.txt index 4274d89..2057b59 100644 --- a/src/datasystem/worker/CMakeLists.txt +++ b/src/datasystem/worker/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(object_cache) +add_subdirectory(stream_cache) add_subdirectory(hash_ring) add_subdirectory(client_manager) add_subdirectory(cluster_manager) @@ -35,6 +36,7 @@ set(WORKER_DEPEND_LIBS posix_protos share_memory_protos worker_object_cache + worker_stream_cache rpc_stub_cache_mgr worker_hash_ring worker_client_manager diff --git a/src/datasystem/worker/client_manager/client_info.cpp b/src/datasystem/worker/client_manager/client_info.cpp index 4696436..f8019c4 100644 --- a/src/datasystem/worker/client_manager/client_info.cpp +++ b/src/datasystem/worker/client_manager/client_info.cpp @@ -27,7 +27,7 @@ #include "datasystem/common/util/uuid_generator.h" DS_DEFINE_uint64(client_dead_timeout_s, 120, - "Maximum time interval for the worker to determine client death, value range: [15, UINT64_MAX)"); + "Maximum time interval for the worker to determine client death, value range: [15, UINT64_MAX/1000)"); static bool ValidClientDeadTimeoutSecs(const char *flagName, uint64_t value) { #ifdef WITH_TESTS diff --git a/src/datasystem/worker/cluster_manager/etcd_cluster_manager.cpp b/src/datasystem/worker/cluster_manager/etcd_cluster_manager.cpp index 53816df..5d34fae 100644 --- a/src/datasystem/worker/cluster_manager/etcd_cluster_manager.cpp +++ b/src/datasystem/worker/cluster_manager/etcd_cluster_manager.cpp @@ -54,8 +54,8 @@ DS_DECLARE_uint32(node_timeout_s); DS_DECLARE_uint32(node_dead_timeout_s); DS_DECLARE_uint32(add_node_wait_time_s); DS_DECLARE_string(master_address); -DS_DECLARE_string(other_az_names); -DS_DECLARE_string(az_name); +DS_DECLARE_string(other_cluster_names); +DS_DECLARE_string(cluster_name); DS_DECLARE_bool(enable_distributed_master); DS_DECLARE_bool(auto_del_dead_node); DS_DEFINE_bool(cross_az_get_meta_from_worker, false, "cross az to get metadata from worker"); @@ -100,10 +100,10 @@ EtcdClusterManager::EtcdClusterManager(const HostPort &workerAddress, const Host eventPq_ = std::make_unique, CmEventCmp>>(pqSize); workerWaitPost_ = std::make_unique(); - if (!FLAGS_other_az_names.empty() && FLAGS_enable_distributed_master) { + if (!FLAGS_other_cluster_names.empty() && FLAGS_enable_distributed_master) { ConstructOtherAzHashRings(); - for (const auto &azName : Split(FLAGS_other_az_names, ",")) { - if (azName != FLAGS_az_name) { + for (const auto &azName : Split(FLAGS_other_cluster_names, ",")) { + if (azName != FLAGS_cluster_name) { otherAZNames_.emplace_back(azName); } } @@ -134,8 +134,8 @@ EtcdClusterManager::~EtcdClusterManager() void EtcdClusterManager::ConstructOtherAzHashRings() { - for (const auto &azName : Split(FLAGS_other_az_names, ",")) { - if (azName != FLAGS_az_name) { + for (const auto &azName : Split(FLAGS_other_cluster_names, ",")) { + if (azName != FLAGS_cluster_name) { auto readRing = std::make_unique(azName, workerAddress_.ToString(), etcdDB_); (void)otherAzHashRings_.insert(std::make_pair(azName, std::move(readRing))); } @@ -157,6 +157,13 @@ Status EtcdClusterManager::Shutdown() thread_.reset(); } + if (orphanNodeMonitorThread_) { + exitFlag_ = true; + orphanWaitPost_.Set(); + orphanNodeMonitorThread_->join(); + orphanNodeMonitorThread_.reset(); + } + return Status::OK(); } @@ -260,7 +267,7 @@ Status EtcdClusterManager::Init(const ClusterInfo &clusterInfo) // as early as possible. Also, cluster manager needs this thread to to add nodes to its node table. // This thread monitors timed out nodes and demotes them to failed nodes. It also tries to generate hash tokens, // to give up reconciliation when there is timeout, and to handle etcd events (ring, node addition, node removal). - RETURN_IF_NOT_OK(StartNodeUtilThread()); + RETURN_IF_NOT_OK(StartBackgroundThread()); RETURN_IF_NOT_OK(SetupInitialClusterNodes(clusterInfo)); @@ -901,7 +908,7 @@ Status EtcdClusterManager::StartNodeUtilThread() { static const int CHECK_INTERVAL_MS = 100; auto traceId = GetStringUuid().substr(0, SHORT_TRACEID_SIZE); - LOG(INFO) << "Start background thread in cluster manager with traceId: " << traceId; + LOG(INFO) << "Start node util thread in cluster manager with traceId: " << traceId; const int clearScaledDownNodeInClusterTableMaxIntervelMs = 30'000; thread_ = std::make_unique([this, traceId]() { Status rc; @@ -941,6 +948,118 @@ Status EtcdClusterManager::StartNodeUtilThread() return Status::OK(); } +void EtcdClusterManager::GetToBeCleanNodes(const std::unordered_map &orphanNodes, + std::set> &toBeCleanNodes) +{ + static const int timeoutMs = 5000; + for (const auto &[orphanNode, timeEpoch] : orphanNodes) { + HostPort addr; + if (addr.ParseString(orphanNode).IsError()) { + continue; + } + // check the nodes that not found in hash ring + RangeSearchResult res; + auto status = etcdDB_->Get(ETCD_CLUSTER_TABLE, orphanNode, res, timeoutMs); + + typename TbbNodeTable::const_accessor accessor; + std::shared_lock lock(mutex_); + auto nodeExist = clusterNodeTable_.find(accessor, addr); + if (!nodeExist) { + LOG(INFO) << "Node " << orphanNode << " is not found in cluster table"; + continue; + } + if (timeEpoch != accessor->second->GetTimeEpoch()) { + LOG(INFO) << "Node " << orphanNode << " is updated in cluster table"; + continue; + } + auto &node = accessor->second; + if (status.GetCode() == K_NOT_FOUND) { + // if the node is not found in etcd, it needs to be removed whatever its state in clusterNodeTable_ + // is. + LOG(INFO) << "Ready to clear resource of worker " << orphanNode + << " that not found in etcd, state in cluster node table before cleanup: " + << node->ToString(addr); + if (node->NodeWasExiting()) { + toBeCleanNodes.emplace(orphanNode, false); + } else { + toBeCleanNodes.emplace(orphanNode, node->IsFailed()); + } + } else if (status.IsOk()) { + // the node has been rejoined or is ready to rejoin + if (node->IsFailed()) { + LOG(INFO) << "Ready to clear resource of worker " << orphanNode + << " that has rejoined into etcd, state in cluster node table before cleanup: " + << node->ToString(addr); + // erase the failed node to prevent scale down again and wait for the new node coming + toBeCleanNodes.emplace(orphanNode, true); + } else { + // skip the erasure. we should wait at least until the lease expires to prevent remove the + // joined node incorrectly + LOG(INFO) << "Skip to clear resource of worker " << orphanNode + << ", state in cluster node table: " << node->ToString(addr) + << ", state in etcd: " << res.value; + } + } else { + LOG(INFO) << "Failed to get node " << orphanNode << " from etcd, status: " << status.ToString(); + } + } +} + +Status EtcdClusterManager::StartOrphanNodeMonitorThread() +{ + auto traceId = GetStringUuid().substr(0, SHORT_TRACEID_SIZE); + LOG(INFO) << "Start orphan node monitor thread in cluster manager with traceId: " << traceId; + orphanNodeMonitorThread_ = std::make_unique([this, traceId]() { + Timer timer; + while (!exitFlag_) { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + orphanWaitPost_.Wait(); + std::unordered_map orphanNodes; + { + std::lock_guard lock(orphanNodeMutex_); + for (const auto &iter : orphanNodeTable_) { + orphanNodes.emplace(iter.first, iter.second); + } + orphanNodeTable_.clear(); + } + auto sz = orphanNodes.size(); + if (sz == 0) { + continue; + } + if (etcdDB_->IsKeepAliveTimeout()) { + const int logEveryN = 1000; + LOG_EVERY_N(INFO, logEveryN) + << "etcd is currently unavailable, synchronization cannot be completed, waiting for the next round " + "of retry"; + return; + } + timer.Reset(); + Raii raii([&timer, &sz]() { + static const int logThresholdMs = 1'000; + LOG_IF(INFO, timer.ElapsedMilliSecond() > logThresholdMs) + << "Cleanup" << sz << " nodes elapsed: " << timer.ElapsedMilliSecond(); + }); + std::set> toBeCleanNodes; + GetToBeCleanNodes(orphanNodes, toBeCleanNodes); + for (const auto &[addr, isFailed] : toBeCleanNodes) { + CleanupWorker(addr, isFailed); + } + if (!toBeCleanNodes.empty()) { + LOG(INFO) << "After sync with hash ring: " << NodesToString(); + } + } + }); + orphanNodeMonitorThread_->set_name("OrphanNodeMonitor"); + return Status::OK(); +} + +Status EtcdClusterManager::StartBackgroundThread() +{ + RETURN_IF_NOT_OK(StartNodeUtilThread()); + RETURN_IF_NOT_OK(StartOrphanNodeMonitorThread()); + return Status::OK(); +} + void EtcdClusterManager::HandleFailedNode(const HostPort &addr) { if (!IsCurrentNodeMaster()) { @@ -1004,14 +1123,7 @@ std::unordered_set EtcdClusterManager::GetFailedWorkers() void EtcdClusterManager::SyncNodeTableWithHashRing(const std::set &workersInRing) { INJECT_POINT("SyncNodeTableWithHashRing", [] { return; }); - const int32_t timeoutMs = 5'000; - const int maxLckTimeMs = 1'000; - Timer timer; - Raii raii([&timer]() { - LOG_IF(INFO, timer.ElapsedMilliSecond() > maxLckTimeMs) - << "SyncNodeTableWithHashRing ElapsedMilliSecond: " << timer.ElapsedMilliSecond(); - }); - std::set> toBeCleanNodes; + bool isNotify = false; { std::lock_guard lock(mutex_); std::string workerAddr; @@ -1020,49 +1132,15 @@ void EtcdClusterManager::SyncNodeTableWithHashRing(const std::set & if (ContainsKey(workersInRing, workerAddr)) { continue; } - if (etcdDB_->IsKeepAliveTimeout()) { - const int logEveryN = 1000; - LOG_EVERY_N(INFO, logEveryN) - << "etcd is currently unavailable, synchronization cannot be completed, waiting for the next round " - "of retry"; - break; - } - // check the nodes that not found in hash ring - RangeSearchResult res; - auto status = etcdDB_->Get(ETCD_CLUSTER_TABLE, workerAddr, res, timeoutMs); - if (status.GetCode() == K_NOT_FOUND) { - // if the node is not found in etcd, it needs to be removed whatever its state in clusterNodeTable_ is. - LOG(INFO) << "Ready to clear resource of worker " << workerAddr - << " that not found in etcd, state in cluster node table before cleanup: " - << iter.second->ToString(iter.first); - if (iter.second->NodeWasExiting()) { - toBeCleanNodes.emplace(workerAddr, false); - } else { - toBeCleanNodes.emplace(workerAddr, iter.second->IsFailed()); - } - } else if (status.IsOk()) { - // the node has been rejoined or is ready to rejoin - if (iter.second->IsFailed()) { - LOG(INFO) << "Ready to clear resource of worker " << workerAddr - << " that has rejoined into etcd, state in cluster node table before cleanup: " - << iter.second->ToString(iter.first); - // erase the failed node to prevent scale down again and wait for the new node coming - toBeCleanNodes.emplace(workerAddr, true); - } else { - // skip the erasure. we should wait at least until the lease expires to prevent remove the joined - // node incorrectly - LOG(INFO) << "Skip to clear resource of worker " << workerAddr - << ", state in cluster node table: " << iter.second->ToString(iter.first) - << ", state in etcd: " << res.value << ", get result: " << status.ToString(); - } + { + std::shared_lock l(orphanNodeMutex_); + orphanNodeTable_.emplace(workerAddr, iter.second->GetTimeEpoch()); } + isNotify = true; } } - for (const auto &[addr, isFailed] : toBeCleanNodes) { - CleanupWorker(addr, isFailed); - } - if (!toBeCleanNodes.empty()) { - LOG(INFO) << "After sync with hash ring: " << NodesToString(); + if (isNotify) { + orphanWaitPost_.Set(); } } diff --git a/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h b/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h index 1fb0f37..6ad7289 100644 --- a/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h +++ b/src/datasystem/worker/cluster_manager/etcd_cluster_manager.h @@ -902,6 +902,26 @@ private: */ Status StartNodeUtilThread(); + /** + * @brief Starts a thread to monitor orphaned nodes. + * @return Status of the call. + */ + Status StartOrphanNodeMonitorThread(); + + /** + * @brief Starts background thread. + * @return Status of the call. + */ + Status StartBackgroundThread(); + + /** + * @brief Get the to be clean nodes. + * @param[in] orphanNodes The orphan nodes. + * @param[out] toBeCleanNodes The to be clean nodes. + */ + void GetToBeCleanNodes(const std::unordered_map &orphanNodes, + std::set> &toBeCleanNodes); + /** * @brief The function that executes the the check for a timed out node to see if it needs to be demoted to a * failed node. Any node that is timed out and meets the criteria for demotion shall have its state changed. @@ -1042,6 +1062,12 @@ private: TbbNodeTable otherClusterNodeTable_; // Tracks node states of the other AZ's cluster nodes mutable std::shared_timed_mutex otherClusterNodeMutex_; + using TbbOrphanTable = tbb::concurrent_hash_map; + TbbOrphanTable orphanNodeTable_; + mutable std::shared_timed_mutex orphanNodeMutex_; // protect orphanNodeTable_ + WaitPost orphanWaitPost_; // wait orphanNodeTable_ is not empty + std::unique_ptr orphanNodeMonitorThread_{ nullptr }; + // The timers that generate fake node removal event, used only in StartNodeUtilThread thread. std::unordered_map nodeTableCompletionTimer_; diff --git a/src/datasystem/worker/hash_ring/hash_ring.cpp b/src/datasystem/worker/hash_ring/hash_ring.cpp index fb05422..45cb1cd 100644 --- a/src/datasystem/worker/hash_ring/hash_ring.cpp +++ b/src/datasystem/worker/hash_ring/hash_ring.cpp @@ -435,9 +435,13 @@ void HashRing::GenerateVoluntaryScaleDownChangingInfo() allocator.RemoveNodeVoluntarily(workerId, standbyWorker, hashFunction_(worker.worker_uuid()), allScaleDownWorkers, oldRing), "RemoveNodeVoluntarily failed"); + if (workerId == workerAddr_) { + // If the voluntary scale down node is not the current node, there is no need to clean up + // VoluntaryTaskId and re-execute the migration data task. + taskExecutor_->ClearVoluntaryTaskId(); + } } newValue = std::make_unique(oldRing.SerializeAsString()); - taskExecutor_->ClearVoluntaryTaskId(); return Status::OK(); }; HASH_RING_LOG_IF_ERROR(etcdStore_->CAS(ETCD_RING_PREFIX, "", funcHandler), " generate voluntary info failed"); diff --git a/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp b/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp index 497b29c..b1092f8 100644 --- a/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp +++ b/src/datasystem/worker/object_cache/obj_cache_shm_unit.cpp @@ -149,6 +149,23 @@ Status CopyAndSplitBuffer(const std::string &tenantId, const void *data, size_t return Status::OK(); } +static Status InitializeMetadataMemory(const std::string &objectKey, uint64_t metadataSize, bool populate, + ShmUnit &shmUnit) +{ + if (metadataSize > 0) { + auto ret = memset_s(shmUnit.GetPointer(), metadataSize, 0, metadataSize); + if (ret != EOK) { + if (!populate) { + shmUnit.SetHardFreeMemory(); + } + shmUnit.FreeMemory(); + RETURN_STATUS_LOG_ERROR(K_RUNTIME_ERROR, + FormatString("[ObjectKey %s] Memset failed, errno: %d", objectKey, ret)); + } + } + return Status::OK(); +} + Status AllocateMemoryForObject(const std::string &objectKey, const uint64_t dataSize, uint64_t metadataSize, bool populate, std::shared_ptr evictionManager, ShmUnit &shmUnit, CacheType cacheType) @@ -159,11 +176,12 @@ Status AllocateMemoryForObject(const std::string &objectKey, const uint64_t data FormatString("The size is overflow, size:%d + add:%d > UINT64_MAX:%d", dataSize, metadataSize, UINT64_MAX)); uint64_t needSize = dataSize + metadataSize; PerfPoint point(PerfKey::WORKER_MEMORY_ALLOCATE); - (void)EvictWhenMemoryExceedThrehold(objectKey, needSize, evictionManager, cacheType); + (void)EvictWhenMemoryExceedThrehold(objectKey, needSize, evictionManager, ServiceType::OBJECT, cacheType); // Allocate some memory into this shmUnit auto tenantId = TenantAuthManager::ExtractTenantId(objectKey); static const std::vector WAIT_MSECOND = { 1, 10, 50, 100, 200, 400, 800, 1600, 3200 }; - Status rc = shmUnit.AllocateMemory(tenantId, needSize, populate, static_cast(cacheType)); + Status rc = shmUnit.AllocateMemory(tenantId, needSize, populate, ServiceType::OBJECT, + static_cast(cacheType)); if (rc.GetCode() == K_OUT_OF_MEMORY) { INJECT_POINT("worker.AllocateMemory.afterOOM"); for (int t : WAIT_MSECOND) { @@ -179,29 +197,91 @@ Status AllocateMemoryForObject(const std::string &objectKey, const uint64_t data VLOG(1) << FormatString("OOM, sleep time: %ld, objectKey: %s, needSize %ld", sleepTime, objectKey, needSize); std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - rc = shmUnit.AllocateMemory(tenantId, needSize, populate, static_cast(cacheType)); + rc = shmUnit.AllocateMemory(tenantId, needSize, populate, ServiceType::OBJECT, + static_cast(cacheType)); if (rc.GetCode() != K_OUT_OF_MEMORY) { break; } - (void)EvictWhenMemoryExceedThrehold(objectKey, needSize, evictionManager, cacheType); + (void)EvictWhenMemoryExceedThrehold(objectKey, needSize, evictionManager, ServiceType::OBJECT, cacheType); } } RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rc, FormatString("[ObjectKey %s] Error while allocating memory.", objectKey)); - if (metadataSize > 0) { - auto ret = memset_s(shmUnit.GetPointer(), metadataSize, 0, metadataSize); - if (ret != EOK) { - if (!populate) { - shmUnit.SetHardFreeMemory(); + RETURN_IF_NOT_OK(InitializeMetadataMemory(objectKey, metadataSize, populate, shmUnit)); + + point.Record(); + workerOperationTimeCost.Append("AllocateMemory", timer.ElapsedMilliSecond()); + return Status::OK(); +} + +Status DistributeMemoryForObject(const std::string &objectKey, const uint64_t dataSize, uint64_t metadataSize, + bool populate, std::shared_ptr shmOwner, ShmUnit &shmUnit) +{ + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + UINT64_MAX - metadataSize >= dataSize, K_RUNTIME_ERROR, + FormatString("The size is overflow, size:%d + add:%d > UINT64_MAX:%d", dataSize, metadataSize, UINT64_MAX)); + uint64_t needSize = dataSize + metadataSize; + PerfPoint point(PerfKey::WORKER_MEMORY_ALLOCATE); + RETURN_IF_NOT_OK(shmOwner->DistributeMemory(needSize, shmUnit)); + RETURN_IF_NOT_OK(InitializeMetadataMemory(objectKey, metadataSize, populate, shmUnit)); + return Status::OK(); +} + +Status AggregateAllocate( + const std::string &firstObjectKey, + std::function, bool &)> &traversalHelper, + std::shared_ptr evictionManager, std::vector> &shmOwners, + std::vector &shmIndexMapping) +{ + // Pre-allocate aggregated chunks of shared memory as ShmOwner, to reduce the number of allocation calls. + // 1. Only for URMA case, non-URMA case allocates the memory when response is handled. + // 2. Aggregate only if all the objects are small objects (< 1MB size), and batch up to 1024 keys and 2MB size. + // 3. Multi-tenancy not yet supported. + const uint64_t batchLimitKeys = 1024; + const uint64_t batchLimitSingleSize = 1024 * 1024; + const uint64_t batchLimitTotalSize = 2 * 1024 * 1024; + + bool needAggregate = false; + std::vector aggreatedSizes; + uint64_t currentBatchSize = 0; + uint64_t currentKeyCount = 0; + + std::function aggregateCollector = + [&](uint64_t dataSz, uint64_t shmSize, uint32_t objectId) { + // Skip any object that has size beyond 1MB. + if (dataSz >= batchLimitSingleSize) { + return; } - shmUnit.FreeMemory(); - RETURN_STATUS_LOG_ERROR(K_RUNTIME_ERROR, - FormatString("[ObjectKey %s] Memset failed, errno: %d", objectKey, ret)); + + // Seal the last batch and start the new batch. + if (currentKeyCount >= batchLimitKeys || (currentBatchSize + shmSize) > batchLimitTotalSize) { + aggreatedSizes.emplace_back(currentBatchSize); + currentBatchSize = 0; + currentKeyCount = 0; + } + // Record the size and num, and also map from object key to ShmOwners index. + currentBatchSize += shmSize; + currentKeyCount++; + shmIndexMapping[objectId] = aggreatedSizes.size(); + }; + + traversalHelper(aggregateCollector, needAggregate); + + if (needAggregate && currentBatchSize > 0) { + // Deal with the last batch. + aggreatedSizes.emplace_back(currentBatchSize); + // Allocate memory for each batch. + for (const auto &aggregateSize : aggreatedSizes) { + std::shared_ptr shmOwner = std::make_shared(); + // All keys in the batch request should belong to the same tenant. + RETURN_IF_NOT_OK( + AllocateMemoryForObject(firstObjectKey, aggregateSize, 0, false, evictionManager, *shmOwner)); + shmOwners.push_back(shmOwner); } + } else { + shmIndexMapping.clear(); } - point.Record(); - workerOperationTimeCost.Append("AllocateMemory", timer.ElapsedMilliSecond()); return Status::OK(); } diff --git a/src/datasystem/worker/object_cache/obj_cache_shm_unit.h b/src/datasystem/worker/object_cache/obj_cache_shm_unit.h index 63fa9a2..70e464b 100644 --- a/src/datasystem/worker/object_cache/obj_cache_shm_unit.h +++ b/src/datasystem/worker/object_cache/obj_cache_shm_unit.h @@ -149,7 +149,6 @@ private: ObjectLifeState lifeState_ = ObjectLifeState::OBJECT_INVALID; }; - /** * @brief Copy and split buffer into multiple rpc message which size small than 2G. * @param[in] tenantId The tenant of the data @@ -176,6 +175,34 @@ Status AllocateMemoryForObject(const std::string &objectKey, const uint64_t data bool populate, std::shared_ptr evictionManager, ShmUnit &shmUnit, CacheType cacheType = CacheType::MEMORY); +/** + * @brief Distribute memory from already allocated ShmOwner for object. + * @param[in] objectKey The object key of entry that need to allocate memory. + * @param[in] dataSize The data size of memory in bytes. + * @param[in] metadataSize The metadata size of memory in bytes. + * @param[in] populate Indicate need populate or not. + * @param[in] shmOwner The share memory owner. + * @param[out] shmUnit The share memory info of object. + * @return Status of the call. + */ +Status DistributeMemoryForObject(const std::string &objectKey, const uint64_t dataSize, uint64_t metadataSize, + bool populate, std::shared_ptr shmOwner, ShmUnit &shmUnit); + +/** + * @brief Allocate aggregated chunks of shared memory. + * @param[in] firstObjectKey The first object key. + * @param[in] traversalHelper Helper function that does the customized traversal work. + * @param[in] evictionManager Eviction manager. + * @param[out] shmOwners The allocated shared memory chunks. + * @param[out] shmIndexMapping The object id to shmOwners index mapping. + * @return Status of the call. + */ +Status AggregateAllocate( + const std::string &firstObjectKey, + std::function, bool &)> &traversalHelper, + std::shared_ptr evictionManager, std::vector> &shmOwners, + std::vector &shmIndexMapping); + /** * @brief Allocate Shm unit and init its id. * @param[in] objectKey The object key. diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp index 5d534ab..e52b3d6 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_batch_get_impl.cpp @@ -22,6 +22,7 @@ #include #include +#include "datasystem/common/iam/tenant_auth_manager.h" #include "datasystem/common/log/log.h" #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/perf/perf_manager.h" @@ -82,6 +83,23 @@ Status WorkerOcServiceGetImpl::BatchGetRetrieveRemotePayload(uint64_t completeDa return Status::OK(); } +void WorkerOcServiceGetImpl::HandleGetFailureHelper(const std::string &objectKey, uint64_t version, + std::shared_ptr &entry, bool isInsert) +{ + LOG(WARNING) << "Get object from remote failed, start to remove location from master"; + (void)RemoveLocation(objectKey, version); + if (entry->Get() != nullptr && entry->Get()->GetShmUnit() != nullptr) { + entry->Get()->GetShmUnit()->SetHardFreeMemory(); + } + if (isInsert) { + (void)objectTable_->Erase(objectKey, *entry); + } else if (entry->Get() != nullptr) { + entry->Get()->FreeResources(); + entry->Get()->SetLifeState(ObjectLifeState::OBJECT_INVALID); + entry->Get()->stateInfo.SetCacheInvalid(true); + } +} + Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( std::vector &queryMetas, const std::map &readKeys, const std::shared_ptr &request, std::vector &payloads, @@ -143,17 +161,16 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( auto &address = queryMeta->first; auto &metaList = queryMeta->second; futures.emplace_back(workerBatchThreadPool_->Submit([this, &lastRc, address, &metaList, readKeys, &request, - &lockedEntries, &tempSuccessIds, &tempNeedRetryIds, - &tempFailedIds, &tempFailedMetas, index, traceId] { - for (auto &metaPair : metaList) { - auto &metas = metaPair.first; - TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); - lastRc = - BatchGetObjectFromRemoteOnLock(address, metas, readKeys, request, lockedEntries, - tempSuccessIds[index], tempNeedRetryIds[index], tempFailedIds[index], - tempFailedMetas[index]); - } - return lastRc; + &lockedEntries, &tempSuccessIds, &tempNeedRetryIds, + &tempFailedIds, &tempFailedMetas, index, traceId] { + for (auto &metaPair : metaList) { + auto &metas = metaPair.first; + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + lastRc = BatchGetObjectFromRemoteOnLock(address, metas, readKeys, request, lockedEntries, + tempSuccessIds[index], tempNeedRetryIds[index], + tempFailedIds[index], tempFailedMetas[index]); + } + return lastRc; })); } for (auto &fut : futures) { @@ -195,14 +212,18 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereBatched( return lastRc; } -void WorkerOcServiceGetImpl::GroupQueryMeta(master::QueryMetaInfoPb &queryMeta, std::unordered_map, uint64_t>>> &groupedQueryMetas) +void WorkerOcServiceGetImpl::GroupQueryMeta( + master::QueryMetaInfoPb &queryMeta, + std::unordered_map, uint64_t>>> &groupedQueryMetas) { const static uint64_t maxPayloadSize = FLAGS_batch_get_threshold_mb * 1024 * 1024; const auto &meta = queryMeta.meta(); auto &splitList = groupedQueryMetas[queryMeta.address()]; if (!(FLAGS_enable_urma) && (FLAGS_batch_get_threshold_mb != 0)) { - if (splitList.empty() || splitList.back().second + meta.data_size() > maxPayloadSize) { + auto payloadSize = meta.data_size() < UINT64_MAX - splitList.back().second + ? splitList.back().second + meta.data_size() + : UINT64_MAX; + if (splitList.empty() || payloadSize > maxPayloadSize) { splitList.emplace_back(std::make_pair(std::list{}, 0)); } } else { @@ -457,5 +478,39 @@ Status WorkerOcServiceGetImpl::BatchGetObjectFromRemoteOnLock( return rc; } +Status WorkerOcServiceGetImpl::AggregateAllocateHelper( + const std::list &metas, + std::map, bool>> &lockedEntries, + std::vector> &shmOwners, std::vector &shmIndexMapping) +{ + std::function, bool &)> traversalHelper = + [&metas, &lockedEntries](std::function collector, + bool &needAggregate) { + needAggregate = metas.size() > 1; + uint32_t objectId = 0; + for (auto metaIter = metas.begin(); metaIter != metas.end() && needAggregate; metaIter++, objectId++) { + auto &meta = *metaIter; + auto dataSz = meta->data_size(); + + const auto &objectKey = meta->object_key(); + auto &pair = lockedEntries.at(objectKey); + auto &entry = *(pair.first); + auto metaSz = entry->GetMetadataSize(); + uint64_t shmSize = dataSz + metaSz; + + auto shmUnit = entry->GetShmUnit(); + // Skip the aggregation if allocation is not needed for the object. + bool szChanged = (shmUnit == nullptr) || (shmUnit->size != shmSize); + if (!szChanged) { + continue; + } + collector(dataSz, shmSize, objectId); + } + }; + auto firstObjectKey = metas.front()->object_key(); + RETURN_IF_NOT_OK( + AggregateAllocate(firstObjectKey, traversalHelper, evictionManager_, shmOwners, shmIndexMapping)); + return Status::OK(); +} } // namespace object_cache } // namespace datasystem diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp index 7c9da61..1ad6d41 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_create_impl.cpp @@ -29,6 +29,8 @@ #include "datasystem/utils/status.h" #include "datasystem/worker/authenticate.h" +DS_DECLARE_uint64(oc_shm_transfer_threshold_kb); + namespace datasystem { namespace object_cache { @@ -116,7 +118,30 @@ Status WorkerOcServiceCreateImpl::MultiCreate(const MultiCreateReqPb &req, Multi FormatString("object key count %zu not match with data size count %zu", req.object_key_size(), req.data_size_size())); auto lastRc = Status::OK(); + auto totalSize = 0u; for (int i = 0; i < req.object_key().size(); i++) { + auto objectKey = req.object_key(i); + auto dataSize = req.data_size(i); + // Check whether the object is in local. + { + auto key = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, objectKey); + std::shared_ptr entry; + if (!req.skip_check_existence() && objectTable_->Get(key, entry).IsOk() && entry->RLock(false).IsOk()) { + Raii unlock([&entry]() { entry->RUnlock(); }); + if ((*entry)->IsBinary() && !(*entry)->IsInvalid()) { + resp.add_exists(true); + continue; + }; + } + } + resp.add_exists(false); + totalSize += dataSize; + } + for (int i = 0; i < req.object_key().size(); i++) { + if (resp.exists(i) || totalSize < FLAGS_oc_shm_transfer_threshold_kb * KB) { + resp.add_results(); + continue; + } const auto &objectKey = req.object_key(i); auto dataSize = req.data_size(i); // If some buffer create failed, need to rollback, remove shm-unit. diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.h b/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.h index 5d75a75..ba68522 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.h +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_delete_impl.h @@ -43,6 +43,12 @@ public: */ Status DeleteAllCopy(const DeleteAllCopyReqPb &req, DeleteAllCopyRspPb &resp); + /** + * @brief The rpc method to delete the objects. + * @param[in] req The rpc request protobuf. + * @param[out] resp The rpc response protobuf. + * @return Status of the call. + */ Status DeleteCopyNotification(const DeleteObjectReqPb &req, DeleteObjectRspPb &rsp); private: @@ -82,11 +88,18 @@ private: * @param failedObjectKeys The keys of failed objects. * @param needDeleteObjectKey The keys of need delete objects. * @param deleteRsp Delete response. - * @return + * @return Status of the call. */ Status InsertFailedId(Status &rpcStatus, Status &recvRc, std::unordered_set &failedObjectKeys, const std::vector &needDeleteObjectKey, master::DeleteAllCopyMetaRspPb &deleteRsp); + /** + * @brief Delete object from notification. + * @param objectKey The object key to delete. + * @param version The version of the key. + * @param async Whether it is asynchronous. + * @return Status of the call. + */ Status DeleteObjectFromNotification(const std::string &objectKey, uint64_t version, bool async); /** @@ -112,10 +125,9 @@ private: HostPort &localAddress_; - std::shared_ptr getProc_{ nullptr }; + std::shared_ptr getProc_{ nullptr }; // shared pointer to the workerocservicegetimpl }; - } // namespace object_cache } // namespace datasystem #endif // DATASYSTEM_OBJECT_CACHE_WORKER_SERVICE_DELETE_IMPL_H \ No newline at end of file diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp index f4d0709..648e425 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_expire_impl.cpp @@ -34,8 +34,8 @@ #include "datasystem/utils/status.h" #include "datasystem/worker/authenticate.h" -DS_DECLARE_string(other_az_names); -DS_DECLARE_string(az_name); +DS_DECLARE_string(other_cluster_names); +DS_DECLARE_string(cluster_name); DS_DECLARE_bool(cross_az_get_data_from_worker); DS_DECLARE_bool(cross_az_get_meta_from_worker); @@ -47,8 +47,8 @@ WorkerOcServiceExpireImpl::WorkerOcServiceExpireImpl(WorkerOcServiceCrudParam &i std::shared_ptr akSkManager) : WorkerOcServiceCrudCommonApi(initParam), etcdCM_(etcdCM), akSkManager_(std::move(akSkManager)) { - for (const auto &azName : Split(FLAGS_other_az_names, ",")) { - if (azName != FLAGS_az_name) { + for (const auto &azName : Split(FLAGS_other_cluster_names, ",")) { + if (azName != FLAGS_cluster_name) { otherAZNames_.emplace_back(azName); } } diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp index 2ceb0ad..95871f0 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.cpp @@ -51,8 +51,8 @@ #include "datasystem/worker/object_cache/object_kv.h" #include "datasystem/worker/object_cache/worker_worker_oc_api.h" -DS_DECLARE_string(other_az_names); -DS_DECLARE_string(az_name); +DS_DECLARE_string(other_cluster_names); +DS_DECLARE_string(cluster_name); DS_DECLARE_bool(cross_az_get_data_from_worker); DS_DECLARE_bool(cross_az_get_meta_from_worker); DS_DECLARE_bool(oc_io_from_l2cache_need_metadata); @@ -77,9 +77,10 @@ WorkerOcServiceGetImpl::WorkerOcServiceGetImpl(WorkerOcServiceCrudParam &initPar akSkManager_(std::move(akSkManager)), localAddress_(std::move(localAddress)) { + remoteGetThreadPool_ = std::make_unique(1, FLAGS_rpc_thread_num, "RemoteGetThreadPool"); if (HaveOtherAZ()) { - for (const auto &azName : Split(FLAGS_other_az_names, ",")) { - if (azName != FLAGS_az_name) { + for (const auto &azName : Split(FLAGS_other_cluster_names, ",")) { + if (azName != FLAGS_cluster_name) { otherAZNames_.emplace_back(azName); } } @@ -265,9 +266,7 @@ Status WorkerOcServiceGetImpl::ProcessGetObjectRequest( MarkObjectsInGetProcess(objectKeys); - Raii getProcessGuard([this, &objectKeys]() { - UnmarkObjectsInGetProcess(objectKeys); - }); + Raii getProcessGuard([this, &objectKeys]() { UnmarkObjectsInGetProcess(objectKeys); }); // Try get from local. TryGetObjectFromLocal(offsetInfos, request, objectsNeedGetRemote); @@ -775,9 +774,9 @@ Status WorkerOcServiceGetImpl::GetObjectFromRemoteWorkerAndDump(const std::strin template Status WorkerOcServiceGetImpl::PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV &objectKV, Req &reqPb, - bool &shmUnitAllocated) + bool &shmUnitAllocated, std::shared_ptr shmOwner) { - if (!IsUrmaEnabled()) { + if (!IsUrmaEnabled() && shmOwner == nullptr) { return Status::OK(); } reqPb.set_data_size(dataSize); @@ -786,6 +785,7 @@ Status WorkerOcServiceGetImpl::PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV & return Status::OK(); }); // Allocate the memory for the remote worker to urma_write. + // Or early distribute memory for general code path. const auto &objectKey = objectKV.GetObjKey(); auto &entry = objectKV.GetObjEntry(); auto metaSz = entry->GetMetadataSize(); @@ -795,7 +795,13 @@ Status WorkerOcServiceGetImpl::PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV & // Only create new shm if size changed or not exist. if (szChanged) { shmUnit = std::make_shared(); - RETURN_IF_NOT_OK(AllocateMemoryForObject(objectKey, dataSize, metaSz, false, evictionManager_, *shmUnit)); + bool populate = false; + if (shmOwner) { + RETURN_IF_NOT_OK(DistributeMemoryForObject(objectKey, dataSize, metaSz, populate, shmOwner, *shmUnit)); + } else { + RETURN_IF_NOT_OK( + AllocateMemoryForObject(objectKey, dataSize, metaSz, populate, evictionManager_, *shmUnit)); + } shmUnit->id = GetStringUuid(); entry->SetShmUnit(shmUnit); shmUnitAllocated = true; @@ -813,8 +819,14 @@ Status WorkerOcServiceGetImpl::ConstructBatchGetRequest( PerfPoint point(PerfKey::WORKER_CONSTRUCT_BATCH_GET_REQ); // The function is placed together with PrepareUrmaInfo, as the template definition does not suit in header file. Status lastRc = Status::OK(); + // Pre-allocate an aggregated chunk of shared memory as ShmOwner, to reduce the number of allocation calls. + std::vector> shmOwners; + std::vector shmIndexMapping(metas.size(), std::numeric_limits::max()); + RETURN_IF_NOT_OK(AggregateAllocateHelper(metas, lockedEntries, shmOwners, shmIndexMapping)); + bool requestReady = false; - for (auto metaIter = metas.begin(); metaIter != metas.end();) { + uint32_t objectId = 0; + for (auto metaIter = metas.begin(); metaIter != metas.end(); objectId++) { auto &meta = *metaIter; const auto &objectKey = meta->object_key(); // Checked availability when metas are grouped, so it should be safe to just access the entry here. @@ -839,7 +851,11 @@ Status WorkerOcServiceGetImpl::ConstructBatchGetRequest( // Prepare the protobuf with urma info for data transfer if applicable. // BatchGetObjectHandleIndividualStatus will free ShmUnit upon error, so no need to actually record it here. bool shmUnitAllocated = false; - status = PrepareUrmaInfo(meta->data_size(), objectKV, subReq, shmUnitAllocated); + std::shared_ptr shmOwner = nullptr; + if (shmIndexMapping.size() > objectId && shmOwners.size() > shmIndexMapping[objectId]) { + shmOwner = shmOwners[shmIndexMapping[objectId]]; + } + status = PrepareUrmaInfo(meta->data_size(), objectKV, subReq, shmUnitAllocated, shmOwner); if (status.IsError()) { BatchGetObjectHandleIndividualStatus(status, objectKey, readKey, successIds, needRetryIds, failedIds); metaIter = metas.erase(metaIter); @@ -1346,7 +1362,7 @@ Status WorkerOcServiceGetImpl::QueryMetaDataFromEtcd(const std::unordered_set &queryMetas, const std::map &readKeys, + const std::shared_ptr &request, std::vector &payloads, + std::map, bool>> &lockedEntries, + std::unordered_set &failedIds, std::vector &needRetryIds) +{ + const size_t kMinParallelRequests = 2; + if (queryMetas.size() < kMinParallelRequests) { + return GetObjectsFromAnywhereSerially(queryMetas, readKeys, request, payloads, lockedEntries, failedIds, + needRetryIds); + } + Status lastRc = Status::OK(); + std::vector successIds; + successIds.reserve(queryMetas.size()); + + std::vector> futures; + std::atomic abortAllTasks{ false }; + std::mutex commonMutex; + + for (size_t i = 0; i < queryMetas.size(); ++i) { + if (abortAllTasks.load()) { + break; + } + + const auto &queryMeta = queryMetas[i]; + const auto &meta = queryMeta.meta(); + + const auto dataFormat = static_cast(queryMeta.meta().config().data_format()); + if (dataFormat != DataFormat::BINARY && dataFormat != DataFormat::HETERO) { + lastRc = Status(K_INVALID, "object data format not match."); + failedIds.emplace(meta.object_key()); + LOG(ERROR) << lastRc; + continue; + } + auto iter = lockedEntries.find(meta.object_key()); + if (iter == lockedEntries.end()) { + LOG(ERROR) << FormatString("[ObjectKey %s] QueryMeta exist but lock entry absent, should not happen", + meta.object_key()); + lastRc = Status(K_UNKNOWN_ERROR, "QueryMeta exist but lock entry absent, should not happen"); + continue; + } + if (readKeys.find(meta.object_key()) == readKeys.end()) { + LOG(ERROR) << FormatString("[ObjectKey %s] cant find offset and size to get", meta.object_key()); + lastRc = Status(K_UNKNOWN_ERROR, "Can not find offset or size to get object"); + continue; + } + ReadKey readKey = readKeys.at(meta.object_key()); + + Timer timer; + int64_t realTimeoutMs = reqTimeoutDuration.CalcRealRemainingTime(); + + futures.emplace_back(remoteGetThreadPool_->Submit([=, &lockedEntries, &commonMutex, &abortAllTasks, &request, + &payloads, &lastRc, &successIds, &needRetryIds, + &failedIds]() { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + int64_t elapsed = timer.ElapsedMilliSecond(); + reqTimeoutDuration.Init(realTimeoutMs - elapsed); + if (abortAllTasks.load()) { + return Status::OK(); + } + const auto &queryMeta = queryMetas[i]; + const auto &meta = queryMeta.meta(); + auto subIter = lockedEntries.find(meta.object_key()); + if (subIter == lockedEntries.end()) { + std::lock_guard lock(commonMutex); + LOG(INFO) << FormatString("[ObjectKey %s] Object not found in locked entries", meta.object_key()); + lastRc = Status(K_NOT_FOUND, + FormatString("[ObjectKey %s] Object not found in locked entries", meta.object_key())); + return lastRc; + } + std::shared_ptr &subEntry = subIter->second.first; + bool isInsert = subIter->second.second; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(subEntry->TransferWLockToCurrentThread(), "Lock failed"); + Status status = GetObjectFromAnywhereWithLock(readKey, request, subEntry, isInsert, queryMeta, payloads); + + // Protects access to successIds, needRetryIds, failedIds, and lastRc + std::lock_guard lock(commonMutex); + + if (status.IsOk()) { + LOG(INFO) << FormatString("[ObjectKey %s] Get from remote success.", meta.object_key()); + successIds.push_back(meta.object_key()); + } else if (status.GetCode() == K_WORKER_PULL_OBJECT_NOT_FOUND) { + LOG(INFO) << FormatString("[ObjectKey %s] Object not found in remote worker.", meta.object_key()); + needRetryIds.emplace_back(readKey); + } else if (status.GetCode() == K_OUT_OF_MEMORY) { + LOG(INFO) << FormatString("[ObjectKey %s] Out of memory, get remote abort.", meta.object_key()); + lastRc = status; + abortAllTasks.store(true); + } else { + LOG(ERROR) << FormatString("[ObjectKey %s] Get from remote failed: %s.", meta.object_key(), + status.ToString()); + failedIds.emplace(meta.object_key()); + lastRc = status; + } + + return status; + })); + } + + for (auto &f : futures) { + f.wait(); + } + + if (successIds.size() != queryMetas.size()) { + LOG(ERROR) << "Failed to get object data from remote. " << successIds.size() << " objects pulled success: [" + << VectorToString(successIds) << "], meta data num: " << queryMetas.size() + << " lastRc: " << lastRc.ToString(); + } + return lastRc; } Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereSerially( @@ -1529,23 +1656,6 @@ Status WorkerOcServiceGetImpl::GetObjectsFromAnywhereSerially( return lastRc; } -void WorkerOcServiceGetImpl::HandleGetFailureHelper(const std::string &objectKey, uint64_t version, - std::shared_ptr &entry, bool isInsert) -{ - LOG(WARNING) << "Get object from remote failed, start to remove location from master"; - (void)RemoveLocation(objectKey, version); - if (entry->Get() != nullptr && entry->Get()->GetShmUnit() != nullptr) { - entry->Get()->GetShmUnit()->SetHardFreeMemory(); - } - if (isInsert) { - (void)objectTable_->Erase(objectKey, *entry); - } else if (entry->Get() != nullptr) { - entry->Get()->FreeResources(); - entry->Get()->SetLifeState(ObjectLifeState::OBJECT_INVALID); - entry->Get()->stateInfo.SetCacheInvalid(true); - } -} - Status WorkerOcServiceGetImpl::GetObjectFromAnywhereWithLock(const ReadKey &readKey, const std::shared_ptr &request, std::shared_ptr &entry, bool isInsert, @@ -1898,7 +2008,7 @@ bool WorkerOcServiceGetImpl::IsGetFromL2Storage(bool canNotFindInWorker, bool wr bool WorkerOcServiceGetImpl::HaveOtherAZ() { - return !FLAGS_other_az_names.empty(); + return !FLAGS_other_cluster_names.empty(); } Status WorkerOcServiceGetImpl::BatchLockForGet( diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h index ab9bd87..64491c3 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_get_impl.h @@ -174,7 +174,7 @@ private: * @return K_OK on success; the error code otherwise. */ Status GetMapOfObjectKeys(const std::vector> &objectKeys, - std::unordered_map &result, Status &lastRc); + std::unordered_map &result, Status &lastRc); /** * @brief Process a get request from client. @@ -293,17 +293,32 @@ private: */ Status UpdateLocation(const std::string &objectKey, ObjectKV &objectKV); + /** + * @brief Helper function to allocate aggregated memory for objects at Batch Get. + * @param[in] metas The batched object meta info contains data size. + * @param[in] lockedEntries The object lock entries. + * @param[out] shmOwners The allocated shared memory chunks. + * @param[out] shmIndexMapping The object key to shmOwners index mapping. + * @return Status of the call. + */ + Status AggregateAllocateHelper(const std::list &metas, + std::map, bool>> &lockedEntries, + std::vector> &shmOwners, + std::vector &shmIndexMapping); + /** * @brief Pull object data from remote worker. * @note The request protobuf needs to contain data_size and urma_info fields. * @param[in] dataSize The object data size. * @param[in] kv The reserved and locked safe object and its corresponding objectKey. * @param[in] reqPb The remote GetObject rpc req protobuf. - * @param[out] shmUnitAllocated did memory allocated during this call + * @param[out] shmUnitAllocated did memory allocated during this call. + * @param[in] shmOwner The allocated shared memory chunks. * @return Status of the call. */ template - Status PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV &objectKV, Req &reqPb, bool &shmUnitAllocated); + Status PrepareUrmaInfo(uint64_t dataSize, ReadObjectKV &objectKV, Req &reqPb, bool &shmUnitAllocated, + std::shared_ptr shmOwner = nullptr); /** * @brief Pull object data from remote worker. @@ -421,6 +436,23 @@ private: std::map, bool>> &lockedEntries, std::unordered_set &failedIds, std::vector &needRetryIds); + /** + * @brief Get objects from anywhere parallelly. + * @param[in] queryMetas QueryMeta result requested from master. + * @param[in] readKeys read key info, contain offset, size, objKey. + * @param[in] request Get request instance. + * @param[in] payloads Get payloads that contains object data. + * @param[in] lockedEntries Object lock entries. + * @param[out] failedIds Failed get object keys. + * @param[out] needRetryIds Need retry get id list. + * @return Status of the call. + */ + Status GetObjectsFromAnywhereParallelly( + const std::vector &queryMetas, const std::map &readKeys, + const std::shared_ptr &request, std::vector &payloads, + std::map, bool>> &lockedEntries, + std::unordered_set &failedIds, std::vector &needRetryIds); + /** * @brief Get objects from anywhere serially. * @param[in] queryMetas QueryMeta result requested from master. @@ -527,8 +559,9 @@ private: * @param[out] groupedQueryMeta Grouped meta by address and payload split by threshold. * @return Status of the call. */ - void GroupQueryMeta(master::QueryMetaInfoPb &queryMeta, std::unordered_map, uint64_t>>> &groupedQueryMetas); + void GroupQueryMeta( + master::QueryMetaInfoPb &queryMeta, + std::unordered_map, uint64_t>>> &groupedQueryMetas); /** * @brief Helper function to handle individual status returned from the batch get request. @@ -835,8 +868,7 @@ private: */ void ProcessQueryMetaFailedObjsWhenMetaStoredInEtcd( const std::unordered_map> &objKeysUndecidedMaster, - std::unordered_set &&objectKeysNotExist, - const std::unordered_set &objectKeysPuzzled, + std::unordered_set &&objectKeysNotExist, const std::unordered_set &objectKeysPuzzled, const std::unordered_set &objectKeysMayInOtherAz, std::vector &queryMetas, std::vector &absentObjectKeys); @@ -868,17 +900,19 @@ private: std::shared_ptr threadPool_{ nullptr }; + std::unique_ptr remoteGetThreadPool_{ nullptr }; + std::shared_ptr akSkManager_{ nullptr }; HostPort localAddress_; - std::shared_timed_mutex inRemoteGetIdsMutex_; + std::shared_timed_mutex inRemoteGetIdsMutex_; // the mutex for inRemoteGetIds_ - std::unordered_set inRemoteGetIds_; + std::unordered_set inRemoteGetIds_; // the object keys that in remote get std::vector otherAZNames_; - std::shared_mutex objectsInGetProcessMutex_; + std::shared_mutex objectsInGetProcessMutex_; // the mutex for objectsInGetProcess_ std::unordered_map objectsInGetProcess_; diff --git a/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp b/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp index be21dc6..b1451cd 100644 --- a/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp +++ b/src/datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.cpp @@ -21,6 +21,7 @@ #include "datasystem/worker/object_cache/service/worker_oc_service_migrate_impl.h" #include +#include #include #include #include @@ -582,7 +583,7 @@ Status WorkerOcServiceMigrateImpl::AllocateAndAssignData( auto needSize = size + metaSize; auto tenantId = TenantAuthManager::ExtractTenantId(objectKey); RETURN_IF_NOT_OK_PRINT_ERROR_MSG( - shmUnit->AllocateMemory(tenantId, needSize, false, + shmUnit->AllocateMemory(tenantId, needSize, false, ServiceType::OBJECT, static_cast((*entry)->modeInfo.GetCacheType())), FormatString("[Migrate Data] %s allocate memory failed, size: %ld", objectKey, needSize)); shmUnit->id = GetStringUuid(); @@ -813,7 +814,9 @@ bool WorkerOcServiceMigrateImpl::IsDiskAvailable(uint64_t size) const LOG_EVERY_T(INFO, freq) << "[Migrate Data] Disk now is not available"; return false; } - uint64_t used = memory::Allocator::Instance()->GetTotalRealMemoryUsage(memory::CacheType::DISK) + size; + auto realMemoryUsage = + memory::Allocator::Instance()->GetTotalRealMemoryUsage(ServiceType::OBJECT, memory::CacheType::DISK); + uint64_t used = size < UINT64_MAX - realMemoryUsage ? realMemoryUsage + size : UINT64_MAX; uint64_t total = memory::Allocator::Instance()->GetMaxMemoryLimit(memory::CacheType::DISK); return used <= total * MIGRATE_HIGH_WATER_FACTOR; } diff --git a/src/datasystem/worker/object_cache/worker_oc_eviction_manager.cpp b/src/datasystem/worker/object_cache/worker_oc_eviction_manager.cpp index 283c8ea..e050271 100644 --- a/src/datasystem/worker/object_cache/worker_oc_eviction_manager.cpp +++ b/src/datasystem/worker/object_cache/worker_oc_eviction_manager.cpp @@ -229,9 +229,9 @@ Status WorkerOcEvictionManager::EvictClearObject(ObjectKV &objectKV) uint64_t WorkerOcEvictionManager::GetLowWaterMark(CacheType cacheType) { memory::CacheType memCacheType = static_cast(cacheType); - auto maxMemorySize = datasystem::memory::Allocator::Instance()->GetMaxMemorySize(memCacheType); + auto maxMemorySize = datasystem::memory::Allocator::Instance()->GetMaxMemorySize(ServiceType::OBJECT, memCacheType); auto usedMemorySize = - datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(memCacheType); + datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(ServiceType::OBJECT, memCacheType); auto lowWater = static_cast( std::min(datasystem::memory::Allocator::Instance()->GetTotalRealMemoryFree(memCacheType) + usedMemorySize, maxMemorySize) @@ -243,7 +243,7 @@ bool WorkerOcEvictionManager::IsAboveLowWaterMark(uint64_t needSize, size_t pend { uint64_t max = std::numeric_limits::max(); auto realMemoryUsage = datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage( - static_cast(cacheType)); + ServiceType::OBJECT, static_cast(cacheType)); realMemoryUsage = (realMemoryUsage > max - needSize) ? max : realMemoryUsage + needSize; auto lowWater = GetLowWaterMark(cacheType); lowWater = (lowWater > max - pendingSpillSize) ? max : lowWater + pendingSpillSize; @@ -726,14 +726,16 @@ std::string WorkerOcEvictionManager::GetActionName(Action action) } bool EvictWhenMemoryExceedThrehold(const std::string &keyInfo, uint64_t needSize, - const std::shared_ptr &evictionManager, CacheType cacheType) + const std::shared_ptr &evictionManager, ServiceType type, + CacheType cacheType) { uint64_t realMemoryUsed = 0; uint64_t memOccupied = 0; uint64_t maxAvailableMemorySize = 0; memory::CacheType memCacheType = static_cast(cacheType); uint64_t memThreshold = 0; - auto realObjMemoryUsed = datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(memCacheType); + auto realObjMemoryUsed = + datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(ServiceType::OBJECT, memCacheType); auto getMemThresInitVal = [](uint64_t maxAvailableMemorySize, uint64_t evictionThresholdMB) { return std::max(static_cast(maxAvailableMemorySize * HIGH_WATER_FACTOR), maxAvailableMemorySize > evictionThresholdMB * MB_TO_BYTES @@ -745,16 +747,24 @@ bool EvictWhenMemoryExceedThrehold(const std::string &keyInfo, uint64_t needSize // it could never be success, so skip evict. return false; } - - realMemoryUsed = realObjMemoryUsed; - memOccupied = realMemoryUsed + needSize; - maxAvailableMemorySize = - std::min(datasystem::memory::Allocator::Instance()->GetMaxMemorySize(memCacheType), - (datasystem::memory::Allocator::Instance()->GetTotalRealMemoryFree(memCacheType) + realMemoryUsed)); - static uint64_t memThresInitVal = - getMemThresInitVal(maxAvailableMemorySize, FLAGS_eviction_reserve_mem_threshold_mb); - memThreshold = memThresInitVal; - + if (type == ServiceType::OBJECT) { + realMemoryUsed = realObjMemoryUsed; + memOccupied = realMemoryUsed + needSize; + maxAvailableMemorySize = std::min( + datasystem::memory::Allocator::Instance()->GetMaxMemorySize(type, memCacheType), + (datasystem::memory::Allocator::Instance()->GetTotalRealMemoryFree(memCacheType) + realMemoryUsed)); + static uint64_t memThresInitVal = + getMemThresInitVal(maxAvailableMemorySize, FLAGS_eviction_reserve_mem_threshold_mb); + memThreshold = memThresInitVal; + } else if (type == ServiceType::STREAM) { + realMemoryUsed = + datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(ServiceType::STREAM) + realObjMemoryUsed; + memOccupied = realMemoryUsed + needSize; + maxAvailableMemorySize = datasystem::memory::Allocator::Instance()->GetMaxMemoryLimit(); + static uint64_t memThresInitVal = + getMemThresInitVal(maxAvailableMemorySize, FLAGS_eviction_reserve_mem_threshold_mb); + memThreshold = memThresInitVal; + } VLOG(1) << FormatString("Allocate memory for %s, size = %lu, memOccupied = %lu, memThreshold = %lu", keyInfo, needSize, memOccupied, memThreshold); if (memOccupied >= memThreshold && realObjMemoryUsed > 0) { diff --git a/src/datasystem/worker/object_cache/worker_oc_eviction_manager.h b/src/datasystem/worker/object_cache/worker_oc_eviction_manager.h index 3acac43..1c42265 100644 --- a/src/datasystem/worker/object_cache/worker_oc_eviction_manager.h +++ b/src/datasystem/worker/object_cache/worker_oc_eviction_manager.h @@ -164,8 +164,7 @@ private: if (!info.empty()) { ss << info << ", "; } - ss << "evict action " << actionName << ", total cost " << elapsed << " ms, " - << "obj size: " << objectSize; + ss << "evict action " << actionName << ", total cost " << elapsed << " ms, " << "obj size: " << objectSize; if (action == Action::SPILL) { ss << "spill cost " << spillCost << " ms, "; } @@ -377,12 +376,13 @@ private: * @param[in] keyInfo The ID of the object need to allocate. * @param[in] needSize The size need to allocate. * @param[in] evictionManager The class of eviction process. + * @param[in] type The service type. * @param[in] cacheType The type of cache. * @return True if eviction is triggered. */ bool EvictWhenMemoryExceedThrehold(const std::string &keyInfo, uint64_t needSize, const std::shared_ptr &evictionManager, - CacheType cacheType = CacheType::MEMORY); + ServiceType type = ServiceType::OBJECT, CacheType cacheType = CacheType::MEMORY); } // namespace object_cache } // namespace datasystem diff --git a/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp b/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp index e56f799..1dc683a 100644 --- a/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp +++ b/src/datasystem/worker/object_cache/worker_oc_service_impl.cpp @@ -108,7 +108,7 @@ DS_DEFINE_uint32(data_migrate_rate_limit_mb, 40, "Data migrate rate limit for ev DS_DECLARE_uint32(max_client_num); DS_DECLARE_string(worker_address); DS_DECLARE_string(master_address); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); DS_DECLARE_string(etcd_address); DS_DECLARE_bool(cross_az_get_data_from_worker); DS_DECLARE_bool(cross_az_get_meta_from_worker); @@ -346,7 +346,21 @@ Status WorkerOCServiceImpl::MultiPublish(const MultiPublishReqPb &req, MultiPubl ReadLock noRecon; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(noRecon, reqTimeoutDuration.CalcRemainingTime()), "validate worker state failed"); - return multiPublishProc_->MultiPublish(req, resp, payloads); + RETURN_IF_NOT_OK(multiPublishProc_->MultiPublish(req, resp, payloads)); + if (req.auto_release_memory_ref()) { + std::set failedSet{ resp.failed_object_keys().begin(), resp.failed_object_keys().end() }; + std::vector shmIds; + shmIds.reserve(req.object_info_size()); + for (auto &info : req.object_info()) { + if (failedSet.find(info.object_key()) != failedSet.end()) { + continue; + } + shmIds.emplace_back(info.shm_id()); + } + VLOG(1) << "auto release ref " << VectorToString(shmIds); + return DecreaseMemoryRef(req.client_id(), shmIds); + } + return Status::OK(); } void WorkerOCServiceImpl::GetObjectsMatch(std::function matchFunc, @@ -450,7 +464,7 @@ Status WorkerOCServiceImpl::ProcessVoluntaryScaledown(const std::string &taskId) std::vector needWaitIds; RETURN_IF_NOT_OK(BeforeMigrateData(taskId, needMigrateDataIds, needWaitIds)); INJECT_POINT("VoluntaryScaledown.MigrateData.Delay"); - MigrateData(needMigrateDataIds, taskId); + RETURN_IF_NOT_OK(MigrateData(needMigrateDataIds, taskId)); // When we have finish migrate data task, we can remove the location. std::vector removeFailedIds; GroupAndRemoveMeta(needMigrateDataIds, master::RemoveMetaReqPb::NORMAL, removeFailedIds, needMigrateDataIds, @@ -684,14 +698,16 @@ Status WorkerOCServiceImpl::MigrateData(const MigrateDataReqPb &req, MigrateData return gMigrateProc_->MigrateData(req, rsp, std::move(payloads)); } -void WorkerOCServiceImpl::MigrateData(const std::vector &objectKeys, const std::string &taskId, - MigrateStrategy::MigrationStrategyStage stage) +Status WorkerOCServiceImpl::MigrateData(const std::vector &objectKeys, const std::string &taskId, + MigrateStrategy::MigrationStrategyStage stage) { - INJECT_POINT_NO_RETURN("WorkerOCServiceImpl.MigrateData.Delay", - [](int sleepMs) { std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); }); + INJECT_POINT("WorkerOCServiceImpl.MigrateData.Delay", [](int sleepMs) { + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); + return Status::OK(); + }); if (objectKeys.empty()) { LOG(INFO) << "[Migrate Data] No object data need to be migrated, we have finish the job, task id: " << taskId; - return; + return Status::OK(); } LOG(INFO) << "[Migrate Data] Processing valuntary scale down data migrate begin, object size: " << objectKeys.size() @@ -724,6 +740,7 @@ void WorkerOCServiceImpl::MigrateData(const std::vector &objectKeys objKeysGrpByMaster.clear(); MetaAddrInfo info; (void)objKeysGrpByMaster.emplace(info, objectKeys); + return Status::OK(); }); for (const auto &item : objKeysGrpByMaster) { futures.emplace_back(MigrateDataByNode(item.first, item.second, progress, threadPool, MigrateStrategy(stage))); @@ -731,12 +748,13 @@ void WorkerOCServiceImpl::MigrateData(const std::vector &objectKeys while (!futures.empty()) { std::vector> newFutures; - HandleMigrateDataResult(taskId, progress, threadPool, futures, newFutures); + RETURN_IF_NOT_OK(HandleMigrateDataResult(taskId, progress, threadPool, futures, newFutures)); futures.swap(newFutures); } + return Status::OK(); } -void WorkerOCServiceImpl::HandleMigrateDataResult( +Status WorkerOCServiceImpl::HandleMigrateDataResult( const std::string &taskId, const std::shared_ptr progress, const std::unique_ptr &threadPool, std::vector> &futures, std::vector> &newFutures) @@ -746,9 +764,11 @@ void WorkerOCServiceImpl::HandleMigrateDataResult( LOG(INFO) << MigrateDataHandler::ResultToString(result); if (!result.failedIds.empty()) { if (!taskId.empty() && etcdCM_->CheckVoluntaryTaskExpired(taskId)) { - LOG(ERROR) << "task id has expired, no need to excute voluntary scale down migrate data task, task id:" - << taskId; - break; + RETURN_STATUS_LOG_ERROR( + K_RUNTIME_ERROR, + FormatString( + "task id has expired, no need to excute voluntary scale down migrate data task, task id: %s", + taskId)); } if (!taskId.empty() && etcdCM_->CheckVoluntaryScaleDown()) { LOG(ERROR) << "this node maybe failed or only one node left, no need to excute voluntary scale down " @@ -760,6 +780,7 @@ void WorkerOCServiceImpl::HandleMigrateDataResult( result.migrateDataStrategy)); } } + return Status::OK(); } std::future WorkerOCServiceImpl::RedirectMigrateData( @@ -2127,9 +2148,19 @@ Status WorkerOCServiceImpl::GetP2PMeta( std::string traceID = Trace::Instance().GetTraceID(); devThreadPool_->Execute([=]() mutable { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); - auto objectKey = req.dev_obj_meta(0).object_key(); - LOG(INFO) << "Worker processes GetP2PMeta from client: " << clientId << ", objects: " << objectKey - << ", threads Statistics: " << devThreadPool_->GetStatistics(); + std::stringstream allKeys; + bool first = true; + for (const auto &dev_obj_meta : *req.mutable_dev_obj_meta()) { + if (!first) { + allKeys << ", "; + } + allKeys << dev_obj_meta.object_key(); + first = false; + } + LOG(INFO) << FormatString("Worker processes GetP2PMeta from client: %s, allKeys: [%s], threads Statistics: %s", + clientId, + allKeys.str(), + devThreadPool_->GetStatistics()); int64_t elapsed = timer.ElapsedMilliSecond(); if (elapsed >= timeout) { LOG(ERROR) << "GetP2PMeta RPC timeout. time elapsed " << elapsed << ", subTimeout:" << timeout @@ -2180,8 +2211,8 @@ Status WorkerOCServiceImpl::RecvRootInfo( int64_t timeout = reqTimeoutDuration.CalcRealRemainingTime(); devThreadPool_->Execute([=]() mutable { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); - LOG(INFO) << "Worker processes RecvRootInfo from dstClientId: " << req.dst_client_id() - << ", dst_device_id: " << req.dst_device_id() + LOG(INFO) << "Worker processes RecvRootInfo from srcClientId: " << req.src_device_id() + << ", src_device_id: " << req.src_device_id() << ", threads Statistics: " << devThreadPool_->GetStatistics(); int64_t elapsed = timer.ElapsedMilliSecond(); if (elapsed >= timeout) { diff --git a/src/datasystem/worker/object_cache/worker_oc_service_impl.h b/src/datasystem/worker/object_cache/worker_oc_service_impl.h index fc94ec0..ef32f15 100644 --- a/src/datasystem/worker/object_cache/worker_oc_service_impl.h +++ b/src/datasystem/worker/object_cache/worker_oc_service_impl.h @@ -205,8 +205,9 @@ public: * @param[in] taskId task id of voluntary scale down task, if task id is empty, it means we * careless about the task id. * @param[in] stage Migration sttrategy stage. + * @return Status of the call */ - void MigrateData(const std::vector &objectKeys, const std::string &taskId, + Status MigrateData(const std::vector &objectKeys, const std::string &taskId, MigrateStrategy::MigrationStrategyStage stage = MigrateStrategy::MigrationStrategyStage::FIRST); /** @@ -216,8 +217,9 @@ public: * @param[in] threadPool Migrate data thread pool. * @param[in] futures Migrate data futures. * @param[out] newFutures New added migrate data futures. + * @return Status of the call. */ - void HandleMigrateDataResult(const std::string &taskId, const std::shared_ptr progress, + Status HandleMigrateDataResult(const std::string &taskId, const std::shared_ptr progress, const std::unique_ptr &threadPool, std::vector> &futures, std::vector> &newFutures); @@ -1131,8 +1133,8 @@ private: std::mutex circularQueueMutex_; // To protect circularQueueManager_ std::vector> circularQueueManager_; - std::shared_timed_mutex clearIdsMutex_; - std::vector voluntaryScaleDownClearIds_ = {}; + std::shared_timed_mutex clearIdsMutex_; // to protect voluntaryScaleDownClearIds_ + std::vector voluntaryScaleDownClearIds_ = {}; // need clear ids before voluntary scaledown /** * the thread pool is only use for delete old version of object in l2cache. * diff --git a/src/datasystem/worker/stream_cache/CMakeLists.txt b/src/datasystem/worker/stream_cache/CMakeLists.txt new file mode 100644 index 0000000..a53a16c --- /dev/null +++ b/src/datasystem/worker/stream_cache/CMakeLists.txt @@ -0,0 +1,47 @@ +add_subdirectory(metrics) +add_subdirectory(page_queue) + +set(WORKER_SC_SRCS + stream_manager.cpp + subscription.cpp + consumer.cpp + worker_master_sc_api.cpp + producer.cpp + remote_worker_manager.cpp + client_worker_sc_service_impl.cpp + master_worker_sc_service_impl.cpp + worker_worker_sc_service_impl.cpp + buffer_pool.cpp + stream_data_pool.cpp + usage_monitor.cpp + worker_sc_allocate_memory.cpp + ) + +set(WORKER_SC_DEPEND_LIBS + ${TBB_LIBRARY} + ${SECUREC_LIBRARY} + common_log + common_util + common_perf + common_rpc_zmq + common_inject + common_event_loop + common_shared_memory + common_sc + sc_metrics + sc_page_queue + posix_protos + worker_stream_protos + master_stream_protos + common_encrypt + common_ak_sk + httpclient + worker_health_check + ) + +add_library(worker_stream_cache STATIC ${WORKER_SC_SRCS}) +target_link_libraries(worker_stream_cache PRIVATE ${WORKER_SC_DEPEND_LIBS}) +add_dependencies(worker_stream_cache + posix_protos + worker_stream_protos + master_stream_protos) diff --git a/src/datasystem/worker/stream_cache/buffer_pool.cpp b/src/datasystem/worker/stream_cache/buffer_pool.cpp new file mode 100644 index 0000000..8b7ef38 --- /dev/null +++ b/src/datasystem/worker/stream_cache/buffer_pool.cpp @@ -0,0 +1,661 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/log/trace.h" +#include "datasystem/common/util/container_util.h" +#include "datasystem/common/util/request_counter.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/worker/stream_cache/buffer_pool.h" +#include "datasystem/stream/stream_config.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { + +void BufferPool::Insert(const std::shared_ptr &ele) +{ + // Hash using the producer id and append to one of the dirty lists. + // Order of the elements from the same producer is important. + auto hashVal = ele->StreamHash(); + auto partitionID = hashVal % static_cast(numPartitions_); + auto &dirtyList = partitionList_[partitionID]->dirtyList_; + dirtyList.Append(ele); + // Wake up the i/o cleaner + dirtyList.cv_.notify_all(); +} + +void BufferPool::PurgeSortHeap(int partitionID, const std::string &streamName) +{ + auto &heapSortMapPartition = heapSortMapDict_[partitionID]; + WriteLock xlock(&heapSortMapPartition->mux_); + for (auto &m : heapSortMapPartition->heapSortMap_) { + if (m.first.firstKey_ != streamName) { + continue; + } + auto &heapSort = m.second; + WriteLock writeLock(&heapSort->mux_); + // Pop all elements and get rid of them + while (!heapSort->que_.empty()) { + auto v = heapSort->que_.top(); + heapSort->que_.pop(); + // We can discard it, but we will still pass it to the async thread + // so UsageMonitor can keep track of it. + partitionList_[partitionID]->dirtyList_.Append(v.first); + } + // Clear the sequence number + heapSort->expectedSeqNo_ = 0; + } +} + +void BufferPool::PurgeBuffer(const std::string &streamName, const EndOfStreamCallbackFn &fn) +{ + // Go through each partition + auto eos = std::make_shared(streamName, fn); + eos->numJobs_ = numPartitions_; + for (auto partitionID = 0; partitionID < numPartitions_; ++partitionID) { + PurgeSortHeap(partitionID, streamName); + auto &part = partitionList_[partitionID]; + // Send the EOS + part->dirtyList_.Append(eos); + part->dirtyList_.cv_.notify_all(); + } + // Now we wait + const int waitMs = 100; + while (eos->numJobs_ > 0 && !interrupt_) { + std::unique_lock lock(eos->mux_); + eos->cv_.wait_for(lock, std::chrono::milliseconds(waitMs), [&eos]() { return eos->numJobs_ == 0; }); + } +} + +void BufferPool::RemoveStream(const std::string &keyName, const std::string &sharedPageName) +{ + for (auto i = 0; i < numPartitions_; i++) { + // 1. get the producer keys to clean + auto &producerKeyMap = producerKeyMaps_[i]; + std::vector keysToErase; + std::vector keysToResetSeqNo; + { + std::shared_lock lk(producerKeyMap->mapMutex_); + auto it = producerKeyMap->producerKeyMap_.find(keyName); + if (it != producerKeyMap->producerKeyMap_.end()) { + keysToErase = it->second; + } + it = producerKeyMap->producerKeyMap_.find(sharedPageName); + if (it != producerKeyMap->producerKeyMap_.end()) { + keysToResetSeqNo = it->second; + } + } + // 2. clear heapSortMap_ + if (!keysToErase.empty()) { + auto &heapSortMapPartition = heapSortMapDict_[i]; + WriteLock xlock(&heapSortMapPartition->mux_); + for (auto &key : keysToErase) { + heapSortMapPartition->heapSortMap_.erase(key); + } + } + // 3. send notification to clear producerDirtyMap_ and producerKeyMap_ + if (!keysToErase.empty() || !keysToResetSeqNo.empty()) { + auto streamDestructData = + std::make_shared(keyName, std::move(keysToErase), std::move(keysToResetSeqNo)); + auto &part = partitionList_[i]; + part->dirtyList_.Append(streamDestructData); + part->dirtyList_.cv_.notify_all(); + } + } +} + +Status BufferPool::UnsortedInsert(std::shared_ptr ele, uint64_t seqNo, uint64_t firstSeqNo) +{ + auto hashVal = ele->StreamHash(); + int partitionID = hashVal % numPartitions_; + auto key = StreamProducerKey(ele->KeyName(), ele->ProducerName(), ele->ProducerInstanceId()); + auto &heapSortMapPartition = heapSortMapDict_[partitionID]; + WriteLock xlock(&heapSortMapPartition->mux_); + auto it = heapSortMapPartition->heapSortMap_.find(key); + if (it == heapSortMapPartition->heapSortMap_.end()) { + bool success; + std::tie(it, success) = heapSortMapPartition->heapSortMap_.emplace(key, std::make_unique()); + auto &producerKeyMap = producerKeyMaps_[partitionID]; + std::lock_guard lk(producerKeyMap->mapMutex_); + producerKeyMap->producerKeyMap_[key.firstKey_].emplace_back(key); + } + auto &heapSort = it->second; + xlock.UnlockIfLocked(); + WriteLock writeLock(&heapSort->mux_); + BaseData p(std::move(ele), seqNo); + const std::string &streamName = p.first->StreamName(); + const std::string &producerId = p.first->ProducerName(); + // If the sequence number is duplicated, ignore it. We already have a copy + if (seqNo < heapSort->expectedSeqNo_) { + RETURN_STATUS_LOG_ERROR(K_DUPLICATED, FormatString("[S:%s P:%s] duplicated seqNo %zu expecting %zu.", + streamName, producerId, seqNo, heapSort->expectedSeqNo_)); + } + // Insert the element into the heap sort + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s P:%s] Push seqNo %zu to heapsort.", streamName, producerId, + seqNo); + heapSort->que_.push(std::move(p)); + // Check if the top is what we are looking for. + // Remote worker will send us the range of sequence it is sending + auto topSeq = heapSort->que_.top().second; + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString( + "[S:%s P:%s] Top %zu. Expecting %zu or %zu", heapSort->que_.top().first->StreamName(), + heapSort->que_.top().first->ProducerName(), topSeq, firstSeqNo, heapSort->expectedSeqNo_); + while (topSeq == firstSeqNo || topSeq == heapSort->expectedSeqNo_) { + // Pop the element and increment the next expected + auto v = heapSort->que_.top(); + heapSort->que_.pop(); + heapSort->expectedSeqNo_ = v.second + 1; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s P:%s] Pop seqNo %zu from heapsort.", v.first->StreamName(), + v.first->ProducerName(), v.second); + partitionList_[partitionID]->dirtyList_.Append(v.first); + if (heapSort->que_.empty()) { + break; + } + topSeq = heapSort->que_.top().second; + } + partitionList_[partitionID]->dirtyList_.cv_.notify_all(); + return Status::OK(); +} + +void BufferPool::ProcessEoSEntries(const StreamProducerKey &key, std::list &producerDirtyList) +{ + // For eos, just pass the remaining buffer to the call back function. + LOG(INFO) << FormatString("[%s] Processing EoS entries. Size of list %zu", key.ToString(), + producerDirtyList.size()); + do { + auto iter = std::find_if(producerDirtyList.begin(), producerDirtyList.end(), + [](const auto &kv) { return kv.first->IsEoS(); }); + if (iter == producerDirtyList.end()) { + break; + } + std::list dataLst; + dataLst.splice(dataLst.end(), producerDirtyList, producerDirtyList.begin(), iter); + auto eos = std::move(*iter); + producerDirtyList.erase(iter); + auto func = std::static_pointer_cast(eos.first); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Purging %zu buffers", key.ToString(), dataLst.size()); + LOG_IF_ERROR((*func)(dataLst, key.firstKey_, key.producerId_), ""); + func->numJobs_--; + } while (!producerDirtyList.empty()); + if (!producerDirtyList.empty()) { + LOG(INFO) << FormatString("[%s] %zu buffers remained after EoS", key.ToString(), producerDirtyList.size()); + } +} + +Status BufferPool::BatchAsyncFlush(int partitionID, std::vector &streamList) +{ + INJECT_POINT("BufferPool.BatchAsyncFlush"); + if (VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL)) { + std::ostringstream oss; + oss << FormatString("Flushing partition %d. Number of producers %d: ", partitionID, streamList.size()); + for (auto &ele : streamList) { + oss << FormatString("[%s] ", ele.ToString()); + } + LOG(INFO) << oss.str(); + } + auto &myPartition = partitionList_[partitionID]; + PendingFlushList flushList; + bool eosInjected = true; + if (myPartition->eosInjected_.compare_exchange_strong(eosInjected, false)) { + LOG(INFO) << FormatString("EoS detected for partition %zu", partitionID); + } + for (auto &key : streamList) { + auto dirtyIt = myPartition->producerDirtyMap_.find(key); + CHECK_FAIL_RETURN_STATUS(dirtyIt != myPartition->producerDirtyMap_.end(), K_RUNTIME_ERROR, + key.ToString() + " not found in producerDirtyMap_"); + auto &producerDirtyList = dirtyIt->second->list_; + if (producerDirtyList.empty()) { + continue; + } + if (eosInjected) { + // Buffers before the eos will be passed to the call back function. + ProcessEoSEntries(key, producerDirtyList); + } + if (!producerDirtyList.empty()) { + flushList.emplace_back(key, producerDirtyList); + } + } + if (flushList.empty()) { + // streamList is the list of streams failing to flush out on return. + streamList.clear(); + return Status::OK(); + } + Status rc = batchFlushFn_(partitionID, flushList); + // Check which one still has pending elements to flush + std::vector failList; + for (auto &ele : flushList) { + if (ele.second.empty()) { + continue; + } + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("[%s] fails to flush out %zu elements", ele.first.ToString(), + ele.second.size()); + failList.emplace_back(ele.first); + } + streamList.swap(failList); + return rc; +} + +void BufferPool::InjectEoS(const std::shared_ptr &eos, + std::unordered_map> &map, + std::vector &fifo) +{ + const std::string streamName = eos->streamName_; + for (auto &ele : map) { + if (ele.first.firstKey_ == streamName) { + eos->numJobs_++; + ele.second->list_.emplace_back(eos, std::numeric_limits::max()); + auto key = ele.first; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("EoS inserted for producer/worker %s", ele.first.producerId_); + if (std::find_if(fifo.begin(), fifo.end(), [&key](StreamProducerKey &ele) { return ele == key; }) + == fifo.end()) { + fifo.emplace_back(key); + } + } + } +} + +void BufferPool::ClearProducerKeyMap(int partitionID, const std::string &streamName, + const std::vector &keys) +{ + auto &producerKeyMap = producerKeyMaps_[partitionID]; + std::lock_guard lk(producerKeyMap->mapMutex_); + (void)EraseIf(producerKeyMap->producerKeyMap_[streamName], + [&keys](const StreamProducerKey &key) { return ContainsKey(keys, key); }); + if (producerKeyMap->producerKeyMap_[streamName].empty()) { + producerKeyMap->producerKeyMap_.erase(streamName); + } +} + +std::vector BufferPool::FetchDirtyList(int partitionID, std::vector &discardKeys) +{ + std::vector fifo; // Track uniqueness and maintain fifo + auto &myPartition = partitionList_[partitionID]; + auto &dirtyList = myPartition->dirtyList_; + auto fifoList = dirtyList.GetAll(); + auto &producerDirtyMap = myPartition->producerDirtyMap_; + auto it = fifoList.begin(); + while (it != fifoList.end()) { + auto keyName = it->first->KeyName(); + auto producerId = it->first->ProducerName(); + auto instanceId = it->first->ProducerInstanceId(); + auto streamDestructData = std::dynamic_pointer_cast(it->first); + if (streamDestructData) { + auto &eraseKeys = streamDestructData->GetProducerKeysToErase(); + auto &resetKeys = streamDestructData->GetProducerKeysToReset(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "Process streamDestructData: keyName=%s, eraseKeys=[%s], resetKeys=[%s]", keyName, + VectorToString(eraseKeys), VectorToString(resetKeys)); + discardKeys.insert(discardKeys.end(), eraseKeys.begin(), eraseKeys.end()); + // 1. clear producerDirtyMap + for (auto &key : eraseKeys) { + producerDirtyMap.erase(key); + } + for (auto &key : resetKeys) { + auto it = producerDirtyMap.find(key); + if (it != producerDirtyMap.end()) { + (void)it->second->seqNo_.erase(keyName); + } + } + // 2. clear producer key map + ClearProducerKeyMap(partitionID, keyName, eraseKeys); + // 3. discard keys in fifo + std::string log; + EraseIf(fifo, [&keyName, &log](const StreamProducerKey &key) { + if (keyName == key.firstKey_) { + log += key.ToString() + ", "; + return true; + } + return false; + }); + LOG_IF(INFO, !log.empty()) << "Discard streamdata of " << log; + } else if (!it->first->IsEoS()) { + StreamProducerKey key(keyName, producerId, instanceId); + auto iter = producerDirtyMap.find(key); + if (iter == producerDirtyMap.end()) { + iter = producerDirtyMap.emplace(key, std::make_unique()).first; + auto &producerKeyMap = producerKeyMaps_[partitionID]; + std::lock_guard lk(producerKeyMap->mapMutex_); + producerKeyMap->producerKeyMap_[key.firstKey_].emplace_back(key); + } + if (std::find_if(fifo.begin(), fifo.end(), [&key](StreamProducerKey &ele) { return ele == key; }) + == fifo.end()) { + fifo.emplace_back(key); + } + auto &producerDirtyList = iter->second->list_; + // Assign a sequence number relative to this stream/producer + auto fetchAddSeqNoFunc = [iter](const std::string &streamName) { + return iter->second->FetchAddSeqNo(streamName); + }; + it->second = it->first->RecordSeqNo(fetchAddSeqNoFunc); + producerDirtyList.emplace_back(*it); + } else { + // EoS + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("EoS detected for S:%s", keyName); + myPartition->eosInjected_ = true; + auto eos = std::static_pointer_cast(it->first); + InjectEoS(eos, producerDirtyMap, fifo); + eos->numJobs_--; + } + it = fifoList.erase(it); + } + return fifo; +} + +void BufferPool::ReleaseBuffers(int partitionID) +{ + auto &part = partitionList_[partitionID]; + auto &dirtyList = part->dirtyList_; + std::lock_guard lock(dirtyList.mux_); + if (!dirtyList.Empty()) { + LOG(WARNING) << FormatString("AsyncFlush thread %d exits but %d remaining fifo requests", partitionID, + dirtyList.Size()); + for (auto &ele : dirtyList.list_) { + (void)ele.first->ReleasePage(); + } + } + for (auto &dirtyStream : part->producerDirtyMap_) { + auto &streamDirtyList = dirtyStream.second->list_; + if (!streamDirtyList.empty()) { + LOG(WARNING) << FormatString("AsyncFlush thread %d exits but stream %s producer %s %zu remaining requests", + partitionID, dirtyStream.first.firstKey_, dirtyStream.first.producerId_, + streamDirtyList.size()); + for (auto &ele : streamDirtyList) { + LOG(WARNING) << FormatString("[%s] Element seqNo %zu not sent", dirtyStream.first.ToString(), + ele.second); + (void)ele.first->ReleasePage(); + } + } + } +} + +bool BufferPool::HaveTasksToProcess() +{ + for (const auto &partition : partitionList_) { + std::lock_guard lck(partition->dirtyList_.mux_); + if (!partition->dirtyList_.Empty()) { + return true; + } + } + return isAsynFlushing_; +} + +void BufferPool::AsyncFlushEntry(int partitionID) +{ + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("AsyncFlush thread %d starts up", partitionID); + auto &part = partitionList_[partitionID]; + auto &dirtyList = part->dirtyList_; + const uint64_t timeoutMs = 10; + std::vector pendingFlushList; + Status rc; + while (true) { + // Wait on the cv for 0.1s for work or interrupt + auto hasWorkToDo = dirtyList.WaitForNotEmpty(timeoutMs); + if (interrupt_) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("AsyncFlush thread %d exits", partitionID); + break; + } + auto traceGuard = Trace::Instance().SetTraceUUID(); + std::vector flushList; + std::vector discardKeys; + if (hasWorkToDo) { + flushList = FetchDirtyList(partitionID, discardKeys); + } + // Append whatever we need to resume from last time + for (auto &ele : pendingFlushList) { + if (!ContainsKey(flushList, ele) && !ContainsKey(discardKeys, ele)) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] resume", ele.ToString()); + flushList.push_back(ele); + } + } + pendingFlushList.clear(); + + if (flushList.empty()) { + isAsynFlushing_ = false; + continue; + } + isAsynFlushing_ = true; + // Ensure that no data is sent to remote worker. + RequestCounter::GetInstance().ResetLastArrivalTime("BufferPool::AsyncFlushEntry"); + rc = BatchAsyncFlush(partitionID, flushList); + if (rc.IsError()) { + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(WARNING, logPerCount) << "BatchAsyncFlush failed " << rc.ToString(); + } + // flushList on return is the list of streams we fail to flush out. + // Save them and resume them in the future + pendingFlushList = std::move(flushList); + isAsynFlushing_ = !pendingFlushList.empty(); + for (auto &ele : pendingFlushList) { + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("[%s] remains pending flush", ele.ToString()); + } + if (rc.IsOk() && !pendingFlushList.empty()) { + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString("[%d] eles is pending flush", pendingFlushList.size()); + } + } + isAsynFlushing_ = false; + ReleaseBuffers(partitionID); +} + +Status BufferPool::Init() +{ + try { + for (auto i = 0; i < numPartitions_; ++i) { + thrd_->Execute([this, i]() { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + AsyncFlushEntry(i); + }); + } + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS(K_RUNTIME_ERROR, e.what()); + } +} + +void BufferPool::Stop() +{ + interrupt_ = true; + for (auto &part : partitionList_) { + part->dirtyList_.cv_.notify_all(); + } + if (thrd_) { + thrd_.reset(); + } +} + +BufferPool::BufferPool(int numPartitions, const std::string &name, BatchFlushCallbackFn f) + : name_(name), + interrupt_(false), + numPartitions_(std::max(1, numPartitions)), + thrd_(std::make_unique(numPartitions_, 0, name)), + batchFlushFn_(std::move(f)) +{ + partitionList_.reserve(numPartitions_); + heapSortMapDict_.reserve(numPartitions); + producerKeyMaps_.reserve(numPartitions); + for (auto i = 0; i < numPartitions_; ++i) { + partitionList_.emplace_back(std::make_unique()); + heapSortMapDict_.emplace_back(std::make_unique()); + producerKeyMaps_.emplace_back(std::make_unique()); + } +} + +BufferPool::~BufferPool() +{ + Stop(); + + std::string logs; + for (auto partitionID = 0; partitionID < numPartitions_; partitionID++) { + std::string log; + auto size = heapSortMapDict_[partitionID]->heapSortMap_.size(); + log += size == 0 ? "" : "heapSortMap_ size = " + std::to_string(size); + size = partitionList_[partitionID]->dirtyList_.Size(); + log += size == 0 ? "" : ", dirtyList_ size = " + std::to_string(size); + size = partitionList_[partitionID]->producerDirtyMap_.size(); + log += size == 0 ? "" : ", producerDirtyMap_ size = " + std::to_string(size); + size = producerKeyMaps_[partitionID]->producerKeyMap_.size(); + log += size == 0 ? "" : ", producerKeyMap_ size = " + std::to_string(size); + if (!log.empty()) { + logs += "Partition " + std::to_string(partitionID) + ": " + log; + } + } + LOG_IF(WARNING, !logs.empty()) << "~BufferPool " << name_ << " with: " << logs; +} + +void ConcurrentList::Append(const std::shared_ptr &p) +{ + std::unique_lock lock(mux_); + list_.emplace_back(p, 0); +} + +std::list ConcurrentList::GetAll() +{ + std::list dataList; + std::unique_lock lock(mux_); + dataList.swap(list_); + return dataList; +} + +bool ConcurrentList::WaitForNotEmpty(uint64_t timeoutMs) +{ + std::unique_lock lock(mux_); + return cv_.wait_for(lock, std::chrono::milliseconds(timeoutMs), [this]() { return !Empty(); }); +} + +bool StreamProducerKey::operator==(const StreamProducerKey &rhs) const +{ + if (rhs.firstKey_ == firstKey_ && rhs.producerId_ == producerId_ + && rhs.producerInstanceId_ == producerInstanceId_) { + return true; + } + return false; +} + +EndOfStreamBufferData::EndOfStreamBufferData(std::string streamName, EndOfStreamCallbackFn fn) + : BaseBufferData(), callbackFn_(std::move(fn)), streamName_(std::move(streamName)), numJobs_(0) +{ + eos = true; +} + +std::string EndOfStreamBufferData::StreamName() const +{ + return streamName_; +} + +std::string EndOfStreamBufferData::ProducerName() const +{ + return ""; +} + +std::string EndOfStreamBufferData::ProducerInstanceId() const +{ + return ""; +} + +uint64_t EndOfStreamBufferData::StreamHash() const +{ + return 0; +} +Status EndOfStreamBufferData::ReleasePage() +{ + return Status::OK(); +} + +StreamDestructData::StreamDestructData(std::string streamName, std::vector producerKeysToErase, + std::vector producerKeysToReset) + : BaseBufferData(), + streamName_(std::move(streamName)), + producerKeysToErase_(std::move(producerKeysToErase)), + producerKeysToReset_(std::move(producerKeysToReset)) +{ +} + +std::string StreamDestructData::StreamName() const +{ + return streamName_; +} + +std::string StreamDestructData::ProducerName() const +{ + return ""; +} + +std::string StreamDestructData::ProducerInstanceId() const +{ + return ""; +} + +uint64_t StreamDestructData::StreamHash() const +{ + return 0; +} + +Status StreamDestructData::ReleasePage() +{ + return Status::OK(); +} + +uint64_t BufferPool::ProducerDirtyList::FetchAddSeqNo(const std::string &streamName) +{ + std::unique_lock lock(mux_); + auto it = seqNo_.find(streamName); + if (it == seqNo_.end()) { + // Sequence number starts from 1. + bool success; + std::tie(it, success) = seqNo_.emplace(streamName, 1); + } + uint64_t ret = it->second; + it->second++; + return ret; +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem + +namespace std { +size_t hash::operator()( + const datasystem::worker::stream_cache::StreamProducerKey &key) const +{ + // This is the golden ratio. (2^64) / (( 1 + sqrt(5)) / 2) + constexpr static uint64_t MAGIC_HASH = 0x9E3779B97F4A7C15ul; + constexpr static uint64_t LEFT_SHIFT = 6; + constexpr static uint64_t RIGHT_SHIFT = 2; + std::vector v{ key.firstKey_, key.producerId_, key.producerInstanceId_ }; + size_t seed = 0; + for (auto &str : v) { + auto val = std::hash{}(str); + seed ^= val + MAGIC_HASH + (seed << LEFT_SHIFT) + (seed >> RIGHT_SHIFT); + } + return seed; +} + +bool less::operator()( + const datasystem::worker::stream_cache::StreamProducerKey &lhs, + const datasystem::worker::stream_cache::StreamProducerKey &rhs) const +{ + if (lhs.firstKey_ == rhs.firstKey_ && lhs.producerId_ == rhs.producerId_) { + return std::less{}(lhs.producerInstanceId_, rhs.producerInstanceId_); + } + if (lhs.firstKey_ == rhs.firstKey_) { + return std::less{}(lhs.producerId_, rhs.producerId_); + } + return std::less{}(lhs.firstKey_, rhs.firstKey_); +} +} // namespace std diff --git a/src/datasystem/worker/stream_cache/buffer_pool.h b/src/datasystem/worker/stream_cache/buffer_pool.h new file mode 100644 index 0000000..3389adc --- /dev/null +++ b/src/datasystem/worker/stream_cache/buffer_pool.h @@ -0,0 +1,344 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Buffer pool + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_BUFFER_POOL_H +#define DATASYSTEM_WORKER_STREAM_CACHE_BUFFER_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datasystem/common/log/log.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/locks.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/thread_pool.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +/** + * @brief Simple key/hash for std::unordered_map consists of worker address and producer id. + */ +struct StreamProducerKey { + std::string firstKey_; + std::string producerId_; + std::string producerInstanceId_; + + StreamProducerKey(std::string firstKey, std::string producerId, std::string producerInstanceId) + : firstKey_(std::move(firstKey)), + producerId_(std::move(producerId)), + producerInstanceId_(std::move(producerInstanceId)) + { + } + ~StreamProducerKey() = default; + + bool operator==(const StreamProducerKey &rhs) const; + + [[nodiscard]] std::string ToString() const + { + return FormatString("K:%s P:%s I:%s", firstKey_, producerId_, producerInstanceId_); + } + + friend std::ostream &operator<<(std::ostream &out, const StreamProducerKey &key) + { + out << key.ToString(); + return out; + } +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem + +namespace std { +template <> +struct hash { + size_t operator()(const datasystem::worker::stream_cache::StreamProducerKey &key) const; +}; +template <> +struct less { + bool operator()(const datasystem::worker::stream_cache::StreamProducerKey &lhs, + const datasystem::worker::stream_cache::StreamProducerKey &rhs) const; +}; +} // namespace std + +namespace datasystem { +namespace worker { +namespace stream_cache { +class BaseBufferData { +public: + BaseBufferData() = default; + virtual ~BaseBufferData() = default; + + [[nodiscard]] virtual std::string StreamName() const = 0; + [[nodiscard]] virtual std::string ProducerName() const = 0; + [[nodiscard]] virtual std::string ProducerInstanceId() const = 0; + [[nodiscard]] virtual uint64_t StreamHash() const = 0; + virtual Status ReleasePage() = 0; + + bool IsEoS() const + { + return eos; + } + + [[nodiscard]] virtual std::string KeyName() const + { + return StreamName(); + }; + + virtual uint64_t RecordSeqNo(std::function fetchAddSeqNo) + { + return fetchAddSeqNo(StreamName()); + }; + + std::string traceId_; + +protected: + bool eos = false; +}; + +using BaseData = std::pair, uint64_t>; + +// A special instance of BaseBufferData that marks the end of stream +using EndOfStreamCallbackFn = std::function, const std::string &, const std::string &)>; +class EndOfStreamBufferData : public BaseBufferData { +public: + EndOfStreamBufferData(std::string streamName, EndOfStreamCallbackFn f); + ~EndOfStreamBufferData() override = default; + std::string StreamName() const override; + std::string ProducerName() const override; + std::string ProducerInstanceId() const override; + uint64_t StreamHash() const override; + Status ReleasePage() override; + + /** + * @brief Call back functor + * @param Buffer not yet flushed. + * @return Status object + */ + Status operator()(std::list p, const std::string &s1, const std::string &s2) + { + return callbackFn_(std::move(p), s1, s2); + } + +private: + friend class BufferPool; + EndOfStreamCallbackFn callbackFn_; + const std::string streamName_; + std::atomic numJobs_; + std::mutex mux_; + std::condition_variable cv_; +}; + +class StreamDestructData : public BaseBufferData { +public: + StreamDestructData(std::string streamName, std::vector producerKeysToErase, + std::vector producerKeysToReset); + ~StreamDestructData() override = default; + std::string StreamName() const override; + std::string ProducerName() const override; + std::string ProducerInstanceId() const override; + uint64_t StreamHash() const override; + Status ReleasePage() override; + const std::vector &GetProducerKeysToErase() const + { + return producerKeysToErase_; + } + const std::vector &GetProducerKeysToReset() const + { + return producerKeysToReset_; + } + +private: + const std::string streamName_; + const std::vector producerKeysToErase_; + const std::vector producerKeysToReset_; +}; + +class ConcurrentList { +public: + ConcurrentList() = default; + virtual ~ConcurrentList() = default; + + /** + * @brief Check if the list is empty + * @return T/F + */ + bool Empty() const + { + return list_.empty(); + } + + /** + * @brief Return the size of the list + * @return number of element in the list + */ + auto Size() const + { + return list_.size(); + } + + bool WaitForNotEmpty(uint64_t timeoutMs); + virtual void Append(const std::shared_ptr &p); + std::list GetAll(); + +protected: + friend class BufferPool; + mutable std::mutex mux_; + std::condition_variable cv_; + std::list list_; +}; + +using PendingFlushList = std::vector &>>; +using BatchFlushCallbackFn = std::function; + +class BufferPool { +public: + /** + * @brief Constructor + * @param[in] numPartitions Number of partitions + * @param[in] name The name of the buffer pool. + * @param[in] f Call back function + */ + explicit BufferPool(int numPartitions, const std::string &name, BatchFlushCallbackFn f); + + ~BufferPool(); + + /** + * @brief Init function + * @return + */ + Status Init(); + + /** + * @brief Shutdown buffer pool + */ + void Stop(); + + /** + * @brief Insert a buffer + * @param ele + */ + void Insert(const std::shared_ptr &ele); + + Status UnsortedInsert(std::shared_ptr ele, uint64_t seqNo, uint64_t firstSeqNo); + + void PurgeBuffer(const std::string &streamName, const EndOfStreamCallbackFn &fn); + + /** + * @brief Remove the info of useless stream from BufferPool + * @param keyName The stream name or page name. + * @param sharedPageName The shared page name. Empty if the stream uses exclusive page or the keyName is page. + */ + void RemoveStream(const std::string &keyName, const std::string &sharedPageName); + + /** + * @brief Check if there are tasks to be processed + * @return T/F + */ + bool HaveTasksToProcess(); + + // For heap sort + struct Compare { + bool operator()(const BaseData &a, const BaseData &b) + { + return a.second > b.second; + } + }; + +private: + void AsyncFlushEntry(int partitionID); + + /** + * @brief Get dirty producerKeys to be processed. + * @param[in] numPartitions Number of partitions. + * @param[in] discardKeys The discard producerKeys. + * @return The producerKeys to be processed. + */ + std::vector FetchDirtyList(int partitionID, std::vector &discardKeys); + void ReleaseBuffers(int partitionID); + Status BatchAsyncFlush(int partitionID, std::vector &streamList); + void PurgeSortHeap(int partitionID, const std::string &streamName); + void ProcessEoSEntries(const StreamProducerKey &key, std::list &producerDirtyList); + + /** + * @brief Clear producerKeyMap_ + * @param[in] numPartitions Number of partitions + * @param[in] streamName The stream name. + * @param[in] keys The producerKeys. + */ + void ClearProducerKeyMap(int partitionID, const std::string &streamName, + const std::vector &keys); + + std::string name_; + std::atomic interrupt_; + const int numPartitions_; + + struct ProducerKeyMap { + std::shared_timed_mutex mapMutex_; // Protect producerKeyMap_ + std::unordered_map> + producerKeyMap_; // Key: streamName/sharedPageName + }; + std::vector> producerKeyMaps_; // The reverse lookup table of StreamProducerKey + + struct ProducerDirtyList { + std::list list_; + mutable std::shared_mutex mux_; + std::unordered_map seqNo_; // , use in sharedPage + ProducerDirtyList() = default; + ~ProducerDirtyList() = default; + uint64_t FetchAddSeqNo(const std::string &streamName); + }; + struct HeapSort { + mutable WriterPrefRWLock mux_; + std::priority_queue, Compare> que_; + std::atomic expectedSeqNo_{ 0 }; + }; + // Each partition consists of a FIFO dirty list, and per producer dirty list, and a heap sort + struct Partition { + ConcurrentList dirtyList_; // FIFO + std::unordered_map> + producerDirtyMap_; // Key: streamName/sharedPageName + std::atomic eosInjected_{ false }; + }; + void InjectEoS(const std::shared_ptr &eos, + std::unordered_map> &map, + std::vector &fifo); + + std::vector> partitionList_; + struct HeapSortPartition { + mutable WriterPrefRWLock mux_; + std::unordered_map> heapSortMap_; + }; + std::vector> heapSortMapDict_; + std::unique_ptr thrd_; // One for each partition + BatchFlushCallbackFn batchFlushFn_; + std::atomic isAsynFlushing_{ false }; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif // DATASYSTEM_WORKER_STREAM_CACHE_BUFFER_POOL_H diff --git a/src/datasystem/worker/stream_cache/client_worker_sc_service_impl.cpp b/src/datasystem/worker/stream_cache/client_worker_sc_service_impl.cpp new file mode 100644 index 0000000..298a573 --- /dev/null +++ b/src/datasystem/worker/stream_cache/client_worker_sc_service_impl.cpp @@ -0,0 +1,2585 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" + +#include +#include +#include +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/iam/tenant_auth_manager.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/log/access_recorder.h" +#include "datasystem/common/log/log_helper.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/common/rpc/rpc_stub_cache_mgr.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/stream_cache/util.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/thread_local.h" +#include "datasystem/common/util/validator.h" +#include "datasystem/master/stream_cache/master_sc_service_impl.h" +#include "datasystem/protos/stream_posix.service.rpc.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/optional.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/cluster_event_type.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/stream_manager.h" +#include "datasystem/worker/cluster_manager/worker_health_check.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" + +DS_DECLARE_uint32(page_size); +DS_DECLARE_int32(zmq_chunk_sz); +DS_DECLARE_uint64(client_dead_timeout_s); +DS_DEFINE_int32(sc_thread_num, 128, "Number of threads for (non rpc) stream cache service work"); +DS_DEFINE_validator(sc_thread_num, &Validator::ValidateThreadNum); +DS_DEFINE_uint32(sc_gc_interval_ms, 50, "Memory resource clean up interval. Default to 50ms"); +DS_DEFINE_bool(enable_stream_data_verification, false, "Option to verify if data from a producer is out of order"); +DS_DECLARE_uint32(sc_shared_page_size_mb); + +namespace datasystem { +namespace worker { +namespace stream_cache { +static const std::string CLIENT_WORKER_SC_SERVICE_IMPL = "ClientWorkerSCServiceImpl"; +template class BlockedCreateRequest; +template class MemAllocRequestList; +template class BlockedCreateRequest; +template class MemAllocRequestList; +ClientWorkerSCServiceImpl::ClientWorkerSCServiceImpl(HostPort serverAddr, HostPort masterAddr, + master::MasterSCServiceImpl *masterSCService, + std::shared_ptr akSkManager, + std::shared_ptr manager) + : localWorkerAddress_(std::move(serverAddr)), + masterAddress_(std::move(masterAddr)), + scAllocateManager_(std::move(manager)), + akSkManager_(std::move(akSkManager)), + interrupt_(false) +{ + workerMasterApiManager_ = + std::make_shared(localWorkerAddress_, akSkManager_, masterSCService); +} + +Status ClientWorkerSCServiceImpl::Init() +{ + remoteWorkerManager_ = std::make_unique(this, akSkManager_, scAllocateManager_); + RETURN_IF_NOT_OK(remoteWorkerManager_->Init()); + LOG(INFO) << FormatString("[%S] Initialize success", LogPrefix()); + // Create a thread pool for async request handling in the service + const size_t MIN_THREADS = 4; + size_t minThreads = std::min(MIN_THREADS, FLAGS_sc_thread_num); + RETURN_IF_EXCEPTION_OCCURS(threadPool_ = + std::make_shared(minThreads, FLAGS_sc_thread_num, "ScThreads")); + // Also create a similar pool but just the purpose of managing CreateShmPage and AllocBigShmMemory rpc + // (both internal and external) + RETURN_IF_EXCEPTION_OCCURS(memAllocPool_ = + std::make_shared(minThreads, FLAGS_sc_thread_num, "memThreads")); + // We will further let another thread pool to do the memory cleanup work. + // But we don't want to overload the pool if we have thousands of streams, and + // we will limit to a small number of threads at a time. + constexpr size_t NUM_ACK_THREADS = 2; + RETURN_IF_EXCEPTION_OCCURS(ackPool_ = std::make_shared(NUM_ACK_THREADS, NUM_ACK_THREADS, "ackThreads")); + // Kick off a thread to do garbage collection + autoAck_ = threadPool_->Submit([this]() { + auto traceId = GetStringUuid(); + auto traceGuard = Trace::Instance().SetTraceNewID(traceId); + LOG(INFO) << FormatString("[%s] Ack thread starts.", LogPrefix()); + const uint64_t timeoutS = FLAGS_client_dead_timeout_s + 5; + std::deque ackList; + while (!interrupt_) { + AutoAckImpl(ackList, timeoutS); + std::this_thread::sleep_for(std::chrono::milliseconds(FLAGS_sc_gc_interval_ms)); + } + }); + threadPool_->SetWarnLevel(ThreadPool::WarnLevel::LOW); + ackPool_->SetWarnLevel(ThreadPool::WarnLevel::LOW); + EraseFailedNodeApiEvent::GetInstance().AddSubscriber(CLIENT_WORKER_SC_SERVICE_IMPL, + [this](HostPort &node) { EraseFailedWorkerMasterApi(node); }); + LOG_IF_ERROR(ScMetricsMonitor::Instance()->StartMonitor(), "Failed to start ScMetrics monitor"); + return ClientWorkerSCService::Init(); +} + +Status ClientWorkerSCServiceImpl::ValidateWorkerState() +{ + if (!IsHealthy()) { + RETURN_STATUS(K_NOT_READY, "Worker not ready"); + } + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::CreateProducer( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + CreateProducerReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + + auto recorder = std::make_shared(AccessRecorderKey::DS_POSIX_CREATE_PRODUCER); + recorder->reqParam.streamName = req.stream_name(); + recorder->reqParam.producerId = req.producer_id(); + recorder->reqParam.pageSize = Optional(req.page_size()); + recorder->reqParam.maxStreamSize = Optional(req.max_stream_size()); + recorder->reqParam.autoCleanup = Optional(req.auto_cleanup()); + recorder->reqParam.retainForNumConsumers = Optional(req.retain_num_consumer()); + recorder->reqParam.encryptStream = Optional(req.encrypt_stream()); + recorder->reqParam.reserveSize = Optional(req.reserve_size()); + recorder->reqParam.streamMode = Optional(req.stream_mode()); + auto rc = CreateProducerInternal(req, recorder, serverApi); + recorder->SetStatus(rc); + return rc; +} + +Status ClientWorkerSCServiceImpl::CreateProducerInternal( + const CreateProducerReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi) +{ + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + LOG(INFO) << "Worker received CreateProducer request: " << LogHelper::IgnoreSensitive(req); + TimeoutDuration parentDuration = scTimeoutDuration; + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckConnection(namespaceUri), "worker check connection failed"); + + PerfPoint point(PerfKey::WORKER_CREATE_PRODUCER_ALL); + + // The real work of the close will be driven in another thread. Launch it now and then release this current thread + // so that it does not hold up the rpc threads. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration = parentDuration; // lambda capture gets the parents copy. assign the copy to thread local + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + CreateProducerRspPb rsp; + Status rc = CreateProducerImpl(namespaceUri, req, rsp); + CheckErrorReturn(rc, rsp, FormatString("[S:%s] CreateProducerImpl failed with rc ", namespaceUri), serverApi); + recorder->SetStatus(rc); + if (rc.IsOk()) { + recorder->rspParam.senderProducerNo = Optional(rsp.sender_producer_no()); + recorder->rspParam.enableDataVerification = Optional(rsp.enable_data_verification()); + recorder->rspParam.streamNo = Optional(rsp.stream_no()); + recorder->rspParam.sharedPageSize = Optional(rsp.shared_page_size()); + recorder->rspParam.enableSharedPage = Optional(rsp.enable_shared_page()); + } + // recorder should destroy before traceGuard, otherwise, the traceid will be cleaned up + recorder.reset(); + }); + + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::GetPrimaryReplicaAddr(const std::string &srcAddr, HostPort &destAddr) +{ + std::string dbName; + RETURN_IF_NOT_OK(etcdCM_->GetPrimaryReplicaLocationByAddr(srcAddr, destAddr, dbName)); + g_MetaRocksDbName = dbName; + return Status::OK(); +} + +void ClientWorkerSCServiceImpl::ConstructCreateProducerPb(const std::string &streamName, + const Optional &streamFields, + master::CreateProducerReqPb &out) const noexcept +{ + DLOG(INFO) << "Start to construct ProducerMetaPb"; + auto &producerMetaPb = *out.mutable_producer_meta(); + producerMetaPb.set_stream_name(streamName); + producerMetaPb.mutable_worker_address()->set_host(localWorkerAddress_.Host()); + producerMetaPb.mutable_worker_address()->set_port(localWorkerAddress_.Port()); + out.set_max_stream_size(streamFields->maxStreamSize_); + out.set_page_size(streamFields->pageSize_); + out.set_auto_cleanup(streamFields->autoCleanup_); + out.set_retain_num_consumer(streamFields->retainForNumConsumers_); + out.set_encrypt_stream(streamFields->encryptStream_); + out.set_reserve_size(streamFields->reserveSize_); + out.set_stream_mode(streamFields->streamMode_); + out.set_redirect(true); +} + +Status ClientWorkerSCServiceImpl::CreateProducerHandleSend(std::shared_ptr api, + const std::string &streamName, + const Optional &streamFields) +{ + auto createProducerFn = [&]() { + master::CreateProducerReqPb masterReq; + ConstructCreateProducerPb(streamName, streamFields, masterReq); + master::CreateProducerRspPb masterRsp; + std::function func = + [&api](master::CreateProducerReqPb &req, master::CreateProducerRspPb &rsp) { + return api->CreateProducer(req, rsp); + }; + RETURN_IF_NOT_OK(RedirectRetryWhenMetaMoving(masterReq, masterRsp, api, func)); + return Status::OK(); + }; + return WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(createProducerFn)); +} + +Status ClientWorkerSCServiceImpl::CreateProducerImpl(const std::string &namespaceUri, const CreateProducerReqPb &req, + CreateProducerRspPb &rsp) +{ + // We need to serialize on CreateProducer and Subscribe. + // That is, we must wait for the previous one to finish (or rollback) successfully + // before the next pub/sub can proceed + CreatePubSubCtrl::Accessor createLock; + createStreamLocks_.Insert(createLock, namespaceUri); + Raii releaseLock([this, &createLock] { createStreamLocks_.TryErase(createLock); }); + + // If reserve size is 0, default it to page size. + uint64_t reserveSize = + req.reserve_size() == 0 ? static_cast(req.page_size()) : static_cast(req.reserve_size()); + const Optional streamFields(req.max_stream_size(), req.page_size(), req.auto_cleanup(), + req.retain_num_consumer(), req.encrypt_stream(), reserveSize, + req.stream_mode()); + const std::string &producerId = req.producer_id(); + const std::string &clientId = req.client_id(); + + // Next to create the stream manager if it doesn't exist. + INJECT_POINT("ClientWorkerSCServiceImpl.CreateProducerImpl.sleep"); + std::shared_ptr streamMgrWithLock; + bool streamExisted; + RETURN_IF_NOT_OK(CreateStreamManagerIfNotExist(namespaceUri, streamFields, streamMgrWithLock, streamExisted)); + bool blockMemoryReclaim = false; + bool rollbackProducer = false; + auto streamMgr = streamMgrWithLock->mgr_; + uint64_t streamNo = streamMgr->GetStreamNo(); + // If we hit any error below, we will erase the StreamManager from the tbb provided it is this thread + // that creates the stream manager. We will need an exclusive accessor + Raii raii([this, &namespaceUri, &streamMgrWithLock, &streamMgr, &blockMemoryReclaim, &rollbackProducer, + &producerId]() { + if (blockMemoryReclaim) { + streamMgr->UnblockMemoryReclaim(); + } + // Unblock reclaim first, because CloseProducer can trigger early reclaim. + if (rollbackProducer) { + LOG_IF_ERROR(streamMgr->CloseProducer(producerId, true), "StreamManager rollback close producer failed"); + } + streamMgrWithLock->CleanUp(std::bind(&ClientWorkerSCServiceImpl::EraseFromStreamMgrDictWithoutLck, this, + namespaceUri, std::placeholders::_1)); + }); + if (streamExisted) { + // An existing stream was found. + // Serialize with EarlyReclaim() so our reserved pages will not be reclaimed. + streamMgr->BlockMemoryReclaim(); + blockMemoryReclaim = true; + bool existsLocalConsumer = streamMgr->CheckConsumerExist(localWorkerAddress_.ToString()).IsOk(); + bool reserveShm = !StreamManager::EnableSharedPage(streamFields->streamMode_) || existsLocalConsumer; + RETURN_IF_NOT_OK(PostCreateStreamManager(streamMgr, streamFields, reserveShm)); + } + bool firstProducer = (streamMgr->GetLocalProducerCount() == 0); + CHECK_FAIL_RETURN_STATUS(firstProducer || streamFields->streamMode_ != StreamMode::SPSC, K_INVALID, + FormatString("There can be at most one producer in this stream mode: %d.", + static_cast(streamFields->streamMode_))); + INJECT_POINT("ClientWorkerSCServiceImpl.CreateProducerImpl.WaitBeforeAdd"); + // Get a unique number to identify the new producer within the stream locally. + DataVerificationHeader::SenderProducerNo senderProducerNo; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamMgr->AddProducer(producerId, senderProducerNo), + "streamMgr add producer failed"); + // We will let go the accessor at this point to prevent deadlock. The master may send back a SyncConsumerNode + // rpc back to this worker if this is the first producer. We are still protected by the createLock + streamMgrWithLock->Release(); + if (firstProducer) { + LOG(INFO) << FormatString("[%s, S:%s, P:%s] First CreateProducer request sending to master.", LogPrefix(), + namespaceUri, producerId); + auto api = workerMasterApiManager_->GetWorkerMasterApi(namespaceUri, etcdCM_); + CHECK_FAIL_RETURN_STATUS(api != nullptr, K_RUNTIME_ERROR, "Get WorkerMasterApi failed of " + namespaceUri); + // Only first producer sends CreateProducer request, so use local address as producer id + Status rc = CreateProducerHandleSend(api, namespaceUri, streamFields); + if (rc.IsError() && rc.GetCode() != StatusCode::K_DUPLICATED) { + LOG(ERROR) << FormatString("Create Producer [%s] failed in master %s: %s", producerId, api->Address(), + rc.GetMsg()); + // If fail on master, we should roll back this operation + rollbackProducer = true; + return rc; + } + streamMgr->InitRetainData(req.retain_num_consumer()); + LOG(INFO) << FormatString("[%s, S:%s, P:%s] CreateProducer success on master %s.", LogPrefix(), namespaceUri, + producerId, api->Address()); + } + + ShmView shmViewOfCursor, shmViewOfStreamMeta; + RETURN_IF_NOT_OK(streamMgr->AddCursorForProducer(producerId, shmViewOfCursor)); + + if (StreamManager::EnableSharedPage(streamFields->streamMode_)) { + RETURN_IF_NOT_OK(streamMgr->GetOrCreateShmMeta(TenantAuthManager::Instance()->ExtractTenantId(namespaceUri), + shmViewOfStreamMeta)); + ShmViewPb shmViewOfStreamMetaPb; + shmViewOfStreamMetaPb.set_fd(shmViewOfStreamMeta.fd); + shmViewOfStreamMetaPb.set_mmap_size(shmViewOfStreamMeta.mmapSz); + shmViewOfStreamMetaPb.set_size(shmViewOfStreamMeta.sz); + shmViewOfStreamMetaPb.set_offset(shmViewOfStreamMeta.off); + rsp.mutable_stream_meta_view()->CopyFrom(shmViewOfStreamMetaPb); + } + + ShmViewPb shmViewOfCursorPb; + shmViewOfCursorPb.set_fd(shmViewOfCursor.fd); + shmViewOfCursorPb.set_mmap_size(shmViewOfCursor.mmapSz); + shmViewOfCursorPb.set_offset(shmViewOfCursor.off); + shmViewOfCursorPb.set_size(shmViewOfCursor.sz); + rsp.mutable_page_view()->CopyFrom(shmViewOfCursorPb); + rsp.set_enable_data_verification(FLAGS_enable_stream_data_verification); + rsp.set_sender_producer_no(senderProducerNo); + rsp.set_stream_no(streamNo); + rsp.set_shared_page_size(FLAGS_sc_shared_page_size_mb * MB_TO_BYTES); + rsp.set_enable_shared_page(StreamManager::EnableSharedPage(static_cast(req.stream_mode()))); + + { + std::unique_lock lock(clearMutex_); + (void)clientProducers_[clientId].emplace_back(namespaceUri, producerId); + } + streamMgrWithLock->needCleanUp = false; + createStreamLocks_.BlockingErase(createLock); + LOG(INFO) << FormatString("[%s, S:%s, P:%s] CreateProducer success.", LogPrefix(), namespaceUri, producerId); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::CloseProducer( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + CloseProducerReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + auto recorder = std::make_shared(AccessRecorderKey::DS_POSIX_CLOSE_PRODUCER); + recorder->reqParam.streamName = req.stream_name(); + recorder->reqParam.producerId = req.producer_id(); + recorder->reqParam.clientId = req.client_id(); + auto rc = CloseProducerInternal(req, recorder, serverApi); + recorder->SetStatus(rc); + return rc; +} + +Status ClientWorkerSCServiceImpl::CloseProducerInternal( + const CloseProducerReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi) +{ + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + TimeoutDuration parentDuration = scTimeoutDuration; + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + LOG(INFO) << "Worker received CloseProducer request:" << LogHelper::IgnoreSensitive(req); + const std::string &producerId = req.producer_id(); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckConnection(namespaceUri), "worker check connection failed"); + const std::string &clientId = req.client_id(); + PerfPoint point(PerfKey::WORKER_CLOSE_PRODUCER_ALL); + + // The real work of the close will be driven in another thread. Launch it now and then release this current thread + // so that it does not hold up the rpc threads. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration = parentDuration; // lambda capture gets the parents copy. assign the copy to thread local + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + CloseProducerRspPb rsp; + Status rc = CloseProducerImpl(producerId, namespaceUri, true); + if (rc.IsOk() || rc.GetCode() == StatusCode::K_SC_PRODUCER_NOT_FOUND) { + std::unique_lock lock(clearMutex_); + clientProducers_[clientId].remove_if([&namespaceUri, &producerId](const StreamProducer &data) { + return (data.streamName_ == namespaceUri && data.producerId_ == producerId); + }); + } + LOG(INFO) << FormatString("[%s, S:%s, P:%s] CloseProducer finish with %s", LogPrefix(), namespaceUri, + producerId, rc.ToString()); + CheckErrorReturn(rc, rsp, FormatString("[S:%s] CloseProducerImpl failed with rc ", namespaceUri), serverApi); + recorder->SetStatus(rc); + // recorder should destroy before traceGuard, otherwise, the traceid will be cleaned up + recorder.reset(); + }); + + return Status::OK(); +} + +void ClientWorkerSCServiceImpl::ConstructCloseProducerReq(std::list &streamList, bool forceClose, + master::CloseProducerReqPb &req) const noexcept +{ + // Common fields for all producers in this list-based CloseProducer call + req.mutable_worker_address()->set_host(localWorkerAddress_.Host()); + req.mutable_worker_address()->set_port(localWorkerAddress_.Port()); + req.set_force_close(forceClose); + req.set_redirect(true); + + // Loop over all of the producers and construct the repeating field of the protobuf + for (const auto &currStream : streamList) { + auto producerInfoPb = req.add_producer_infos(); + producerInfoPb->set_stream_name(currStream); + } + // Clear the stream list. After the master call, this list may be repopulated with any of the producers that got + // a failure (or remain empty if all closes were successful). + streamList.clear(); +} + +Status ClientWorkerSCServiceImpl::HandleCloseProducerRsp(std::list &failedList, + const master::CloseProducerRspPb &rsp) const +{ + Status rc = Status::OK(); + if (rsp.has_err()) { + rc = Status(static_cast(rsp.err().error_code()), rsp.err().error_msg()); + std::string failedProds(" ["); + for (const auto &currProducer : rsp.failed_producers()) { + failedProds += " "; + failedList.emplace_back(currProducer.stream_name()); + } + LOG(ERROR) << "Worker->Master CloseProducer request failed. Failed producer list:" << failedProds << " ]" + << "\nrc: " << rc.ToString(); + } + return rc; +} + +Status ClientWorkerSCServiceImpl::CloseProducerHandleSend(std::shared_ptr api, + std::list &streamList, bool forceClose) +{ + auto closeProducerFn = [&] { + std::list needCloseList = streamList; + master::CloseProducerReqPb req; + ConstructCloseProducerReq(needCloseList, forceClose, req); + master::CloseProducerRspPb rsp; + std::function func = + [&api](master::CloseProducerReqPb &req, master::CloseProducerRspPb &rsp) { + return api->CloseProducer(req, rsp); + }; + // Even for the list batch version of CloseProducer currently groups by the same stream name, + // so if redirect is needed, all should be retried, so here the retry is part of RedirectRetryWhenMetasMoving + RETURN_IF_NOT_OK(RedirectRetryWhenMetasMoving(req, rsp, api, func)); + // CloseProducer packs a return code in its rsp. Handle this now and populate the producerList with the + // producers that failed to close and return if error. + RETURN_IF_NOT_OK(HandleCloseProducerRsp(needCloseList, rsp)); + INJECT_POINT("CloseProducer.TimeoutInMaster"); + streamList = needCloseList; + return Status::OK(); + }; + return WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(closeProducerFn)); +} + +Status ClientWorkerSCServiceImpl::CloseProducerImpl(const std::string &producerId, const std::string &streamName, + bool notifyMaster) +{ + // Serialize CreateProducer requests + CreatePubSubCtrl::Accessor createLock; + createStreamLocks_.Insert(createLock, streamName); + Raii releaseLock([this, &createLock] { createStreamLocks_.TryErase(createLock); }); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + bool lastProducer = (streamMgr->GetLocalProducerCount() == 1); + // We only need to inform master of producer close on last local producer + if (notifyMaster && lastProducer) { + auto api = workerMasterApiManager_->GetWorkerMasterApi(streamName, etcdCM_); + if (api != nullptr) { + std::list streamList; + streamList.emplace_back(streamName); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + CloseProducerHandleSend(api, streamList, false), + FormatString("Close Producer [%s] failed in master %s", producerId, api->Address())); + } else { + // api object is nullptr and its not force close so dont delete local data + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "Worker disconnected from master, stream name = " + streamName); + } + } + // This is a normal close + RETURN_IF_NOT_OK(streamMgr->CloseProducer(producerId, false)); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::CloseProducerLocallyOnForceClose(std::list &producerList, + std::set &streamListForNotifications) +{ + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Closing %d producers in local stream manager", producerList.size()); + Status returnRc = Status::OK(); + auto iter = std::begin(producerList); + while (iter != std::end(producerList)) { + // If any of these calls give an error, record the latest error to the returnRc but continue looping. + Status rc; + StreamManagerMap::const_accessor accessor; + rc = GetStreamManager(iter->streamName_, accessor); + if (rc.IsError()) { + returnRc = rc; + iter++; + continue; + } + + std::shared_ptr streamMgr = accessor->second; + // This is a force close + rc = streamMgr->CloseProducer(iter->producerId_, true); + if (rc.IsError()) { + returnRc = rc; + iter++; + continue; + } + + // We will later check if its a last producer under lock to send notifications to master + (void)streamListForNotifications.emplace(iter->streamName_); + + // This one closed successfully. Remove it from the list so that at the end of the call, only the failed ones + // will remain in the list + iter = producerList.erase(iter); + } + return returnRc; +} + +Status ClientWorkerSCServiceImpl::UnlockAndProtect(std::list &producerList, uint32_t lockId, + ProduceGrpByStreamList &producersGrpStreamName) +{ + // Before driving the close work, protect the delete code path from concurrent closes for all of these producers + auto iter = std::begin(producerList); + while (iter != std::end(producerList)) { + auto streamProducer = *iter; + auto &streamName = streamProducer.streamName_; + LOG(INFO) << FormatString("Start close producer [%s] in stream [%s] for client lost.", + streamProducer.producerId_, streamName); + // Only try to unlock once for each stream. + if (producersGrpStreamName.find(streamName) == producersGrpStreamName.end()) { + StreamManagerMap::const_accessor accessor; + Status rc = GetStreamManager(streamName, accessor); + // The only error here is K_SC_STREAM_NOT_FOUND + // The stream is likely already deleted so we can remove the producer from the list. + if (rc.IsError()) { + iter = producerList.erase(iter); + continue; + } + std::shared_ptr streamMgr = accessor->second; + streamMgr->ForceUnlockByCursor(streamProducer.producerId_, true, lockId); + } + producersGrpStreamName[streamName].push_back(streamProducer); + iter++; + } + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::SendBatchedCloseProducerReq(std::set &streamList, + std::vector &failedList) +{ + Status masterCloseRc = Status::OK(); + for (auto &streamName : streamList) { + CreatePubSubCtrl::Accessor createLock; + (void)createStreamLocks_.Insert(createLock, streamName); + Raii releaseLock([this, &createLock] { createStreamLocks_.BlockingErase(createLock); }); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + // If not last producer dont send the request to master yet + if (streamMgr->GetLocalProducerCount()) { + // Ignore the stream dont add it to the failedList + continue; + } + // A custom retry loop to provide some protection around close failures on the master call. + const int maxRetries = 5; + int numRetries = 0; + Status masterCloseRcPerCall; + std::string masterAddr; + do { + CHECK_FAIL_RETURN_STATUS(scTimeoutDuration.CalcRealRemainingTime() > 0, K_RPC_DEADLINE_EXCEEDED, + "Rpc timeout"); + + // We only send a CloseProducer Request on last producer close + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[S:%s] Sending close producer to master. Attempt: %d", + streamName, numRetries); + auto api = workerMasterApiManager_->GetWorkerMasterApi(streamName, etcdCM_); + if (api != nullptr) { + // force close is true + std::list streamList; + streamList.emplace_back(streamName); + masterCloseRcPerCall = CloseProducerHandleSend(api, streamList, true); + } else { + masterCloseRcPerCall = { StatusCode::K_RPC_UNAVAILABLE, "Master not available for the stream." }; + } + ++numRetries; + } while (masterCloseRcPerCall.IsError() && numRetries < maxRetries); + + if (masterCloseRcPerCall.IsError()) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("CloseProducer failed for stream %s with status: %s", + streamName, masterCloseRcPerCall.ToString()); + masterCloseRc = masterCloseRcPerCall; + failedList.emplace_back(streamName); + } + } + return masterCloseRc; +} + +Status ClientWorkerSCServiceImpl::CloseProducerImplForceClose(uint32_t lockId, std::list &producerList) +{ + ProduceGrpByStreamList producersGrpStreamName; + + // Unlock page lock if its hold by crashed producer + // Get a stream manager accessor on each stream + // group producers by stream name + RETURN_IF_NOT_OK(UnlockAndProtect(producerList, lockId, producersGrpStreamName)); + + // First locally delete the metadata in worker + // And Get List of streams that have 0 producers after closing force closed ones. + std::set streamList; + Status streamCloseRc = CloseProducerLocallyOnForceClose(producerList, streamList); + + INJECT_POINT("ClientWorkerSCServiceImpl.CloseProducerImplForceClose.sleep"); + + // Then if required send Close Producer/Consumer to master + std::vector failedStreamList; + Status masterCloseRc = SendBatchedCloseProducerReq(streamList, failedStreamList); + + producerList.clear(); + // Get all failed producers from streamList and return back to caller + for (auto &streamName : failedStreamList) { + producerList.splice(producerList.end(), producersGrpStreamName[streamName]); + } + + // As this is force close, we do not handle any errors. We just return errors for logging. + // if both masterCloseRc and StreamCloseRc is set, return the master rc. + // master closes worked fine, but there were problems closing the producers locally. + return masterCloseRc.IsError() ? masterCloseRc : streamCloseRc; +} + +Status ClientWorkerSCServiceImpl::Subscribe( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + SubscribeReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + auto recorder = std::make_shared(AccessRecorderKey::DS_POSIX_SUBSCRIBE); + recorder->reqParam.streamName = req.stream_name(); + recorder->reqParam.consumerId = req.consumer_id(); + recorder->reqParam.clientId = req.client_id(); + auto rc = SubscribeInternal(req, recorder, serverApi); + recorder->SetStatus(rc); + return rc; +} + +Status ClientWorkerSCServiceImpl::SubscribeInternal( + const SubscribeReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi) +{ + LOG(INFO) << "Worker received Subscribe request:" << LogHelper::IgnoreSensitive(req); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + TimeoutDuration parentDuration = scTimeoutDuration; + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + PerfPoint point(PerfKey::WORKER_CREATE_SUB_ALL); + + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckConnection(namespaceUri), "worker check connection failed"); + + // The real work of the close will be driven in another thread. Launch it now and then release this current thread + // so that it does not hold up the rpc threads. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration = parentDuration; // lambda capture gets the parents copy. assign the copy to thread local + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + SubscribeRspPb rsp; + Status rc = SubscribeImpl(namespaceUri, req, rsp); + CheckErrorReturn(rc, rsp, "SubscribeImpl failed with rc ", serverApi); + recorder->SetStatus(rc); + // recorder should destroy before traceGuard, otherwise, the traceid will be cleaned up + recorder.reset(); + }); + + return Status::OK(); +} + +void ClientWorkerSCServiceImpl::ConstructConsumerMetaPb(const std::string &streamName, const std::string &consumerId, + uint64_t lastAckCursor, const SubscriptionConfig &config, + const std::string &clientId, ConsumerMetaPb &out) const noexcept +{ + out.set_stream_name(streamName); + out.mutable_worker_address()->set_host(localWorkerAddress_.Host()); + out.mutable_worker_address()->set_port(localWorkerAddress_.Port()); + out.set_consumer_id(consumerId); + out.mutable_sub_config()->set_subscription_name(config.subscriptionName); + out.mutable_sub_config()->set_subscription_type(SubscriptionTypePb(config.subscriptionType)); + out.set_last_ack_cursor(lastAckCursor); + out.set_client_id(clientId); +} + +Status ClientWorkerSCServiceImpl::SubscribeHandleSend(std::shared_ptr streamMgr, + const std::string &streamName, const std::string &consumerId, + uint64_t lastAckCursor, const SubscriptionConfig &config, + const std::string &clientId, Optional &streamFields, + std::string &masterAddress) +{ + auto subscribeFn = [&] { + std::shared_ptr api; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerMasterApiManager_->GetWorkerMasterApi(streamName, etcdCM_, api), + "Getting master api failed. stream name = " + streamName); + masterAddress = api->Address(); + master::SubscribeReqPb masterReq; + auto &consumerMetaPb = *masterReq.mutable_consumer_meta(); + ConstructConsumerMetaPb(streamName, consumerId, lastAckCursor, config, clientId, consumerMetaPb); + masterReq.set_redirect(true); + master::SubscribeRspPb masterRsp; + std::function func = + [&api](master::SubscribeReqPb &req, master::SubscribeRspPb &rsp) { return api->Subscribe(req, rsp); }; + Status rc = RedirectRetryWhenMetaMoving(masterReq, masterRsp, api, func); + // The purpose of the SetRetainData is to notify the local producer, + // in that case we should not set it from INIT to RETAIN if there is no local producer. + if (rc.IsOk() && masterRsp.retain_data() == RetainDataState::State::NOT_RETAIN) { + streamMgr->SetRetainData(masterRsp.retain_data()); + } + Optional rspStreamFields(masterRsp.max_stream_size(), masterRsp.page_size(), + masterRsp.auto_cleanup(), masterRsp.retain_num_consumer(), + masterRsp.encrypt_stream(), masterRsp.reserve_size(), + masterRsp.stream_mode()); + // Assign the initialized Optional to the output fields if the data is not empty. + // Otherwise the optional value will remain false. + if (!rspStreamFields.value().Empty()) { + streamFields = rspStreamFields; + } + INJECT_POINT("ClientWorkerSC.Subscribe.TimingHole"); + return rc; + }; + return WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(subscribeFn)); +} + +Status ClientWorkerSCServiceImpl::SubscribeImpl(const std::string &namespaceUri, const SubscribeReqPb &req, + SubscribeRspPb &rsp) +{ + // We need to serialize on CreateProducer and Subscribe. + // That is, we must wait for the previous one to finish (or rollback) successfully + // before the next pub/sub can proceed + CreatePubSubCtrl::Accessor createLock; + createStreamLocks_.Insert(createLock, namespaceUri); + Raii releaseLock([this, &createLock] { createStreamLocks_.TryErase(createLock); }); + + // Find the StreamManager by request.streamName. + const std::string &subName = req.subscription_config().subscription_name(); + const std::string &consumerId = req.consumer_id(); + + std::shared_ptr streamMgr; + Optional streamFields; // optional is false to start + + // Next to create the stream manager if it doesn't exist. + std::shared_ptr streamMgrWithLock; + bool streamExisted; + RETURN_IF_NOT_OK(CreateStreamManagerIfNotExist(namespaceUri, streamFields, streamMgrWithLock, streamExisted)); + streamMgr = streamMgrWithLock->mgr_; + if (streamExisted) { + // An existing stream was found. + RETURN_IF_NOT_OK(streamMgr->CheckIfStreamActive()); + } + // If we hit any error below, we will erase the StreamManager from the tbb provided it is this thread + // that creates the stream manager. We will need an exclusive accessor + Raii raii([this, &namespaceUri, &streamMgrWithLock]() { + streamMgrWithLock->CleanUp(std::bind(&ClientWorkerSCServiceImpl::EraseFromStreamMgrDictWithoutLck, this, + namespaceUri, std::placeholders::_1)); + }); + + // Add consumer into the SubscriptionGroup we found above. + SubscriptionConfig config(subName, static_cast(req.subscription_config().subscription_type())); + + uint64_t lastAckCursor = 0; + // Request master to update topological structure for every new consumer. + ShmView waView; + // Shm page reservation happens after adding to subs_, so we do not need to synchronize with EarlyReclaim(). + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamMgr->AddConsumer(config, consumerId, lastAckCursor, waView), + "streamMgr add consumer failed"); + + // We will let go the accessor at this point to prevent deadlock. The master may send back a SyncPubNode + // rpc back to this worker if this is the first consumer. We are still protected by the createLock + streamMgrWithLock->Release(); + std::string masterAddress; + Status rc = SubscribeHandleSend(streamMgr, namespaceUri, consumerId, lastAckCursor, config, req.client_id(), + streamFields, masterAddress); + // Ignore if the master doesn't have the stream fields yet. These will be set by CreateProducer later and the master + // will drive notification to set it here in the consumer. + if ((rc.IsOk() || rc.GetCode() == StatusCode::K_DUPLICATED) && streamFields && !streamFields->Empty()) { + // Reserve one page of memory if we know the page size from the master. + rc = PostCreateStreamManager(streamMgr, streamFields, true); + if (rc.IsError()) { + LOG(ERROR) << FormatString("[%s, S:%s, Sub:%s, C:%s] AddConsumer results in %s. Undo master meta data", + LogPrefix(), namespaceUri, subName, consumerId, rc.ToString()); + LOG_IF_ERROR(CloseConsumerImpl(consumerId, namespaceUri, subName, true, WORKER_LOCK_ID, false), + "Undo subscription"); + return rc; + } + if (StreamManager::EnableSharedPage(streamFields->streamMode_)) { + ShmView shmViewOfStreamMeta; + RETURN_IF_NOT_OK(streamMgr->GetOrCreateShmMeta(TenantAuthManager::Instance()->ExtractTenantId(namespaceUri), + shmViewOfStreamMeta)); + } + } + if (rc.IsError() && rc.GetCode() != StatusCode::K_DUPLICATED) { + LOG(ERROR) << FormatString("Create Consumer [%s] failed in master [%s], error msg: %s", consumerId, + masterAddress, rc.GetMsg()); + // If fail on master, we should roll back this operation on local node. + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamMgr->CloseConsumer(subName, consumerId), + "streamMgr close consumer failed"); + return rc; + } + + // Response for client. + rsp.set_last_recv_cursor(lastAckCursor); + rsp.set_worker_fd(waView.fd); + rsp.set_mmap_size(waView.mmapSz); + rsp.set_offset(waView.off); + rsp.set_size(waView.sz); + { + std::unique_lock lock(clearMutex_); + clientConsumers_[req.client_id()].emplace_back(namespaceUri, subName, consumerId); + } + LOG(INFO) << FormatString("[%s, S:%s, C:%s] Subscribe(create consumer) success on master %s", LogPrefix(), + namespaceUri, consumerId, masterAddress); + streamMgrWithLock->needCleanUp = false; + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::CloseConsumer( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + CloseConsumerReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + auto recorder = std::make_shared(AccessRecorderKey::DS_POSIX_CLOSE_CONSUMER); + recorder->reqParam.streamName = req.stream_name(); + recorder->reqParam.consumerId = req.consumer_id(); + recorder->reqParam.subscriptionName = req.subscription_name(); + recorder->reqParam.clientId = req.client_id(); + auto rc = CloseConsumerInternal(req, recorder, serverApi); + recorder->SetStatus(rc); + return rc; +} + +Status ClientWorkerSCServiceImpl::CloseConsumerInternal( + const CloseConsumerReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi) +{ + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + TimeoutDuration parentDuration = scTimeoutDuration; + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + LOG(INFO) << "Worker received CloseConsumer request:" << LogHelper::IgnoreSensitive(req); + const std::string &consumerId = req.consumer_id(); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckConnection(namespaceUri), "worker check connection failed"); + const std::string &clientId = req.client_id(); + + // The real work of the close will be driven in another thread. Launch it now and then release this current thread + // so that it does not hold up the rpc threads. + auto traceId = Trace::Instance().GetTraceID(); + threadPool_->Execute([=]() mutable { + scTimeoutDuration = parentDuration; // lambda capture gets the parents copy. assign the copy to thread local + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + Status rc = CloseConsumerImpl(consumerId, namespaceUri, req.subscription_name(), true, WORKER_LOCK_ID); + if (rc.IsOk() || rc.GetCode() == StatusCode::K_SC_CONSUMER_NOT_FOUND) { + std::unique_lock lock(clearMutex_); + clientConsumers_[clientId].remove_if([&namespaceUri, &consumerId](const SubInfo &data) { + return (data.streamName == namespaceUri && data.consumerId == consumerId); + }); + } + LOG(INFO) << FormatString("[%s, S:%s, C:%s] CloseConsumer finish with %s", LogPrefix(), namespaceUri, + consumerId, rc.ToString()); + // Flow replies back for unary rpc + if (rc.IsOk()) { + // Success case, flow the response back to client (rc of OK is inferred) + CloseConsumerRspPb rsp; + LOG_IF_ERROR(serverApi->Write(rsp), "Write reply to client stream failed"); + } else { + // Error case, flow the rc back to client + LOG_IF_ERROR(serverApi->SendStatus(rc), "Write reply to client stream failed"); + } + recorder->SetStatus(rc); + // recorder should destroy before traceGuard, otherwise, the traceid will be cleaned up + recorder.reset(); + }); + + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::CloseConsumerImpl(const std::string &consumerId, const std::string &streamName, + const std::string &subName, bool notifyMaster, uint32_t lockId, + bool forceClose) +{ + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + if (lockId > WORKER_LOCK_ID) { + streamMgr->ForceUnlockByCursor(consumerId, false, lockId); + } + SubscriptionType subType; + // Obtain subscription type to construct subscription config. + RETURN_IF_NOT_OK(streamMgr->GetSubType(subName, subType)); + SubscriptionConfig subConfig(subName, subType); + Status rc; + if (notifyMaster) { + std::string masterAddr; + auto closeConsumerFn = [&] { + // We don't care about lastAckCursor change when close consumer, so we set it as 0. + std::shared_ptr api; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerMasterApiManager_->GetWorkerMasterApi(streamName, etcdCM_, api), + "Getting master api failed. stream name = " + streamName); + masterAddr = api->Address(); + master::CloseConsumerReqPb req; + auto &consumerMetaPb = *req.mutable_consumer_meta(); + ConstructConsumerMetaPb(streamName, consumerId, 0, subConfig, "", consumerMetaPb); + req.set_redirect(true); + master::CloseConsumerRspPb rsp; + std::function func = + [&api](master::CloseConsumerReqPb &req, master::CloseConsumerRspPb &rsp) { + return api->CloseConsumer(req, rsp); + }; + return RedirectRetryWhenMetaMoving(req, rsp, api, func); + }; + rc = WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(closeConsumerFn)); + LOG_IF_ERROR(rc, FormatString("Close Consumer [%s] failed in master [%s]: %s. ForceClose: %s", consumerId, + masterAddr, rc.GetMsg(), forceClose ? "true" : "false")); + } + if (rc.IsError() && !forceClose) { + return rc; + } + return streamMgr->CloseConsumer(subName, consumerId); +} + +Status ClientWorkerSCServiceImpl::GetDataPage( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + GetDataPageReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Worker(%s) receive GetDataPage request, namespaceUri: %s", + localWorkerAddress_.ToString(), namespaceUri); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(GetStreamManager(namespaceUri, accessor), "worker get stream manager failed"); + std::shared_ptr streamManager = accessor->second; + RETURN_IF_NOT_OK(streamManager->CheckIfStreamActive()); // Check for Reset or Delete + std::shared_ptr subscription; + RETURN_IF_NOT_OK(streamManager->GetSubscription(req.subscription_name(), subscription)); + return streamManager->GetDataPage(req, subscription, serverApi); +} + +Status ClientWorkerSCServiceImpl::GetLastAppendCursor(const LastAppendCursorReqPb &req, LastAppendCursorRspPb &rsp) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(GetStreamManager(namespaceUri, accessor), "worker get stream manager failed"); + std::shared_ptr streamManager = accessor->second; + rsp.set_last_append_cursor(streamManager->GetLastAppendCursor()); + + return Status::OK(); +} + +template +void ClientWorkerSCServiceImpl::AsyncSendMemReq(const std::string &namespaceUri) +{ + auto traceId = Trace::Instance().GetTraceID(); + memAllocPool_->Execute([this, namespaceUri, traceId]() { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + Status rc = this->template HandleBlockedRequestImpl(namespaceUri); + if (rc.IsError()) { + LOG(ERROR) << FormatString("HandleBlockedRequestImpl failed for %s. %s", namespaceUri, rc.ToString()); + } + }); +} + +Status ClientWorkerSCServiceImpl::CreateShmPage( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + PerfPoint point(PerfKey::WORKER_CREATE_WRITE_PAGE_ALL); + CreateShmPageReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] CreateShmPage request: %s", LogPrefix(), namespaceUri, + LogHelper::IgnoreSensitive(req)); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(namespaceUri, accessor)); + std::shared_ptr streamManager = accessor->second; + RETURN_IF_NOT_OK(streamManager->CheckIfStreamActive()); // Check for Reset or Delete + // Create a blocked request for FIFO. There can be some previous requests that got OOM and are still waiting. + auto fn = std::bind(&StreamManager::AllocDataPage, streamManager, std::placeholders::_1); + auto blockedReq = std::make_unique>( + namespaceUri, req, streamManager->GetStreamPageSize(), serverApi, fn); + // Lock to compete with StreamManager::UnblockProducers + RETURN_IF_NOT_OK(streamManager->AddBlockedCreateRequest(this, std::move(blockedReq), true)); + AsyncSendMemReq(namespaceUri); + return Status::OK(); +} + +template +Status ClientWorkerSCServiceImpl::HandleBlockedRequestImpl(const std::string &streamName) +{ + StreamManagerMap::const_accessor accessor; + PerfPoint point1(PerfKey::WORKER_CREATE_PAGE_GET_STREAM_MANAGER); + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + point1.Record(); + // Check the next request. + size_t nextReqSz; + bool bigElement; + std::tie(nextReqSz, bigElement) = streamMgr->GetNextBlockedRequestSize(); + // Big element doesn't go through the ack chain, and we will focus more + // on CreateShmPage request which incur a lot of lock conflicts with the ack thread + if (!(bigElement || streamMgr->CheckHadEnoughMem(nextReqSz))) { + // We will try to release pages (if any) as if this thread is doing ack manually + // Part of the ack process may also call StreamManager::HandleBlockedRequestImpl + // in which case we can expect StreamMgr::GetBlockedCreateRequest below can return + // K_TRY_AGAIN + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s S:%s] Most likely OOM. Reclaim memory", LogPrefix(), streamName); + streamMgr->AckCursors(); + } + std::shared_ptr> blockedReq; + INJECT_POINT("HandleBlockedRequestImpl.sleep"); + auto rc = streamMgr->GetBlockedCreateRequest(blockedReq); + RETURN_OK_IF_TRUE(rc.GetCode() == K_TRY_AGAIN); + RETURN_IF_NOT_OK(rc); + // Treat OOM as normal. HandleBlockedRequestImpl will re-queue the request + RETURN_IF_NOT_OK_EXCEPT(streamMgr->template HandleBlockedRequestImpl(std::move(blockedReq), true), K_OUT_OF_MEMORY); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::AllocBigShmMemory( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + CreateLobPageReqPb req; + RETURN_IF_NOT_OK(serverApi->Read(req)); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Worker(%s) receive AllocBigShmMemory request, namespaceUri: %s", + localWorkerAddress_.ToString(), namespaceUri); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(GetStreamManager(namespaceUri, accessor), "worker get stream manager failed"); + std::shared_ptr streamManager = accessor->second; + RETURN_IF_NOT_OK(streamManager->CheckIfStreamActive()); // Check for Reset or Delete + // Create a blocked request for FIFO. There can be some previous requests that got OOM and are still waiting. + auto fn = std::bind(&StreamManager::AllocBigShmMemory, streamManager, std::placeholders::_1); + auto blockedReq = std::make_unique>( + namespaceUri, req, req.page_size(), serverApi, fn); + // Lock to compete with StreamManager::UnblockProducers + RETURN_IF_NOT_OK(streamManager->AddBlockedCreateRequest(this, std::move(blockedReq), true)); + AsyncSendMemReq(namespaceUri); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::ReleaseBigShmMemory( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + ReleaseLobPageReqPb req; + RETURN_IF_NOT_OK(serverApi->Read(req)); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Worker(%s) receive ReleaseBigShmMemory request, namespaceUri: %s", + localWorkerAddress_.ToString(), namespaceUri); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(GetStreamManager(namespaceUri, accessor), "worker get stream manager failed"); + std::shared_ptr streamManager = accessor->second; + RETURN_IF_NOT_OK(streamManager->CheckIfStreamActive()); // Check for Reset or Delete + RETURN_IF_NOT_OK(streamManager->ReleaseBigShmMemory(serverApi, req)); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::DeleteStreamLocally(StreamManagerMap::accessor &accessor) +{ + std::shared_ptr streamManager = accessor->second; + auto streamName = streamManager->GetStreamName(); + // Earlier topo change should have removed the stream from RWM. But no harm to call it again + // but just expect the error stream not found + RETURN_IF_NOT_OK_EXCEPT(remoteWorkerManager_->DeleteStream(streamName), K_SC_STREAM_NOT_FOUND); + bool success = streamMgrDict_.erase(accessor); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(success, K_RUNTIME_ERROR, + FormatString("Failed erase stream %s from streamMgrDict_", streamName)); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::DeleteStream(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp) +{ + AccessRecorder recorder(AccessRecorderKey::DS_POSIX_DELETE_STREAM); + auto rc = DeleteStreamImpl(req, rsp); + StreamRequestParam reqParam; + reqParam.streamName = req.stream_name(); + reqParam.clientId = req.client_id(); + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Status ClientWorkerSCServiceImpl::DeleteStreamImpl(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + (void)rsp; + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + PerfPoint point(PerfKey::WORKER_DELETE_STREAM_ALL); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CheckConnection(namespaceUri), "worker check connection failed"); + LOG(INFO) << FormatString("Worker(%s) received DeleteStream request, namespaceUri: %s", + localWorkerAddress_.ToString(), namespaceUri); + bool needsRollback = true; + { + StreamManagerMap::const_accessor accessor; + if (GetStreamManager(namespaceUri, accessor).IsOk()) { // If exists on local worker node + std::shared_ptr streamManager = accessor->second; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamManager->CheckDeleteStreamCondition(), + "streamManager check delete stream condition failed"); + // Set the state to DELETE_IN_PROGRESS to prevent a timing hole after + // the master is updated but before the stream manager is removed from tbb. + Status rc = streamManager->SetDeleteState(); + if (rc.IsError()) { + // If delete is already in progress + if (rc.GetCode() == StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS) { + // Then this function is a no op + // No need to decrement because without ignore, RefCount does not increase + return Status::OK(); + } + // No need to decrement because on error, RefCount does not increase + return rc; // Else return any other error + } + } + accessor.release(); + } + Raii unsetDelete([this, &needsRollback, &namespaceUri]() { + StreamManagerMap::const_accessor accessor; + if (needsRollback && GetStreamManager(namespaceUri, accessor).IsOk()) { + std::shared_ptr streamManager = accessor->second; + LOG(INFO) << FormatString("[S:%s] Setting Active State", namespaceUri); + streamManager->SetActiveState(); + } + }); + INJECT_POINT("ClientWorkerSCServiceImpl.DELETE_IN_PROGRESS.sleep"); + // We call master to process broadcast to all other workers, except this node, that's related to this stream. + Status rc = DeleteStreamHandleSend(namespaceUri); + // Still proceed to handle DeleteStreamLocally even if K_NOT_FOUND is returned. + INJECT_POINT("ClientWorkerSCServiceImpl.DeleteStreamHandleSend.sleep"); + RETURN_IF_NOT_OK_EXCEPT(rc, K_NOT_FOUND); + // Now we need an exclusive accessor lock when we remove the stream from the tbb map + LOG(INFO) << FormatString("[S:%s] Waiting for stream to be free of use", namespaceUri); + INJECT_POINT("ClientWorkerSCServiceImpl.DeleteStreamLocally.sleep"); + bool success; + { + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + StreamManagerMap::accessor accessor; + success = streamMgrDict_.find(accessor, namespaceUri); + if (success) { + RETURN_IF_NOT_OK(DeleteStreamLocally(accessor)); + } + } + LOG_IF(INFO, !success) << FormatString("[S:%s] Stream manager is already gone", namespaceUri); + LOG(INFO) << FormatString("[%s, S:%s] DeleteStream success.", LogPrefix(true), namespaceUri); + needsRollback = false; + return success ? Status::OK() : rc; +} + +Status ClientWorkerSCServiceImpl::DeleteStreamHandleSend(const std::string &streamName) +{ + auto deleteFn = [&] { + std::shared_ptr api; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerMasterApiManager_->GetWorkerMasterApi(streamName, etcdCM_, api), + "Getting master api failed. stream name = " + streamName); + master::DeleteStreamReqPb masterReq; + masterReq.set_stream_name(streamName); + masterReq.mutable_src_node_addr()->set_host(localWorkerAddress_.Host()); + masterReq.mutable_src_node_addr()->set_port(localWorkerAddress_.Port()); + masterReq.set_redirect(true); + master::DeleteStreamRspPb masterRsp; + std::function func = + [&api](master::DeleteStreamReqPb &req, master::DeleteStreamRspPb &rsp) { + return api->DeleteStream(req, rsp); + }; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(RedirectRetryWhenMetaMoving(masterReq, masterRsp, api, func), + "workerMasterApi delete stream failed on master " + api->Address()); + return Status::OK(); + }; + return WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(deleteFn)); +} + +Status ClientWorkerSCServiceImpl::QueryGlobalProducersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + AccessRecorder recorder(AccessRecorderKey::DS_POSIX_QUERY_PRODUCERS_NUM); + auto rc = QueryGlobalProducersNumImpl(req, rsp); + StreamRequestParam reqParam; + reqParam.streamName = req.stream_name(); + reqParam.clientId = req.client_id(); + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + rspParam.count = Optional(rsp.global_count()); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Status ClientWorkerSCServiceImpl::QueryGlobalProducersNumImpl(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + LOG(INFO) << FormatString("Worker(%s) received QueryGlobalProducersNum request, namespaceUri: %s", + localWorkerAddress_.ToString(), namespaceUri); + auto queryFn = [&] { + std::shared_ptr api; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerMasterApiManager_->GetWorkerMasterApi(namespaceUri, etcdCM_, api), + "Getting master api failed. stream name = " + namespaceUri); + master::QueryGlobalNumReqPb masterReq; + masterReq.set_stream_name(namespaceUri); + masterReq.set_redirect(true); + master::QueryGlobalNumRsqPb masterRsp; + std::function func = + [&api](master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) { + return api->QueryGlobalProducersNum(req, rsp); + }; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + RedirectRetryWhenMetaMoving(masterReq, masterRsp, api, func), + "workerMasterApi query global producers number failed on master" + api->Address()); + rsp.set_global_count(masterRsp.global_count()); + LOG(INFO) << "worker QueryGlobalProducersNum done, namespaceUri: " << namespaceUri; + return Status::OK(); + }; + return WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(queryFn)); +} + +Status ClientWorkerSCServiceImpl::QueryGlobalConsumersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + AccessRecorder recorder(AccessRecorderKey::DS_POSIX_QUERY_CONSUMERS_NUM); + auto rc = QueryGlobalConsumersNumImpl(req, rsp); + StreamRequestParam reqParam; + reqParam.streamName = req.stream_name(); + reqParam.clientId = req.client_id(); + StreamResponseParam rspParam; + rspParam.msg = rc.GetMsg(); + rspParam.count = Optional(rsp.global_count()); + recorder.Record(rc.GetCode(), reqParam, rspParam); + return rc; +} + +Status ClientWorkerSCServiceImpl::QueryGlobalConsumersNumImpl(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + std::string namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.stream_name()); + LOG(INFO) << FormatString("Worker(%s) received QueryGlobalConsumersNum request, namespaceUri: %s", + localWorkerAddress_.ToString(), namespaceUri); + auto queryFn = [&] { + std::shared_ptr api; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerMasterApiManager_->GetWorkerMasterApi(namespaceUri, etcdCM_, api), + "Getting master api failed. stream name = " + namespaceUri); + master::QueryGlobalNumReqPb masterReq; + masterReq.set_stream_name(namespaceUri); + masterReq.set_redirect(true); + master::QueryGlobalNumRsqPb masterRsp; + std::function func = + [&api](master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) { + return api->QueryGlobalConsumersNum(req, rsp); + }; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + RedirectRetryWhenMetaMoving(masterReq, masterRsp, api, func), + "workerMasterApi query global consumers number failed on master " + api->Address()); + rsp.set_global_count(masterRsp.global_count()); + LOG(INFO) << "worker QueryGlobalConsumersNum done, namespaceUri: " << namespaceUri; + return Status::OK(); + }; + + return WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(queryFn)); +} + +std::string ClientWorkerSCServiceImpl::LogPrefix(bool withAddress) const +{ + if (withAddress) { + return FormatString("ClientWorkerSvc, Node:%s", localWorkerAddress_.ToString()); + } else { + return "ClientWorkerSvc"; + } +} + +Status ClientWorkerSCServiceImpl::SendBlockProducerReq(const std::string &streamName, + const std::string &remoteWorkerAddr) +{ + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + RETURN_IF_NOT_OK(streamMgr->BlockProducer(remoteWorkerAddr, false)); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::SendUnBlockProducerReq(const std::string &streamName, + const std::string &remoteWorkerAddr) +{ + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + RETURN_IF_NOT_OK(streamMgr->UnBlockProducer(remoteWorkerAddr)); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::CreateStreamManagerIfNotExist( + const std::string &streamName, const Optional &streamFields, + std::shared_ptr &streamMgrWithLock, bool &streamExisted) +{ + auto rlock = std::make_unique(); + streamExisted = streamMgrDict_.find(*rlock, streamName); + if (streamExisted) { + // An existing stream was found. + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Stream already exists in StreamManager", LogPrefix(), + streamName); + auto streamMgr = (*rlock)->second; + streamMgrWithLock = + std::make_shared(streamMgr, rlock.release(), false, shared_from_this()); + } else { + auto xlock = std::make_unique(); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CreateStreamManagerImpl(streamName, streamFields, *xlock), + "worker create stream manager failed"); + auto streamMgr = (*xlock)->second; + streamMgrWithLock = + std::make_shared(streamMgr, xlock.release(), true, shared_from_this()); + } + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::PostCreateStreamManager(const std::shared_ptr &streamManager, + const Optional &streamFields, bool reserveShm) +{ + RETURN_IF_NOT_OK(streamManager->CheckIfStreamActive()); + // If the stream fields are passed in then assign them to this existing stream + // manager. This code may fail a verification check if the existing stream was not empty and has mismatching + // settings. + if (streamFields && !streamFields.value().Empty()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamManager->UpdateStreamFields(streamFields.value(), reserveShm), + "streamMgr verify failed"); + } + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::CreateStreamManagerImpl(const std::string &streamName, + const Optional &streamFields, + StreamManagerMap::accessor &accessor) +{ + PerfPoint point(PerfKey::WORKER_CREATE_STREAM_MANAGER_LOGIC); + auto streamNo = ++lifetimeLocalStreamCount_; + bool needRollback = streamMgrDict_.emplace( + accessor, std::make_pair(streamName, std::make_shared( + streamName, remoteWorkerManager_.get(), localWorkerAddress_.ToString(), + akSkManager_, weak_from_this(), scAllocateManager_, + workerWorkerSCService_, streamNo))); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + needRollback, K_DUPLICATED, + FormatString("[%s, S:%s] Stream already exists in StreamManager", LogPrefix(), streamName)); + // If we hit any error below, we will erase the StreamManager from the tbb + Raii raii([this, &accessor, &needRollback]() { + if (needRollback) { + streamMgrDict_.erase(accessor); + } + }); + INJECT_POINT("ClientWorkerSCServiceImpl.CreateStreamManagerImpl.StreamNo_Sleep"); + RETURN_IF_NOT_OK(AddStreamNo(streamNo, streamName)); + auto &streamManager = accessor->second; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Create stream in StreamManager.", LogPrefix(), streamName); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + streamManager->CreatePageQueueHandler(streamFields), + FormatString("[%s, S:%s] Fail to create page queue handler", LogPrefix(), streamName)); + if (ScMetricsMonitor::Instance()->IsEnabled()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + streamManager->InitStreamMetrics(), + FormatString("[%s, S:%s] Fail to init stream metrics", LogPrefix(), streamName)); + } + needRollback = false; + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::GetStreamManager(const std::string &streamName, + StreamManagerMap::const_accessor &accessor) +{ + PerfPoint point(PerfKey::WORKER_CREATE_STREAM_MANAGER_GET_LOCK); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + point.RecordAndReset(PerfKey::WORKER_GET_STREAM_MANAGER); + auto success = streamMgrDict_.find(accessor, streamName); + CHECK_FAIL_RETURN_STATUS(success, K_SC_STREAM_NOT_FOUND, + FormatString("[%s] Stream %s not found", LogPrefix(), streamName)); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::ClosePubSubForClientLost(const std::string &clientId) +{ + std::shared_ptr clientInfo; + clientInfo = ClientManager::Instance().GetClientInfo(clientId); + if (clientInfo == nullptr) { + RETURN_STATUS_LOG_ERROR(K_RUNTIME_ERROR, "invalid client id"); + } + uint32_t lockId; + RETURN_IF_NOT_OK(clientInfo->GetLockId(lockId)); + LOG(INFO) << "Client Lost: Begin to close producers and consumers for client: " << clientId + << " lock id: " << lockId; + + std::list producerList; + std::list consumerList; + Status returnRc = Status::OK(); + { + std::lock_guard lock(clearMutex_); + producerList = std::move(clientProducers_[clientId]); + consumerList = std::move(clientConsumers_[clientId]); + (void)clientProducers_.erase(clientId); + (void)clientConsumers_.erase(clientId); + } + std::set streams; + for (const auto &producer : producerList) { + (void)streams.emplace(producer.streamName_); + } + for (const auto &consumer : consumerList) { + (void)streams.emplace(consumer.streamName); + } + + ForceUnlockMemViemForPages(streams, lockId); + + if (!producerList.empty()) { + Status rc = CloseProducerImplForceClose(lockId, producerList); + LOG(INFO) << FormatString("Close producers for client %s finished with %s", clientId, rc.ToString()); + if (rc.IsError()) { + // Unsuccessful to close at least one of the producers. Add any failed producers back to the client list + // for this client. Do not quit this function yet. Continue to try to close amy consumers as well. + std::lock_guard lock(clearMutex_); + clientProducers_[clientId] = std::move(producerList); + returnRc = rc; + } + } + + // Consumer close does not have list-based close at this time. Loop over each consumer and close it. + std::list failedConsumers; + const bool forceClose = true; + for (const auto &consumerInfo : consumerList) { + Status rc = CloseConsumerImpl(consumerInfo.consumerId, consumerInfo.streamName, consumerInfo.subName, true, + lockId, forceClose); + LOG(INFO) << FormatString("Close consumer [%s] in stream [%s] for client %s finished with %s", + consumerInfo.consumerId, consumerInfo.streamName, clientId, rc.ToString()); + if (rc.IsError()) { + // Track the consumer that failed. + failedConsumers.emplace_back(consumerInfo.streamName, consumerInfo.subName, consumerInfo.consumerId); + returnRc = rc; + } + } + + if (!failedConsumers.empty()) { + // Add any consumers that failed to close back to the client tracking + std::lock_guard lock(clearMutex_); + clientConsumers_[clientId] = std::move(failedConsumers); + } + + // At this point, if returnRc is not OK, then the client has not been properly closed and there still exists + // producers and/or consumers that have not been cleaned up yet. + return returnRc; +} + +void ClientWorkerSCServiceImpl::ForceUnlockMemViemForPages(const std::set &streams, uint32_t lockId) +{ + Timer timer; + for (const auto &streamName : streams) { + StreamManagerMap::const_accessor accessor; + Status rc = GetStreamManager(streamName, accessor); + if (rc.IsError()) { + continue; + } + accessor->second->ForceUnlockMemViemForPages(lockId); + } + LOG(INFO) << "ForceUnlockMemViemForPages for stream count: " << streams.size() + << ", cost:" << timer.ElapsedMilliSecond() << "ms"; +} + +void ClientWorkerSCServiceImpl::GetProducerConsumerMetadata( + std::vector &localProducers, std::vector> &localConsumers, + GetStreamMetadataRspPb *meta, const std::string &streamName, HostPortPb &hostPortPb) +{ + // If there exists a producer, set the ProducerPb + if (localProducers.size()) { + auto producerPb = meta->add_producers(); + producerPb->set_stream_name(streamName); + producerPb->mutable_worker_address()->CopyFrom(hostPortPb); + // This number will be 1 as master dont have to know about actual count + producerPb->set_producer_count(1); + } + + for (const auto &consumer : localConsumers) { + std::string foundClientId = ""; + auto comparator = [&consumer](const SubInfo conInfo) { return consumer.first == conInfo.consumerId; }; + for (const auto &[clientId, conList] : clientConsumers_) { + auto iter = std::find_if(conList.begin(), conList.end(), comparator); + if (iter != conList.end()) { + foundClientId = clientId; + break; + } + } + if (foundClientId.empty()) { + LOG(ERROR) << "Client id not found for consumer: " << consumer.first; + continue; + } + auto consumerPb = meta->add_consumers(); + consumerPb->set_client_id(std::move(foundClientId)); + consumerPb->set_consumer_id(consumer.first); + consumerPb->set_stream_name(streamName); + const auto &config = consumer.second; + consumerPb->mutable_sub_config()->set_subscription_name(config.subscriptionName); + consumerPb->mutable_sub_config()->set_subscription_type(SubscriptionTypePb(config.subscriptionType)); + consumerPb->mutable_worker_address()->CopyFrom(hostPortPb); + } +} + +Status ClientWorkerSCServiceImpl::GetStreamMetadata(const std::string &streamName, GetStreamMetadataRspPb *meta) +{ + std::shared_lock lock(clearMutex_); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(GetStreamManager(streamName, accessor), "GetStreamManager failed"); + std::shared_ptr streamManager = accessor->second; + + // Get pub and sub metadata. + std::vector localProducers; + std::vector> localConsumers; + streamManager->GetLocalProducers(localProducers); + streamManager->GetLocalConsumers(localConsumers); + meta->set_is_remote_pub_empty(streamManager->IsRemotePubEmpty()); + + StreamFields fields; + streamManager->GetStreamFields(fields); + meta->set_max_stream_size(fields.maxStreamSize_); + meta->set_page_size(fields.pageSize_); + meta->set_auto_cleanup(fields.autoCleanup_); + meta->set_retain_num_consumer(fields.retainForNumConsumers_); + meta->set_encrypt_stream(fields.encryptStream_); + meta->set_reserve_size(fields.reserveSize_); + meta->set_stream_mode(fields.streamMode_); + + HostPortPb hostPortPb; + hostPortPb.set_host(localWorkerAddress_.Host()); + hostPortPb.set_port(localWorkerAddress_.Port()); + + GetProducerConsumerMetadata(localProducers, localConsumers, meta, streamName, hostPortPb); + return Status::OK(); +} + +std::vector ClientWorkerSCServiceImpl::GetStreamNameList() +{ + std::vector streamNames; + Timer t; + WriteLockHelper wlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Time to acquire mutex_: [%.6lf]s", LogPrefix(), + t.ElapsedSecond()); + auto iter = streamMgrDict_.begin(); + while (iter != streamMgrDict_.end()) { + streamNames.emplace_back(iter->first); + ++iter; + } + return streamNames; +} + +Status ClientWorkerSCServiceImpl::SendAllStreamMetadata( + const GetMetadataAllStreamReqPb &req, + std::shared_ptr> &streamRpc) +{ + Status status; + const std::string &masterAddr = req.master_address(); + worker::HashRange hashRanges; + hashRanges.reserve(req.hash_ranges_size()); + for (const auto &range : req.hash_ranges()) { + hashRanges.emplace_back(range.from(), range.end()); + } + + auto streamNames = GetStreamNameList(); + for (const auto &streamName : streamNames) { + if (CheckConditionsForStream(streamName, masterAddr, hashRanges)) { + // This stream goes to the requesting master. + GetStreamMetadataRspPb rsp; + Status rc = GetStreamMetadata(streamName, &rsp); + rsp.set_stream_name(streamName); + rsp.mutable_error()->set_error_code(rc.GetCode()); + rsp.mutable_error()->set_error_msg(rc.GetMsg()); + ASSIGN_IF_NOT_OK_PRINT_ERROR_MSG(status, streamRpc->Write(rsp), + FormatString("Write metadata to master failed for stream %s", streamName)); + } + } + return status; +} + +Status ClientWorkerSCServiceImpl::GetAllStreamMetadata(const GetMetadataAllStreamReqPb &req, + GetMetadataAllStreamRspPb &rsp) +{ + const std::string &masterAddr = req.master_address(); + worker::HashRange hashRanges; + hashRanges.reserve(req.hash_ranges_size()); + for (const auto &range : req.hash_ranges()) { + hashRanges.emplace_back(range.from(), range.end()); + } + + auto streamNames = GetStreamNameList(); + for (const auto &streamName : streamNames) { + if (CheckConditionsForStream(streamName, masterAddr, hashRanges)) { + // This stream goes to the requesting master. + auto streamMetaPb = rsp.add_stream_meta(); + streamMetaPb->set_stream_name(streamName); + Status rc = GetStreamMetadata(streamName, streamMetaPb); + streamMetaPb->mutable_error()->set_error_code(rc.GetCode()); + streamMetaPb->mutable_error()->set_error_msg(rc.GetMsg()); + } + } + return Status::OK(); +} + +bool ClientWorkerSCServiceImpl::CheckConditionsForStream(const std::string &streamName, const std::string &masterAddr, + const worker::HashRange &hashRanges) +{ + if (!masterAddr.empty()) { + MetaAddrInfo metaAddrInfo; + auto rc = etcdCM_->GetMetaAddress(streamName, metaAddrInfo); + if (rc.IsError()) { + LOG(ERROR) << rc.ToString(); + return false; + } + auto masterAddress = metaAddrInfo.GetAddressAndSaveDbName(); + return masterAddress.ToString() == masterAddr; + } + return etcdCM_->IsInRange(hashRanges, streamName, ""); +} + +Status ClientWorkerSCServiceImpl::CheckConnection(const std::string &streamName) +{ + auto func = [&] { + Status status = etcdCM_->CheckConnection(streamName); + if (status.IsError()) { + std::stringstream ss; + ss << "Worker disconnected from master, error msg: " << status.ToString(); + LOG(ERROR) << ss.str(); + RETURN_STATUS(StatusCode::K_RPC_UNAVAILABLE, ss.str()); + } + return Status::OK(); + }; + return WorkerMasterOcApiManager::RetryForReplicaNotReady(scTimeoutDuration.CalcRealRemainingTime(), + std::move(func)); +} + +Status ClientWorkerSCServiceImpl::UnblockProducer( + std::shared_ptr> serverApi) +{ + UnblockProducerReqPb req; + RETURN_IF_NOT_OK(serverApi->Read(req)); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + const auto &streamName = req.stream_name(); + const auto &workerAddr = req.worker_addr(); + LOG(INFO) << "UnBlocking Producer for stream: " << streamName << " From remote worker: " << workerAddr; + HostPort workerHostPort; + RETURN_IF_NOT_OK(workerHostPort.ParseString(workerAddr)); + RETURN_IF_NOT_OK(remoteWorkerManager_->ToggleStreamBlocking(workerAddr, streamName, false)); + RETURN_IF_NOT_OK(serverApi->Write(UnblockProducerRspPb())); + VLOG(SC_NORMAL_LOG_LEVEL) << "UnBlocking Producer Request for stream: " << streamName + << " From remote worker: " << workerAddr << " Successful"; + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::BlockProducer( + std::shared_ptr> serverApi) +{ + BlockProducerReqPb req; + RETURN_IF_NOT_OK(serverApi->Read(req)); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + const auto &streamName = req.stream_name(); + const auto &workerAddr = req.worker_addr(); + LOG(INFO) << "Blocking Producer for stream: " << streamName << " From remote worker: " << workerAddr; + HostPort workerHostPort; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(workerHostPort.ParseString(workerAddr), "ParseString error"); + RETURN_IF_NOT_OK(remoteWorkerManager_->ToggleStreamBlocking(workerAddr, streamName, true)); + RETURN_IF_NOT_OK(serverApi->Write(BlockProducerRspPb())); + VLOG(SC_NORMAL_LOG_LEVEL) << "Blocking Producer for stream: " << streamName << " From remote worker: " << workerAddr + << " Successful"; + return Status::OK(); +} + +std::string ClientWorkerSCServiceImpl::GetTotalStreamCount() +{ + return std::to_string(streamMgrDict_.size()); +} + +Status ClientWorkerSCServiceImpl::DeleteStreamContext(const std::string &streamName, bool forceDelete, int64_t timeout) +{ + LOG(INFO) << FormatString("[%s, S:%s] DelStreamContext request started with forceDelete: %d", LogPrefix(true), + streamName, forceDelete); + scTimeoutDuration.Init(timeout); + Raii outerResetDuration([]() { scTimeoutDuration.Reset(); }); + std::shared_ptr streamManager; + INJECT_POINT("ClientWorkerSCServiceImpl.DeleteStreamContext.timeout", [](int timeout) { + scTimeoutDuration.Init(timeout); + return Status::OK(); + }); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + streamManager = accessor->second; + if (forceDelete) { + // Since the pub/sub is closed before calling DeleteStreamContext, the pub/sub list is empty for the stream. + std::vector pubSubList; + RETURN_IF_NOT_OK(streamManager->ResetStreamStart(pubSubList)); + streamManager->ForceCloseClients(); + } + + Status rc = streamManager->SetDeleteState(true); + if (rc.GetCode() == StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS) { + LOG(INFO) << FormatString("[%s] Ignore DELETE_IN_PROGRESS", streamName); + } else if (rc.IsError()) { + // Any error except delete-in-progress + return rc; + } + bool needsRollback = true; + Raii unsetDelete([this, &needsRollback, &streamName]() { + StreamManagerMap::const_accessor accessor; + if (needsRollback && GetStreamManager(streamName, accessor).IsOk()) { + std::shared_ptr streamManager = accessor->second; + LOG(INFO) << FormatString("[S:%s] Context Undo, Setting Active State", streamName); + streamManager->SetActiveState(); + } + }); + // Now we need an exclusive accessor lock when we remove the stream from the tbb map + accessor.release(); + LOG(INFO) << FormatString("[%s, S:%s] Starting to delete stream due to DeleteStreamContext", LogPrefix(true), + streamName); + { + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + StreamManagerMap::accessor xAccessor; + LOG(INFO) << FormatString("[%s, S:%s] Waiting for stream to be free of use", LogPrefix(true), streamName); + // At this point + // Master should have checked stream dont have any consumers or producers + // we should have set delete state to avoid new operations after master call + // So this should be fast + bool success = streamMgrDict_.find(xAccessor, streamName); + // Earlier topo change should have removed the stream from RWM. But no harm to call it again + // but just expect the error stream not found + RETURN_IF_NOT_OK_EXCEPT(remoteWorkerManager_->DeleteStream(streamName), K_SC_STREAM_NOT_FOUND); + // If stream not found ignore and return ok + if (success) { + success = streamMgrDict_.erase(xAccessor); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + success, K_RUNTIME_ERROR, FormatString("Failed erase stream %s from streamMgrDict_", streamName)); + } + } + LOG(INFO) << FormatString("[%s, S:%s] DelStreamContext (Notified by master to clear stream data) success.", + LogPrefix(true), streamName); + + // Blocked request contain shared_ptr to StreamManager. To prevent a memory leak, we need to clear all blocked + // request in StreamManager. + streamManager->ClearBlockedList(); + + needsRollback = false; + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::GetWorkerStub(const HostPort &workerHostPort, + std::shared_ptr &stub) +{ + std::lock_guard lock(remotePubStubMutex_); + auto workerAddr = workerHostPort.ToString(); + auto it = remotePubStubs_.find(workerAddr); + if (it == remotePubStubs_.end()) { + RpcCredential cred; + RETURN_IF_NOT_OK(RpcAuthKeyManager::CreateCredentials(WORKER_SERVER_NAME, cred)); + auto channel = std::make_shared(workerHostPort, cred); + stub = std::make_shared(channel); + remotePubStubs_.emplace(workerAddr, stub); + } else { + stub = it->second; + } + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::ResetStreams( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + ResetOrResumeStreamsReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + + const auto &streamNamesRepeated = req.stream_names(); + std::unordered_set streams{ streamNamesRepeated.begin(), streamNamesRepeated.end() }; + + LOG(INFO) << "ResetStreams request start, clientId: " << req.client_id() + << ", streams: " << VectorToString(streams); + // Divide Streams into 2 Lists based on current Reset state + std::unordered_set doneList, errorList; + + // Start Reset Streams for all streams + for (auto &stream : streams) { + // Get Stream Manager for the stream + auto namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, stream); + LOG(INFO) << "Begin to clear data for stream Id: " << namespaceUri; + StreamManagerMap::const_accessor accessor; + Status rc = GetStreamManager(namespaceUri, accessor); + if (rc.IsError()) { + LOG(ERROR) << "Could not get stream manager for stream: " << namespaceUri; + // If stream manager is not found, the stream is considered to be reset. + doneList.insert(stream); + continue; + } + std::shared_ptr streamMgr = accessor->second; + // If stream is getting deleted dont allow Reset on it + rc = streamMgr->CheckIfStreamActive(); + if (rc.GetCode() == StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS) { + LOG(ERROR) << "Delete in progress for stream: " << namespaceUri; + // If stream is getting deleted, the stream is considered to be reset. + doneList.insert(stream); + continue; + } + + std::vector prodConList; + (void)GetPubSubForClientStream(req.client_id(), namespaceUri, prodConList); + // Send Reset Start Request + rc = streamMgr->ResetStreamStart(prodConList); + if (rc.IsError()) { + // Check if reset completed or is in progress by another client. + if (streamMgr->CheckIfStreamInState(StreamState::RESET_COMPLETE) + || streamMgr->CheckIfStreamInState(StreamState::DELETE_IN_PROGRESS)) { + // Reset is done for the stream + doneList.insert(stream); + } else { + LOG(ERROR) << rc.GetMsg(); + errorList.insert(stream); + } + continue; + } + // Reset Stream Start Successful + doneList.insert(stream); + } + + // We are done starting up Reset and Now we wait for Reset to be done + RETURN_IF_NOT_OK(ResetStreamsReply(serverApi, streams.size(), doneList.size(), errorList.size())); + LOG(INFO) << "ResetStreams request end"; + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::ResetStreamsReply( + std::shared_ptr> serverApi, + size_t streamsSize, size_t doneListSize, size_t errListSize) +{ + Status retStatus; + // We got OK for all streams requested + ResetOrResumeStreamsRspPb rsp; + if (doneListSize == streamsSize) { // We got all the streams reset + VLOG(SC_INTERNAL_LOG_LEVEL) << "ResetStreams Done"; + retStatus = Status::OK(); + serverApi->Write(rsp); + } else if (errListSize != 0) { // Some of them has errors + retStatus = Status(K_RUNTIME_ERROR, "Got error while resetting stream"); + CheckErrorReturn(retStatus, rsp, "Reset failed with rc ", serverApi); + } else { + LOG(ERROR) << "Size of completed streams is not same as total resetting streams"; + } + // We will not reply if its timeout, its handled by the client + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::ResumeStreams( + std::shared_ptr> serverApi) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ValidateWorkerState(), "validate worker state failed"); + ResetOrResumeStreamsReqPb req; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Read(req), "serverApi read request failed"); + + std::string tenantId; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(Authenticate(akSkManager_, req, tenantId, req.client_id()), + "Authenticate failed."); + + const auto &streamNamesRepeated = req.stream_names(); + std::unordered_set streams{ streamNamesRepeated.begin(), streamNamesRepeated.end() }; + LOG(INFO) << "ResumeStreams request start, clientId: " << req.client_id() + << ", streams: " << VectorToString(streams); + Status rc = Status::OK(); + for (const auto &stream : streams) { + auto namespaceUri = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, stream); + LOG(INFO) << "Begin to resume stream " << namespaceUri; + StreamManagerMap::const_accessor accessor; + // If StreamManager is not found, may be stream is deleted while reset was in progress. Log error, return OK. + if (GetStreamManager(namespaceUri, accessor).IsOk()) { + std::shared_ptr streamMgr = accessor->second; + Status status = streamMgr->ResumeStream(); + if (status.IsError()) { + LOG(ERROR) << status.GetMsg(); + rc = status; + } + } else { + LOG(ERROR) << "Could not get stream manager for stream: " << namespaceUri; + } + } + ResetOrResumeStreamsRspPb rsp; + CheckErrorReturn(rc, rsp, "Reset failed with rc ", serverApi); + LOG(INFO) << "ResumeStreams request end"; + return Status::OK(); +} + +std::string ClientWorkerSCServiceImpl::GetSCRemoteSendSuccessRate() const +{ + if (remoteWorkerManager_ == nullptr) { + return ""; + } + return remoteWorkerManager_->GetSCRemoteSendSuccessRate(); +} + +void ClientWorkerSCServiceImpl::WaitForAckTask(std::deque &ackList, const uint64_t waitTimeS) +{ + typedef std::chrono::duration millisecond; + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + constexpr static int K_FUTURE = 0; + constexpr static int K_NAME = 1; + constexpr static int K_TIME = 2; + auto iter = ackList.begin(); + while (iter != ackList.end()) { + if (interrupt_) { + return; + } + auto streamName = std::get(*iter); + auto &fut = std::get(*iter); + auto now = std::chrono::high_resolution_clock::now(); + auto duration = now - std::get(*iter); + // Wait for them to come back + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s] Wait for AckCursors thread execution", streamName); + // Report warning if the ack thread is blocked for a long time + auto realWaitTimeS = waitTimeS > std::numeric_limits::max() + ? std::chrono::seconds::max() + : std::chrono::seconds(waitTimeS); + auto status = fut.wait_for(realWaitTimeS); + INJECT_POINT("AutoAckImpl.WaitAndRetry", [&status]() mutable { + status = std::future_status::timeout; + return; + }); + if (status == std::future_status::timeout) { + LOG(WARNING) << FormatString( + "[S:%s] Waited for %zu second and AckCursors thread hasn't returned yet. Will try again", streamName, + realWaitTimeS.count()); + // Save this future for the next round and move on. + ++iter; + continue; + } + const Status rc = fut.get(); + LOG_IF_EVERY_N(ERROR, rc.IsError(), logPerCount) + << FormatString("[S:%s] Auto Ack error %s", streamName, rc.ToString()); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s] Done AckCursors thread execution. Execution time [%6lf]ms", + streamName, + std::chrono::duration_cast(duration).count()); + iter = ackList.erase(iter); + } +} + +void ClientWorkerSCServiceImpl::AutoAckImpl(std::deque &ackList, const uint64_t waitTimeS) +{ + constexpr static int K_NAME = 1; + // Get a list of those streams that hasn't ack back yet from previous round. + std::set streamNoAckResponse; + for (const auto &ele : ackList) { + streamNoAckResponse.insert(std::get(ele)); + } + // If some streams do not ack back from previous round, keep them + // in the FIFO queue and preserve their orders in the queue. + // Another reason is these streams are already holding the + // const accessor. We will deadlock ourselves. + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + // Get a snapshot of all the streams without holding the lock for too long + std::vector streams = GetStreamNameList(); + if (!streams.empty()) { + LOG_EVERY_N(INFO, logPerCount) << FormatString("Begin to process AutoAckImpl ack logic for %d streams", + streams.size() - streamNoAckResponse.size()); + } + auto traceId = Trace::Instance().GetTraceID(); + for (auto const &streamName : streams) { + if (interrupt_) { + return; + } + if (streamNoAckResponse.count(streamName) > 0) { + continue; // Already handled. + } + auto ackFunc = [this, streamName, traceId]() { + auto traceGuard = Trace::Instance().SetTraceNewID(traceId); + auto rlock = std::make_unique(); + RETURN_IF_NOT_OK(GetStreamManager(streamName, *rlock)); + auto streamMgr = (*rlock)->second; + return streamMgr->AckCursors(); + }; + ackList.emplace_back(ackPool_->Submit(ackFunc), streamName, std::chrono::high_resolution_clock::now()); + } + // Wait for the tasks to come back. + WaitForAckTask(ackList, waitTimeS); + VLOG(SC_INTERNAL_LOG_LEVEL) << "Done to process AutoAckImpl ack logic"; +} + +ClientWorkerSCServiceImpl::~ClientWorkerSCServiceImpl() +{ + EraseFailedNodeApiEvent::GetInstance().RemoveSubscriber(CLIENT_WORKER_SC_SERVICE_IMPL); + interrupt_ = true; + if (autoAck_.valid()) { + autoAck_.get(); + } +} + +void ClientWorkerSCServiceImpl::EraseFailedWorkerMasterApi(HostPort &masterAddr) +{ + workerMasterApiManager_->EraseFailedWorkerMasterApi(masterAddr, StubType::WORKER_MASTER_SC_SVC); +} + +Status ClientWorkerSCServiceImpl::GetPubSubForClientStream(const std::string &clientId, const std::string &streamName, + std::vector &prodConList) +{ + prodConList.clear(); + bool found = false; + std::shared_lock locker(clearMutex_); + // First collect all the producer Ids if there exists any for the given client. + auto iter = clientProducers_.find(clientId); + if (iter != clientProducers_.end()) { + for (auto &streamProducer : iter->second) { + if (streamProducer.streamName_ == streamName) { + prodConList.emplace_back(streamProducer.producerId_); + } + } + found = true; + } + + // Next collect all the consumer Ids if there exists any for the given client. + auto iter2 = clientConsumers_.find(clientId); + if (iter2 != clientConsumers_.end()) { + for (auto &subInfo : iter2->second) { + if (subInfo.streamName == streamName) { + prodConList.emplace_back(subInfo.consumerId); + } + } + found = true; + } + if (!found) { + RETURN_STATUS(K_NOT_FOUND, FormatString("No producer or consumer found for the client: %s", clientId)); + } + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::ReserveMemoryFromUsageMonitor(const std::string &streamName, size_t reserveSize) +{ + auto workerWorkerSCServicePtr = workerWorkerSCService_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(workerWorkerSCServicePtr, K_RUNTIME_ERROR, + FormatString("WorkerWorkerSCService shutdown")); + RETURN_IF_NOT_OK(workerWorkerSCServicePtr->GetUsageMonitor().ReserveMemory(streamName, reserveSize)); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::UndoReserveMemoryFromUsageMonitor(const std::string &streamName) +{ + auto workerWorkerSCServicePtr = workerWorkerSCService_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(workerWorkerSCServicePtr, K_RUNTIME_ERROR, + FormatString("WorkerWorkerSCService shutdown")); + workerWorkerSCServicePtr->GetUsageMonitor().UndoReserveMemory(streamName); + return Status::OK(); +} + +template <> +Status ClientWorkerSCServiceImpl::HandleBlockedCreateTimeout( + const std::string &streamName, const std::string &producerId, const std::string &traceId, int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime) +{ + INJECT_POINT("ClientWorkerSCServiceImpl.HandleBlockedCreateTimeout.sleep"); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + LOG(INFO) << "Blocked CreateShmPage request timer expired. Return OOM to client for stream " << streamName + << " with producer " << producerId << " and timeout " << subTimeout; + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + RETURN_IF_NOT_OK(streamMgr->CheckIfStreamActive()); + streamMgr->HandleBlockedCreateTimeout(producerId, subTimeout, startTime); + return Status::OK(); +} + +template <> +Status ClientWorkerSCServiceImpl::HandleBlockedCreateTimeout( + const std::string &streamName, const std::string &producerId, const std::string &traceId, int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); + LOG(INFO) << "Blocked AllocBigElement request timer expired. Return OOM to client for stream " << streamName + << " with producer " << producerId << " and timeout " << subTimeout; + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + RETURN_IF_NOT_OK(streamMgr->CheckIfStreamActive()); + streamMgr->HandleBlockedCreateTimeout(producerId, subTimeout, startTime); + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::StreamNoToName(uint64_t streamNo, std::string &streamName) +{ + std::shared_lock locker(mappingMutex_); + auto iter = streamNum2StreamName_.find(streamNo); + CHECK_FAIL_RETURN_STATUS(iter != streamNum2StreamName_.end(), K_SC_STREAM_NOT_FOUND, + FormatString("Stream number %zu not found", streamNo)); + streamName = iter->second; + return Status::OK(); +} + +Status ClientWorkerSCServiceImpl::AddStreamNo(uint64_t streamNo, const std::string &streamName) +{ + std::lock_guard locker(mappingMutex_); + bool success = streamNum2StreamName_.emplace(streamNo, streamName).second; + CHECK_FAIL_RETURN_STATUS(success, K_RUNTIME_ERROR, FormatString("Duplicated stream number %zu", streamNo)); + return Status::OK(); +} + +void ClientWorkerSCServiceImpl::RemoveStreamNo(uint64_t streamNo) +{ + std::lock_guard locker(mappingMutex_); + streamNum2StreamName_.erase(streamNo); +} + +template +BlockedCreateRequest::BlockedCreateRequest(std::string streamName, const R &req, size_t reqSz, + std::shared_ptr> serverApi, + BlockedCreateReqFn fn) + : req_(req), + reqSize_(reqSz), + serverApi_(std::move(serverApi)), + startTime_(std::chrono::steady_clock::now()), + retryCount_(0), + streamName_(std::move(streamName)), + traceId_(Trace::Instance().GetTraceID()), + callBackFn_(fn), + ack_(AckVal::NONE), + timer_(nullptr), + timeSpent_(req_.sub_timeout()) +{ + // Set up a return rc. If we have chance to execute this request, the defaultRc_ will be + // overridden by AllocMemory call. Otherwise, this request will then time out, and we return OOM. + if (req_.sub_timeout() > 0) { + auto subTimeout = req_.sub_timeout(); + auto producerId = req_.producer_id(); + defaultRc_ = Status(K_OUT_OF_MEMORY, + FormatString("[S:%s, P:%s] Blocked CreateShmPage request timer expired. timeout %d.", + streamName_, producerId, subTimeout)); + } +} + +template +void BlockedCreateRequest::SetTimer(std::unique_ptr timer) +{ + timer_ = std::move(timer); +} + +template +void BlockedCreateRequest::CancelTimer() +{ + if (timer_) { + INJECT_POINT("do.not.cancel.timer", [] { return; }); + (void)TimerQueue::GetInstance()->Cancel(*timer_); + timer_.reset(); + } +} + +template +int64_t BlockedCreateRequest::GetRemainingTimeMs() +{ + return timeSpent_.GetRemainingTimeMs(); +} + +template +R BlockedCreateRequest::GetCreateRequest() const +{ + return req_; +} + +template +Status BlockedCreateRequest::SendStatus(const Status &rc) +{ + if (serverApi_) { + return serverApi_->SendStatus(rc); + } else { + wp_.Set(); + return rc; + } +} + +template +Status BlockedCreateRequest::Write() +{ + if (serverApi_) { + return serverApi_->Write(rsp_); + } else { + // Before wake up the receiver, flip the atomic variable from 0 to 1 + uint32_t val = AckVal::NONE; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(ack_.compare_exchange_strong(val, AckVal::DONE), K_RUNTIME_ERROR, + FormatString("Unexpected CAS error from 0 -> 1. Current val %d", val)); + wp_.Set(); + } + return Status::OK(); +} + +template +Status BlockedCreateRequest::Wait(uint64_t timeoutMs) +{ + auto success = wp_.WaitFor(timeoutMs); + if (success) { + // Map the return code. Depends on the timing, we may get K_NOT_FOUND + // when the memory was allocated but then undone because this thread is slow + // to get to the ack phase. + switch (defaultRc_.GetCode()) { + case K_OK: + case K_OUT_OF_MEMORY: + return defaultRc_; + default: + RETURN_STATUS(K_TRY_AGAIN, defaultRc_.GetMsg()); + } + } + return { StatusCode::K_TRY_AGAIN, FormatString("Waited for %zu ms. No response", timeoutMs) }; +} + +template +Status BlockedCreateRequest::SenderHandShake() +{ + // Can't do handshake with RPC caller. + RETURN_OK_IF_TRUE(serverApi_); + INJECT_POINT("BlockedCreateRequest.ReceiverHandShake.Rollback"); + // Wait for a few ms for the Receiver to acknowledge + auto okay = ackWp_.WaitFor(RPC_POLL_TIME); + // No need to do anything if we are waken up by the receiver + RETURN_OK_IF_TRUE(okay); + // Now we flip back from 1 to 0. + uint32_t expectedVal = AckVal::DONE; + bool success = ack_.compare_exchange_strong(expectedVal, AckVal::NONE); + // There is a still a chance that the receiver wake up the same time as us. + // If the receiver can flip the ack from 1 to 0, we still treat the whole handshake + // as successful. That means, if our CAS fail, we treat it as handshake success + RETURN_OK_IF_TRUE(!success); + // Lastly we inform the caller the receiver has gone. + return { StatusCode::K_NOT_FOUND, "Receiver has gone" }; +} + +template +Status BlockedCreateRequest::ReceiverHandShake() +{ + // Can't do handshake with RPC caller. + RETURN_OK_IF_TRUE(serverApi_); + INJECT_POINT("BlockedCreateRequest.ReceiverHandShake.sleep"); + // Drive a compare and swap. We expect it is 1, and will flip it back to 0. + // If it is 0, that means the sender has already undone the memory allocation + uint32_t expectedVal = AckVal::DONE; + bool success = ack_.compare_exchange_strong(expectedVal, AckVal::NONE); + ackWp_.Set(); + RETURN_OK_IF_TRUE(success); + return { StatusCode::K_TRY_AGAIN, "Sender has undone the changes" }; +} + +template +Status BlockedCreateRequest::HandleBlockedCreateTimeout() +{ + if (req_.sub_timeout() > 0 && GetRemainingTimeMs() == 0) { + auto elapsed = static_cast(timeSpent_.ElapsedMilliSecond()); + LOG(ERROR) << FormatString("[S:%s, P:%s] RPC timeout. time elapsed %zu, subTimeout: %zu", streamName_, + req_.producer_id(), elapsed, req_.sub_timeout()); + if (serverApi_) { + LOG_IF_ERROR(serverApi_->SendStatus(defaultRc_), "send status failed"); + } else { + wp_.Set(); + } + return defaultRc_; + } + return Status::OK(); +} + +template +Status BlockedCreateRequest::operator()() +{ + ++retryCount_; + defaultRc_ = callBackFn_(this); + return defaultRc_; +} + +template +bool BlockedCreateRequest::HasStartTime(const std::chrono::steady_clock::time_point &startTime) +{ + return startTime == startTime_; +} + +template +bool BlockedCreateRequest::HasRequestPbOlderThan(const BlockedCreateRequest &inputRequest) +{ + uint64_t inputRequestPbTimestamp = inputRequest.req_.timestamp(); + uint64_t currentRequestPbTimestamp = req_.timestamp(); + LOG(INFO) << FormatString( + "Producer: %s, Other request: traceid %s timestamp %llu, " + "This request: traceid %s timestamp %llu", + req_.producer_id(), inputRequest.traceId_, inputRequestPbTimestamp, traceId_, currentRequestPbTimestamp); + return currentRequestPbTimestamp < inputRequestPbTimestamp; +} + +template +Status MemAllocRequestList::AddBlockedCreateRequest(ClientWorkerSCServiceImpl *scSvc, + std::shared_ptr> blockedReq) +{ + auto subTimeout = blockedReq->GetRemainingTimeMs(); + const auto req = blockedReq->GetCreateRequest(); + const auto &producerId = req.producer_id(); + const auto streamName = blockedReq->streamName_; + const std::chrono::steady_clock::time_point startTime = blockedReq->startTime_; + VLOG(SC_NORMAL_LOG_LEVEL) << "Adding a blocked request to the blocked queue for stream " << streamName + << " with producer " << producerId << " and timeout " << subTimeout; + std::unique_lock lock(blockedListMutex_); + std::shared_ptr> savedReq = blockedReq; // save the ptr for later + // Add the entry to the blocked list and queue + auto it = blockedList_.find(producerId); + if (it != blockedList_.end()) { + // There is an old request by the same producer not processed. + LOG(WARNING) << FormatString("Duplicate blocked create page requests for stream %s with producer %s", + streamName, producerId); + Status rc = Status(K_DUPLICATED, FormatString("[S:%s, P:%s] Receive a newer request from producer, this " + "request become expired and not processed.", + streamName, producerId)); + if (it->second->HasRequestPbOlderThan(*blockedReq)) { + // We have a new request with newer timestamp, remove the old request. + // Get old request + std::shared_ptr> oldBlockedReq = std::move(it->second); + + // Cancel the timer for the old request since we will take action here + oldBlockedReq->CancelTimer(); + + // Remove the old request + // A memory request is stored in two data structures. One is a hashed map and one is a priority queue. + // We can't take out one but leaves the other. + (void)blockedList_.erase(it); + RemoveBlockedCreateRequestFromQueueLocked(oldBlockedReq.get()); + + // Send status for the old request. + LOG_IF_ERROR(oldBlockedReq->SendStatus(rc), "Send status to client failed"); + } else { + // We have a new request with older timestamp than the request currently in the blocked list, + // do not process the new request with older timestamp. + LOG_IF_ERROR(blockedReq->SendStatus(rc), "Send status to client failed"); + return Status::OK(); + } + } + + // Now any old request by the same producer has been processed, insert the new request. + bool success; + std::tie(std::ignore, success) = blockedList_.emplace(producerId, std::move(blockedReq)); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + success, StatusCode::K_RUNTIME_ERROR, + FormatString("Fail to insert BlockedCreateRequest for stream %s with producer %s and timeout %lld", streamName, + producerId, subTimeout)); + queue_.template emplace(savedReq.get()); + // Early exit if we don't have to create a timer + RETURN_OK_IF_TRUE(req.sub_timeout() == 0); + // Create a timer to throw OOM and clean up if the timer expires + TimerQueue::TimerImpl timer; + auto traceID = savedReq->traceId_; + INJECT_POINT("AddBlockedCreateRequest.subTimeout", [&subTimeout]() mutable { + subTimeout = 0; + return Status::OK(); + }); + // If the timer has expired already, let StreamManager::HandleBlockedRequestImpl handle it. No need + // to create a timer. + RETURN_OK_IF_TRUE(subTimeout == 0); + RETURN_IF_NOT_OK(TimerQueue::GetInstance()->AddTimer( + subTimeout, + [scSvc, streamName, producerId, traceID, subTimeout, startTime]() { + LOG_IF_ERROR( + (scSvc->HandleBlockedCreateTimeout(streamName, producerId, traceID, subTimeout, startTime)), + "HandleBlockedCreateTimeout"); + }, + timer)); + + savedReq->SetTimer(std::make_unique(timer)); + return Status::OK(); +} + +template +void MemAllocRequestList::RemoveBlockedCreateRequestFromQueueLocked( + const BlockedCreateRequest *blockedReqPtr) +{ + // The tricky part is the priority queue which we can only pop the top. + std::vector *> list; + while (!queue_.empty()) { + auto *ptr = queue_.top(); + queue_.pop(); + if (ptr == blockedReqPtr) { + break; + } else { + // requests that we must put back + list.push_back(ptr); + } + } + for (auto p : list) { + queue_.push(p); + } +} + +template +void MemAllocRequestList::HandleBlockedCreateTimeout(const std::string &streamName, const std::string &producerId, + int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime) +{ + std::lock_guard lock(blockedListMutex_); + auto it = blockedList_.find(producerId); + if (it == std::end(blockedList_) || !(it->second->HasStartTime(startTime))) { + // A race between a thread doing free vs this timeout. The other thread won so this is a no-op + // Log it for information only. + LOG(INFO) << "Blocked CreateShmPage request timer expired. The page was already handled for stream " + << streamName << " with producer " << producerId << " and timeout " << subTimeout; + return; + } + // Return original OOM here and log timer expired message into worker log + std::shared_ptr> blockedReq = std::move(it->second); + auto rc = blockedReq->defaultRc_; + LOG(ERROR) << FormatString("[S:%s, P:%s] timeout %zu, %s", streamName, producerId, subTimeout, rc.ToString()); + // A memory request is stored in two data structures. One is a hashed map and one is a priority queue. + // We can't take out one but leaves the other. + (void)blockedList_.erase(it); + RemoveBlockedCreateRequestFromQueueLocked(blockedReq.get()); + LOG_IF_ERROR(blockedReq->SendStatus(rc), "Send status to client failed"); +} + +template +Status MemAllocRequestList::GetBlockedCreateRequest(std::shared_ptr> &out) +{ + INJECT_POINT("GetBlockedCreateRequest.sleep"); + std::lock_guard lock(blockedListMutex_); + CHECK_FAIL_RETURN_STATUS(!queue_.empty(), StatusCode::K_TRY_AGAIN, "No outstanding memory request"); + auto *blockedReq = queue_.top(); + queue_.pop(); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(blockedReq->traceId_); + const auto req = blockedReq->GetCreateRequest(); + const auto producerId = req.producer_id(); + const std::string streamName = blockedReq->streamName_; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[S:%s, P:%s] Handle alloc memory request", streamName, producerId); + auto it = blockedList_.find(producerId); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + it != blockedList_.end(), K_RUNTIME_ERROR, + FormatString("[S:%s, P:%s] Alloc memory request not found", streamName, producerId)); + // Cancel the timer for this entry since we will take action here + it->second->CancelTimer(); + out = std::move(it->second); + (void)blockedList_.erase(it); + return Status::OK(); +} + +template +bool MemAllocRequestList::Empty() +{ + std::shared_lock lock(blockedListMutex_); + return queue_.empty(); +} + +template +size_t MemAllocRequestList::Size() +{ + std::shared_lock lock(blockedListMutex_); + return queue_.size(); +} + +template +size_t MemAllocRequestList::GetNextBlockedRequestSize() +{ + std::shared_lock lock(blockedListMutex_); + if (queue_.empty()) { + return 0; + } + return queue_.top()->reqSize_; +} + +template +void MemAllocRequestList::ClearBlockedList() +{ + std::shared_lock lock(blockedListMutex_); + blockedList_.clear(); + queue_ = std::priority_queue *, std::vector *>, Compare>(); +} + +StreamManagerWithLock::StreamManagerWithLock(std::shared_ptr mgr, void *accessor, bool exclusive, + std::shared_ptr service) + : mgr_(std::move(mgr)), accessor_(accessor), exclusive_(exclusive), service_(std::move(service)) +{ + rlock_ = std::make_unique(LOCK_ARGS_MSG_FN(service_->mutex_, service_->LogPrefix)); +} + +StreamManagerWithLock::~StreamManagerWithLock() +{ + Release(); +} + +void StreamManagerWithLock::Release() +{ + if (accessor_) { + if (exclusive_) { + auto accessor = + std::unique_ptr(reinterpret_cast(accessor_)); + accessor->release(); + } else { + auto accessor = std::unique_ptr( + reinterpret_cast(accessor_)); + accessor->release(); + } + accessor_ = nullptr; + } + if (rlock_->owns_lock()) { + rlock_->unlock(); + } +} + +void StreamManagerWithLock::CleanUp(std::function &&callback) +{ + if (!needCleanUp || !exclusive_) { + return; + } + + if (!rlock_->owns_lock()) { + rlock_->AcquireLock(); + } + if (accessor_ != nullptr) { + auto accessor = + std::unique_ptr(reinterpret_cast(accessor_)); + callback(accessor.get()); + accessor_ = nullptr; + } else { + callback(nullptr); + } + rlock_->unlock(); +} + +void ClientWorkerSCServiceImpl::EraseFromStreamMgrDictWithoutLck(const std::string &namespaceUri, + StreamManagerMap::accessor *accessor) +{ + if (accessor) { + streamMgrDict_.erase(*accessor); + } else { + streamMgrDict_.erase(namespaceUri); + } +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/client_worker_sc_service_impl.h b/src/datasystem/worker/stream_cache/client_worker_sc_service_impl.h new file mode 100644 index 0000000..e2f97fa --- /dev/null +++ b/src/datasystem/worker/stream_cache/client_worker_sc_service_impl.h @@ -0,0 +1,1058 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_CLIENT_WORKER_SC_SERVICE_IMPL_H +#define DATASYSTEM_WORKER_STREAM_CACHE_CLIENT_WORKER_SC_SERVICE_IMPL_H + +#include +#include +#include +#include +#include +#include +#include + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/lock_map.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/master/stream_cache/master_sc_service_impl.h" +#include "datasystem/protos/stream_posix.service.rpc.pb.h" +#include "datasystem/protos/stream_posix.stub.rpc.pb.h" +#include "datasystem/utils/optional.h" +#include "datasystem/worker/stream_cache/remote_worker_manager.h" +#include "datasystem/worker/stream_cache/stream_producer.h" +#include "datasystem/worker/stream_cache/worker_master_sc_api.h" +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" +#include "datasystem/worker/authenticate.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class MasterWorkerSCServiceImpl; +class WorkerWorkerSCServiceImpl; +class UsageMonitor; + +/** + * @brief A simple class to save a AllocMemory Request so that it can be executed at a later time. + */ +template +class BlockedCreateRequest { +public: + enum AckVal : uint32_t { NONE = 0, DONE = 1 }; + using BlockedCreateReqFn = std::function *)>; + + /** + * @brief Constructor for the blocked request. Initializes an elapsed timer start time. + * @param[in] req The request details to save + * @param[in] serverApi The UnaryWriterReader api to associate with the request. + */ + BlockedCreateRequest(std::string streamName, const R &req, size_t reqSz, + std::shared_ptr> serverApi, BlockedCreateReqFn fn); + + /** + * @brief default destructor + */ + ~BlockedCreateRequest() = default; + + /** + * @brief Sets the reference to the timer queue entry for this request + */ + void SetTimer(std::unique_ptr timer); + + /** + * @brief Cancels the running timer queue entry for this request + */ + void CancelTimer(); + + /** + * @brief Uses the internal timer to compute how much time has passed, and then subtract that from the input timeout + * arg. + * @return the new computed timeout value + */ + int64_t GetRemainingTimeMs(); + + /** + * @brief A getter function for the request info (deep copy return) + * @return A copy of the request info + */ + R GetCreateRequest() const; + + Status SendStatus(const Status &rc); + + Status Write(); + + Status Wait(uint64_t timeoutMs); + + Status SenderHandShake(); + + Status ReceiverHandShake(); + + /** + * @brief Handle a timeout request + * @return + */ + Status HandleBlockedCreateTimeout(); + + /** + * @brief Functor to execute the call back + */ + Status operator()(); + + /** + * @brief Check if the BlockedCreateRequest is created at startTime + */ + bool HasStartTime(const std::chrono::steady_clock::time_point &startTime); + + /** + * @brief Check if the internal request pb is older than the input request pb. + * Require inputRequest is made by the same producer as this request. + */ + bool HasRequestPbOlderThan(const BlockedCreateRequest &inputRequest); + + W rsp_; + R req_; + size_t reqSize_; + std::shared_ptr> serverApi_; + const std::chrono::steady_clock::time_point startTime_; + std::atomic retryCount_; + std::string streamName_; + std::string traceId_; + Status defaultRc_; + BlockedCreateReqFn callBackFn_; + WaitPost wp_; // If serverApi is null + WaitPost ackWp_; // Handshake + std::atomic_uint32_t ack_; // Handshake + +private: + std::unique_ptr timer_; + Timer timeSpent_; // Start time initialized at construction time +}; + +using CreatePubSubCtrl = LockMap; +using ProduceGrpByStreamList = std::unordered_map>; + +/** + * A class to wrap StreamManager with an accessor which can be an exclusive or a read accessor + * Used only in CreateProducer/Subscribe api + */ +class StreamManagerWithLock { +public: + StreamManagerWithLock(std::shared_ptr mgr, void *accessor, bool exclusive, + std::shared_ptr service); + ~StreamManagerWithLock(); + void Release(); + + /** + * @brief If necessary, clean up any effects that occurred during the lock period. + * @param[in] callback The implementation of cleanup work. + */ + void CleanUp(std::function &&callback); + + bool needCleanUp = true; + std::shared_ptr mgr_; + +private: + using ReadLockHelperType = ReadLockHelper>; + + void *accessor_; + bool exclusive_; + std::shared_ptr service_; + std::unique_ptr rlock_; +}; +class ClientWorkerSCServiceImpl : public ClientWorkerSCService, + public std::enable_shared_from_this { +public: + /** + * @brief Construct the rpc service of ClientWorkerSCServiceImpl. + * @param[in] serverAddr The address of worker. + * @param[in] masterAddr The address of master. + * @param[in] masterSCService The master service. + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + ClientWorkerSCServiceImpl(HostPort serverAddr, HostPort masterAddr, master::MasterSCServiceImpl *masterSCService, + std::shared_ptr akSkManager, + std::shared_ptr manager); + + /** + * @brief Init the service. + * @return Status of the call. + */ + Status Init() override; + + ~ClientWorkerSCServiceImpl() override; + + /** + * @brief Check if there are tasks to be processed + * @return T/F + */ + bool HaveTasksToProcess() + { + return remoteWorkerManager_->HaveTasksToProcess(); + } + + /** + * @brief Create a producer, i.e., register a publisher to a stream. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CreateProducer( + std::shared_ptr> serverApi) override; + + /** + * @brief Close a producer, force flushing and page seal, unregister a publisher to a stream. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducer( + std::shared_ptr> serverApi) override; + + /** + * @brief Subscribe to a stream, using a subscription name, i.e., register a consumer to a subscription. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status Subscribe(std::shared_ptr> serverApi) override; + + /** + * @brief Close a consumer, trigger subscription cursor change and unregister a subscribed consumer to a stream. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CloseConsumer( + std::shared_ptr> serverApi) override; + + /** + * @brief Create a stream page and get its related shared memory meta to perform zero-copy send. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CreateShmPage( + std::shared_ptr> serverApi) override; + + Status GetDataPage(std::shared_ptr> serverApi) override; + + template + void AsyncSendMemReq(const std::string &namespaceUri); + + template + Status HandleBlockedRequestImpl(const std::string &streamName); + + /** + * @brief Delete stream manager and related sessions when it is not used anymore. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status DeleteStream(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp) override; + + /** + * @brief Query producer count in global scope for one stream + * @param[in] req The rpc request protobuf + * @param[out] rsp The rpc response protobuf + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalProducersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) override; + + /** + * @brief Query consumer count in global scope for one stream + * @param[in] req The rpc request protobuf + * @param[out] rsp The rpc response protobuf + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalConsumersNum(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp) override; + + /** + * @brief Closes all of the producers and consumers for a given client when client crashes + * - Any consumers of the producers from this client will cease to function and give error SC_PRODUCER_NOT_FOUND. + * - Any in-progress data flowing from the closed producers will be dropped/freed and not sent to the remote. + * @param[in] clientId The ID of client. + * @return Status of the call. + */ + Status ClosePubSubForClientLost(const std::string &clientId); + + /** + * @brief Unlock mem view on all pages for given streams. + * @param[in] streams The stream name list. + * @param[in] lockId The lock id. + */ + void ForceUnlockMemViemForPages(const std::set &streams, uint32_t lockId); + + /** + * @brief Get the Stream Metadata object + * @param[in] streamName The stream name. + * @param[out] meta The rpc protobuf for stream metadata + * @return K_OK on success; the error code otherwise. + */ + Status GetStreamMetadata(const std::string &streamName, GetStreamMetadataRspPb *meta); + + /** + * @brief Collect the producer consumer metadata for the given list of producers and consumers. + * @param[in] localProducers The list of producers from this worker for the given stream + * @param[in] localConsumers The list of consumers from this worker for the given stream + * @param[in] meta The rpc protobuf for stream metadata + * @param[in] streamName The stream name of the producers and consumers + * @param[in] hostPortPb The protobuf message for the address of this worker + */ + void GetProducerConsumerMetadata(std::vector &localProducers, + std::vector> &localConsumers, + GetStreamMetadataRspPb *meta, const std::string &streamName, + HostPortPb &hostPortPb); + + /** + * @brief Get metadata for all streams for the requesting master. + * @param[in] masterAddr The GetMetadataAllStreamReqPb request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status GetAllStreamMetadata(const GetMetadataAllStreamReqPb &req, GetMetadataAllStreamRspPb &rsp); + + /** + * @brief Send metadata for all streams for the requesting master using the stream rpc. + * @param[in] req The GetMetadataAllStreamReqPb request protobuf. + * @param[in/out] streamRpc Used to read request from master and write response to master. + * @return K_OK on success; the error code otherwise. + */ + Status SendAllStreamMetadata( + const GetMetadataAllStreamReqPb &req, + std::shared_ptr> &streamRpc); + + /** + * @brief Unblock producer sending stream + * @param[in] serverApi The api used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status UnblockProducer( + std::shared_ptr> serverApi) override; + + /** + * @brief Blocks producer sending stream + * @param[in] serverApi The api used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status BlockProducer( + std::shared_ptr> serverApi) override; + + /** + * @brief Get a pointer to the threadpool + * @return raw pointer to the threadpool + */ + ThreadPool *GetThreadPool() + { + return threadPool_.get(); + } + + /** + * @brief Return the number of streams. + * @return Usage: numStream. + */ + std::string GetTotalStreamCount(); + + Status GetStreamManager(const std::string &streamName, StreamManagerMap::const_accessor &accessor); + + /** + * @brief Blocks a remote producer that belongs to a stream + * @param[in] streamName Stream Name. + * @param[in] remoteWorkerAddr Remote Worker address. + * @return K_OK on success; the error code otherwise. + */ + Status SendBlockProducerReq(const std::string &streamName, const std::string &remoteWorkerAddr); + + /** + * @brief UnBlocks a remote producer that belongs to a stream + * @param[in] streamName Stream Name. + * @param[in] remoteWorkerAddr Remote Worker address. + * @return K_OK on success; the error code otherwise. + */ + Status SendUnBlockProducerReq(const std::string &streamName, const std::string &remoteWorkerAddr); + + /** + * @brief Get last append cursor in worker consumer. + * @param[in] req The LastAppendCursorReqPb Request. + * @param[out] rsp The LastAppendCursorRspPb Response. + * @return Status of the call. + */ + Status GetLastAppendCursor(const LastAppendCursorReqPb &req, LastAppendCursorRspPb &rsp) override; + + /** + * @brief The main delete stream driver is DeleteStream() call in the worker. This call DeleteStreamContext() is an + * internal call when the master sends a delete to the worker and MasterWorkerSCServiceImpl wants to delete the + * stream. + * @param[in] streamName is the name/key into the streamManagerDict_ to erase + * @param[in] forceDelete Force deletion + * @param[in] timeout delete exits if exceeded the timeout + * @return Status of the call + */ + Status DeleteStreamContext(const std::string &streamName, bool forceDelete, int64_t timeout); + + /** + * @brief Helper function to create a client stub for worker. + * @param[in] workerHostPort worker Address. + * @param[out] stub creates a client stub. + * @return K_OK on success; the error code otherwise. + */ + Status GetWorkerStub(const HostPort &workerHostPort, std::shared_ptr &stub); + + /** + * @brief Cleanup cached data and metadata for the requested streams. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status ResetStreams(std::shared_ptr> + serverApi) override; + + /** + * @brief Resume streams to allow regular data flow. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status ResumeStreams(std::shared_ptr> + serverApi) override; + + /** + * @brief Allocate shared memory for big element insert + */ + Status AllocBigShmMemory( + std::shared_ptr> serverApi) override; + + /** + * @brief Release big element + */ + Status ReleaseBigShmMemory( + std::shared_ptr> serverApi) override; + + /** + * @brief Setter method for assigning worker-worker service + * @param[in] impl The pointer to worker-worker stream cache service + */ + void SetWorkerWorkerSCServiceImpl(std::weak_ptr impl) + { + workerWorkerSCService_ = impl; + } + + /** + * @brief Setter method for assigning cluster manager + * @param[in] cm The pointer to etcd cluster manager + */ + void SetClusterManager(EtcdClusterManager *cm) + { + etcdCM_ = cm; + } + + /** + * @brief erase failed worker master api. + * @param[in] masterAddr failed master addr. + */ + void EraseFailedWorkerMasterApi(HostPort &masterAddr); + + /** + * @brief Get remote worker manager + */ + auto GetRemoteWorkerManager() + { + return remoteWorkerManager_.get(); + } + + /** + * @brief Reserve memory from the usage monitor. + * @return Status of the call. + */ + Status ReserveMemoryFromUsageMonitor(const std::string &streamName, size_t reserveSize); + + /** + * @brief Undo the memory reservation from the usage monitor. + * @return Status of the call. + */ + Status UndoReserveMemoryFromUsageMonitor(const std::string &streamName); + + /** + * @brief Obtain the success rate of sending data to the remote worker. + * @return The success rate string. + */ + std::string GetSCRemoteSendSuccessRate() const; + + /** + * @brief Called by TimerQueue for expired memory allocation request. + * @return + */ + template + Status HandleBlockedCreateTimeout(const std::string &streamName, const std::string &producerId, + const std::string &traceId, int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime) + { + (void)streamName; + (void)producerId; + (void)traceId; + (void)subTimeout; + (void)startTime; + return Status::OK(); + } + + /** + * @brief Convert from stream number to the corresponding stream name. + * @param[in] streamNo The stream number. + * @param[out] streamName The stream name. + * @return Status of the call. + */ + Status StreamNoToName(uint64_t streamNo, std::string &streamName); + + /** + * @brief Record the stream number to stream name mapping. + * @param[in] streamNo The stream number. + * @param[out] streamName The stream name. + * @return Status of the call. + */ + Status AddStreamNo(uint64_t streamNo, const std::string &streamName); + + /** + * @brief Take stream number out from the mapping. + * @param[in] streamNo The stream number to remove. + */ + void RemoveStreamNo(uint64_t streamNo); + +private: + /** + * @brief Get the stream name list. + * @return The stream name list. + */ + std::vector GetStreamNameList(); + + /** + * @brief Check workers health status + * @return K_OK on success; the error code otherwise. + */ + Status ValidateWorkerState(); + + /** + * @brief Create stream manager if not exist. + * @param[in] streamName The name of the stream. + * @param[in] streamFields Optional argument to pre-assign stream fields after stream construction. + * @param[out] streamManager The output stream manager. + * @param[out] streamExisted True if the stream already existed and a new one was not created. + * @return K_OK on success; the error code otherwise. + */ + Status CreateStreamManagerImpl(const std::string &streamName, const Optional &streamFields, + StreamManagerMap::accessor &accessor); + Status CreateStreamManagerIfNotExist(const std::string &streamName, const Optional &streamFields, + std::shared_ptr &streamMgrWithLock, + bool &streamExisted); + + /** + * @brief Get current log with local worker address. + * @param[in] withAddress This value is used to decide whether to add local address, default is false. + * @return The head of log. + */ + std::string LogPrefix(bool withAddress = false) const; + + /** + * @brief Close a producer, force flushing and page seal, unregister a publisher to a stream. + * @param[in] producerId The producer id. + * @param[in] streamName The stream name. + * @param[in] notifyMaster Notify master or not. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducerImpl(const std::string &producerId, const std::string &streamName, bool notifyMaster); + + /** + * @brief Close a list of producers, force flushing and page seal, unregister a publisher to a stream. + * @param[in] lockId The lock id. + * @param[in/out] producerList A list of StreamProducers + * On success, the producerList will be empty. If any of the producers failed to close, this list will contain + * the producers that did not close successfully. + * @return K_OK on success; the error code otherwise. In the case of multiple producers getting errors, the + * returned error will be the first error that was encountered. + */ + Status CloseProducerImplForceClose(uint32_t lockId, std::list &producerList); + + /** + * @brief Helper function to send CloseProducer request through worker to master api. + * @param[in] api The stream cache worker to master api. + * @param[in/out] streamList A list of streams that failed the request to master. + * @param[in] forceClose Force close in worker. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducerHandleSend(std::shared_ptr api, std::list &streamList, + bool forceClose); + + /** + * @brief Helper function for the list version of CloseProducerImpl. + * @param[in/out] producerList A list of StreamProducers that failed the request to master. + * On success, the producerList will be empty. If any of the producers failed to close, this list will contain + * the producers that did not close successfully. + * @param[in/out] successList A list of StreamProducers that were successful requests. + * @param[in] forceClose Force close in worker. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducerHandleFailure(std::list &producerList, std::list &successList, + bool forceClose); + + /** + * @brief A helper function to lookup the stream and close each producer from the input list. + * @param[in/out] producerList A list of producers to close in the stream manager. Successfully remove entries + * will be removed from the producerList, leaving only the unsuccessful ones in the producerList. + * @param[in] forceClose If the pub node had a crash or regular close. + * @return Status of the call. + */ + Status CloseStreamProducerList(std::list &producerList, bool forceClose); + + /** + * @brief When Client Crashes, gets stream manager const accessor and unlocks the page lock + * @param[in] producerList List of producers that have crashed + * @param[in] lockId lock id for the producer + * @param[out] producersGrpStreamName producerList grouped by stream names + * @return K_OK on success; the error code otherwise. + */ + Status UnlockAndProtect(std::list &producerList, uint32_t lockId, + ProduceGrpByStreamList &producersGrpStreamName); + + /** + * @brief Closes all producers locally and gets list of streams that needs master notifications + * @param[in] producerList List of producers that have crashed + * @param[out] streamListForNotifications Set of streams that got there last producer closed + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducerLocallyOnForceClose(std::list &producerList, + std::set &streamListForNotifications); + + /** + * @brief Sends master notifications to all streams in the list + * @param[in] streamList Set of streams that have no producers + * @param[out] failedList List of streams that failed to send master notifications + * @return K_OK on success; the error code otherwise. + */ + Status SendBatchedCloseProducerReq(std::set &streamList, std::vector &failedList); + + /** + * @brief Close a consumer, trigger subscription cursor change and unregister a subscribed consumer to a stream. + * @param[in] consumerId The consumer id. + * @param[in] streamName The stream name. + * @param[in] subName The subscription name. + * @param[in] notifyMaster Notify master or not. + * @param[in] lockId The lockId for client. + * @param[in] forceClose Force close in worker. + * @return K_OK on success; the error code otherwise. + */ + Status CloseConsumerImpl(const std::string &consumerId, const std::string &streamName, const std::string &subName, + bool notifyMaster, uint32_t lockId, bool forceClose = false); + + /** + * @brief Check connection to master. + * @param[in] streamName The stream name. + * @return K_OK on success; the error code otherwise. + */ + Status CheckConnection(const std::string &streamName); + + /** + * @brief Helper function of the CreateProducer logic. + * @param[in] api The stream cache worker to master api. + * @param[in] streamName The stream to be constructed. + * @param[out] streamFields Stream fields after stream construction. + * @return Status of the call. + */ + Status CreateProducerHandleSend(std::shared_ptr api, const std::string &streamName, + const Optional &streamFields); + + /** + * @brief Implementation of the CreateProducer logic. + * @param[in] namespaceUri The stream name for the producer + * @param[in] req The request info for the create producer + * @return Status of the call + */ + Status CreateProducerImpl(const std::string &namespaceUri, const CreateProducerReqPb &req, + CreateProducerRspPb &rsp); + + /** + * @brief Implementation of the Subscribe logic. + * @param[in] namespaceUri The stream name + * @param[in] req The request info for the subscribe + * @param[in] rsp The response info from the subscribe + * @return Status of the call + */ + Status SubscribeImpl(const std::string &namespaceUri, const SubscribeReqPb &req, SubscribeRspPb &rsp); + + /** + * @brief Implementation of the Ack logic + * @param[in] streamName The name of the stream + * @param[in] streamManager The stream manager of the consumer to ack with. + * @param[in] subscription The subscription to use for the Ack + * @param[in] consumerId The id of the consumer to use for the ack + * @param[in] elementId The cursor position for the Ack + * @return K_OK on success; the error code otherwise. + */ + Status AckImpl(const std::string &streamName, std::shared_ptr streamManager, + std::shared_ptr subscription, const std::string &consumerId, uint64_t elementId); + + /** + * Implementation function for auto ack + */ + using AckTask = std::tuple, std::string, std::chrono::high_resolution_clock::time_point>; + void AutoAckImpl(std::deque &ackList, const uint64_t waitTimeS); + void WaitForAckTask(std::deque &ackList, const uint64_t waitTimeS); + + /** + * @brief Send reply for streams to get reset. + * @param[in] serverApi Used to read request from client and write response to client. + * @param[in] streamsSize The size of list of streams to get reset. + * @param[in] doneListSizse The sizse of streams already completed the reset operation. + * @param[in] errListSize The size of streams encountered an error while doing the reset. + * @return K_OK on success; the error code otherwise. + */ + static Status ResetStreamsReply( + std::shared_ptr> serverApi, + size_t streamsSize, size_t doneListSize, size_t errListSize); + + /** + * @brief Get the primary replica addr + * @param[in] srcAddr The source address. + * @param[out] destAddr The dest address. + * @return Status of this call. + */ + Status GetPrimaryReplicaAddr(const std::string &srcAddr, HostPort &destAddr); + + /** + * @brief Retry and redirect + * @tparam Req Request to master + * @tparam Rsp Response of master + * @param[in] req Request of redirect + * @param[out] rsp Response of redirect + * @param[in] workerMasterApi worker master api + * @param[in] fun Create update or copy meta to master. + * @return + */ + template + Status RedirectRetryWhenMetaMoving(Req &req, Rsp &rsp, std::shared_ptr &workerMasterApi, + std::function fun) + { + CHECK_FAIL_RETURN_STATUS(fun != nullptr, K_RUNTIME_ERROR, "function is nullptr"); + while (reqTimeoutDuration.CalcRealRemainingTime() > 0) { + RETURN_IF_NOT_OK(fun(req, rsp)); + if (rsp.info().redirect_meta_address().empty()) { + return Status::OK(); + } else if (!rsp.meta_is_moving()) { + HostPort newMetaAddr; + RETURN_IF_NOT_OK(GetPrimaryReplicaAddr(rsp.info().redirect_meta_address(), newMetaAddr)); + LOG(INFO) << "meta has been migrated to the new master[%s]" << newMetaAddr.ToString(); + RETURN_IF_NOT_OK_APPEND_MSG(workerMasterApiManager_->GetWorkerMasterApi(newMetaAddr, workerMasterApi), + "hash master get failed, RedirectRetryWhenMetaMoving failed"); + if (etcdCM_->MultiReplicaEnabled()) { + req.set_redirect(false); + RETURN_IF_NOT_OK(fun(req, rsp)); + return Status::OK(); + } + } + static const int sleepTimeMs = 200; + rsp.Clear(); + std::this_thread::sleep_for(std::chrono::milliseconds(sleepTimeMs)); + } + return Status(K_RPC_DEADLINE_EXCEEDED, "Rpc timeout"); + } + + /** + * @brief Retry when meta is moving + * @param Rsp Response of redirect + * @param Req Request of redirect + * @param rsp Response of redirect + * @param fun Query or delete request to master + * @return + */ + template + Status RedirectRetryWhenMetasMoving(Req &req, Rsp &rsp, std::shared_ptr &workerMasterApi, + std::function fun) + { + CHECK_FAIL_RETURN_STATUS(fun != nullptr, K_RUNTIME_ERROR, "function is nullptr"); + while (reqTimeoutDuration.CalcRealRemainingTime() > 0) { + RETURN_IF_NOT_OK(fun(req, rsp)); + if (rsp.info().empty()) { + return Status::OK(); + } else if (!rsp.meta_is_moving()) { + HostPort newMetaAddr; + RETURN_IF_NOT_OK(GetPrimaryReplicaAddr(rsp.info(0).redirect_meta_address(), newMetaAddr)); + LOG(INFO) << "meta has been migrated to the new master[%s]" << newMetaAddr.ToString(); + workerMasterApi = workerMasterApiManager_->GetWorkerMasterApi(newMetaAddr); + CHECK_FAIL_RETURN_STATUS(workerMasterApi != nullptr, K_RUNTIME_ERROR, + "hash master get failed, RedirectRetryWhenMetaMoving failed"); + if (etcdCM_->MultiReplicaEnabled()) { + req.set_redirect(false); + RETURN_IF_NOT_OK(fun(req, rsp)); + return Status::OK(); + } + } + static const int sleepTimeMs = 200; + rsp.Clear(); + std::this_thread::sleep_for(std::chrono::milliseconds(sleepTimeMs)); + } + return Status(K_RPC_DEADLINE_EXCEEDED, "Rpc timeout"); + } + + /** + * @brief Construct protobuf struct for producer + * @param[in] streamName The stream to be constructed. + * @param[in] streamFields Optional argument to pre-assign stream fields after stream construction.. + * @param[out] out The output CreateProducerReqPb. + */ + void ConstructCreateProducerPb(const std::string &streamName, const Optional &streamFields, + master::CreateProducerReqPb &out) const noexcept; + + /** + * @brief Construct protobuf struct for the CloseProducerReqPb + * @param[in/out] streamList is the list of streams to add into the request. The input list will be + * erased after this call. + * @param[in] forceClose T/F if force close mode will be used + * @param[out] req The close producer request structure that is now populated with data. + */ + void ConstructCloseProducerReq(std::list &streamList, bool forceClose, + master::CloseProducerReqPb &req) const noexcept; + + /** + * @brief Parse the CloseProducerRspPb to check it for error. If error, populate the failed list with the failed + * producers and then return the rc of the failure. If no errors, failedList will remain unchanged. + * @param[out] failedList The list of failed streams if there was an error. + * @param[in] rsp The response pb to parse for errors. + * @return The error code from the response, otherwise OK + */ + Status HandleCloseProducerRsp(std::list &failedList, const master::CloseProducerRspPb &rsp) const; + + /** + * @brief Construct protobuf struct for consumer. + * @param[in] streamName The stream to be constructed. + * @param[in] consumerId The id of consumer. + * @param[in] lastAckCursor The cursor of last ack. + * @param[in] config The config of the Subscription. + * @param[in] clientId The client id. + * @param[out] out The output ConsumerMetaPb. + */ + void ConstructConsumerMetaPb(const std::string &streamName, const std::string &consumerId, uint64_t lastAckCursor, + const SubscriptionConfig &config, const std::string &clientId, + ConsumerMetaPb &out) const noexcept; + + /** + * @brief Helper function to send Subscribe request through worker to master api. + * @param[in] streamMgr The stream manager for the stream. + * @param[in] streamName The stream to be constructed. + * @param[in] consumerId The id of consumer. + * @param[in] lastAckCursor The cursor of last ack. + * @param[in] config The config of the Subscription. + * @param[in] clientId The client id. + * @param[out] streamFields Stream fields after stream construction. + * @param[out] masterAddress The master address. + * @return K_OK on success; the error code otherwise. + */ + Status SubscribeHandleSend(std::shared_ptr streamMgr, const std::string &streamName, + const std::string &consumerId, uint64_t lastAckCursor, const SubscriptionConfig &config, + const std::string &clientId, Optional &streamFields, + std::string &masterAddress); + + /** + * @brief Check if the given address is the master for the testing stream. + * @param[in] streamName The name of the testing stream + * @param[in] masterAddr The given master address for the testing stream. + * @param[in] hashRanges The given hash ranges for the testing stream. + * @return True if the master address for the stream is found and matches with the given address, False otherwise. + */ + bool CheckConditionsForStream(const std::string &streamName, const std::string &masterAddr, + const worker::HashRange &hashRanges); + + /** + * @brief Get all the producers and consumers for a client created on the given stream. + * @param[in] clientId The requesting client Id. + * @param[in] streamName The stream for which pubsub list should be returned. + * @param[out] prodConList The list of producer and consumer Ids for the requesting client. + * @return The status of the call. + */ + Status GetPubSubForClientStream(const std::string &clientId, const std::string &streamName, + std::vector &prodConList); + + /** + * @brief Cleanup all the data belong to the stream locally + * @param[in] streamManager stream manager for the stream + * @return The status of the call. + */ + Status DeleteStreamLocally(StreamManagerMap::accessor &accessor); + + /** + * @brief Create a producer, i.e., register a publisher to a stream. + * @param[in] req The request instance. + * @param[in] recorder The access recorder instance. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CreateProducerInternal( + const CreateProducerReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi); + + /** + * @brief Close a producer, force flushing and page seal, unregister a publisher to a stream. + * @param[in] req The request instance. + * @param[in] recorder The access recorder instance. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducerInternal( + const CloseProducerReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi); + + /** + * @brief Subscribe to a stream, using a subscription name, i.e., register a consumer to a subscription. + * @param[in] req The request instance. + * @param[in] recorder The access recorder instance. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status SubscribeInternal(const SubscribeReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi); + + /** + * @brief Close a consumer, trigger subscription cursor change and unregister a subscribed consumer to a stream. + * @param[in] req The request instance. + * @param[in] recorder The access recorder instance. + * @param[in] serverApi Used to read request from client and write response to client. + * @return K_OK on success; the error code otherwise. + */ + Status CloseConsumerInternal( + const CloseConsumerReqPb &req, std::shared_ptr recorder, + std::shared_ptr> serverApi); + + /** + * @brief Delete stream manager and related sessions when it is not used anymore. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status DeleteStreamImpl(const DeleteStreamReqPb &req, DeleteStreamRspPb &rsp); + + /** + * @brief Send delete stream request. + * @param[in] streamName The stream name. + * @return K_OK on success; the error code otherwise. + */ + Status DeleteStreamHandleSend(const std::string &streamName); + + /** + * @brief Query producer count in global scope for one stream + * @param[in] req The rpc request protobuf + * @param[out] rsp The rpc response protobuf + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalProducersNumImpl(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp); + + /** + * @brief Query consumer count in global scope for one stream + * @param[in] req The rpc request protobuf + * @param[out] rsp The rpc response protobuf + * @return K_OK on success; the error code otherwise. + */ + Status QueryGlobalConsumersNumImpl(const QueryGlobalNumReqPb &req, QueryGlobalNumRsqPb &rsp); + + Status PostCreateStreamManager(const std::shared_ptr &streamManager, + const Optional &streamFields, bool reserveShm); + + /** + * @brief Erase from streamMgrDict_ without lock. + * @param[in] namespaceUri The key to be erased. + * @param[in] accessor The accessor of streamMgrDict_. + */ + void EraseFromStreamMgrDictWithoutLck(const std::string &namespaceUri, StreamManagerMap::accessor *accessor); + + friend class MasterWorkerSCServiceImpl; // They share the stream data on local worker node + friend class StreamManagerWithLock; + + std::unique_ptr remoteWorkerManager_{ nullptr }; + std::shared_timed_mutex mutex_; // protect streamMgrDict_. + StreamManagerMap streamMgrDict_; + std::atomic lifetimeLocalStreamCount_{ 0 }; + std::shared_timed_mutex mappingMutex_; // protect streamNum2StreamName_. + std::unordered_map streamNum2StreamName_; + CreatePubSubCtrl createStreamLocks_; + + std::shared_ptr> workerMasterApiManager_{ nullptr }; + + HostPort localWorkerAddress_; + HostPort masterAddress_; + + std::shared_timed_mutex clearMutex_; // Protect requests success when other client crash. + struct SubInfo { + SubInfo(std::string streamName, std::string subName, std::string consumerId) + : streamName(std::move(streamName)), subName(std::move(subName)), consumerId(std::move(consumerId)) + { + } + std::string streamName; + std::string subName; + std::string consumerId; + }; + std::map> clientProducers_; + std::map> clientConsumers_; + std::shared_ptr scAllocateManager_; + std::weak_ptr workerWorkerSCService_; + std::shared_ptr akSkManager_; + std::shared_ptr threadPool_{ nullptr }; + std::shared_ptr memAllocPool_{ nullptr }; + std::shared_ptr ackPool_{ nullptr }; + // For remote workers/producers + std::mutex remotePubStubMutex_; + std::unordered_map> remotePubStubs_; + std::atomic interrupt_; + std::future autoAck_; + EtcdClusterManager *etcdCM_{ nullptr }; // back pointer to the cluster manager +}; + +template <> +Status ClientWorkerSCServiceImpl::HandleBlockedCreateTimeout( + const std::string &streamName, const std::string &producerId, const std::string &traceId, int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime); + +template <> +Status ClientWorkerSCServiceImpl::HandleBlockedCreateTimeout( + const std::string &streamName, const std::string &producerId, const std::string &traceId, int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime); + +template +class MemAllocRequestList { +public: + Status AddBlockedCreateRequest(ClientWorkerSCServiceImpl *scSvc, + std::shared_ptr> blockedReq); + void RemoveBlockedCreateRequestFromQueueLocked(const BlockedCreateRequest *blockedReqPtr); + void HandleBlockedCreateTimeout(const std::string &streamName, const std::string &producerId, int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime); + Status GetBlockedCreateRequest(std::shared_ptr> &out); + bool Empty(); + size_t Size(); + size_t GetNextBlockedRequestSize(); + auto GetNextStartTime() + { + if (queue_.empty()) { + return std::chrono::steady_clock::now(); + } + return queue_.top()->startTime_; + } + void ClearBlockedList(); + +private: + std::shared_timed_mutex blockedListMutex_; + struct Compare { + bool operator()(const BlockedCreateRequest *a, const BlockedCreateRequest *b) + { + return a->startTime_ > b->startTime_; + } + }; + std::unordered_map>> blockedList_; + std::priority_queue *, std::vector *>, Compare> queue_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/consumer.cpp b/src/datasystem/worker/stream_cache/consumer.cpp new file mode 100644 index 0000000..f25621f --- /dev/null +++ b/src/datasystem/worker/stream_cache/consumer.cpp @@ -0,0 +1,132 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/consumer.h" + +#include +#include +#include "datasystem/common/iam/tenant_auth_manager.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/stream_cache/consumer.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +Consumer::Consumer(std::string id, uint64_t lastAckCursor, std::string streamName, std::shared_ptr cursor) + : id_(std::move(id)), + streamName_(std::move(streamName)), + initialCursor_(lastAckCursor), + pendingRecv_(nullptr), + cursor_(std::move(cursor)) +{ + // Put the last ack cursor in the work area + UpdateWALastAckCursor(initialCursor_); +} + +Consumer::~Consumer() = default; + +Status Consumer::AddPendingReceive(uint64_t lastRecvCursor, uint64_t timeoutMs, const std::function &recvFunc, + std::shared_ptr> stream) +{ + PerfPoint point(PerfKey::WORKER_CONSUMER_ADD_PENDING_RECV); + std::unique_lock lock(mutex_); + if (pendingRecv_) { + RETURN_STATUS_LOG_ERROR(StatusCode::K_DUPLICATED, + FormatString("The consumer %s already had pending receive request.", id_)); + } + INJECT_POINT("worker.stream.before_add_pending"); + TimerQueue::TimerImpl timer; + RETURN_IF_NOT_OK(TimerQueue::GetInstance()->AddTimer(timeoutMs, recvFunc, timer)); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Add pending receive request with timer:id %zu", LogPrefix(), + timer.GetId()); + pendingRecv_ = std::make_unique(); + pendingRecv_->lastRecvCursor = lastRecvCursor; + pendingRecv_->timer = std::make_unique(timer); + pendingRecv_->stream = std::move(stream); + pendingRecv_->start = std::chrono::steady_clock::now(); + pendingRecv_->wakeupPendingRecvOnProdFault = false; + return Status::OK(); +} + +void Consumer::RemovePendingReceive(bool &wakeupPendingRecvOnProdFault) +{ + std::lock_guard lock(mutex_); + RemovePendingReceiveNoLock(wakeupPendingRecvOnProdFault); +} + +void Consumer::RemovePendingReceiveNoLock(bool &wakeupPendingRecvOnProdFault) +{ + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Remove pending receive request with timer:id %zu", LogPrefix(), + pendingRecv_->timer->GetId()); + auto nowTime = std::chrono::steady_clock::now(); + uint64_t elapsed = std::chrono::duration_cast(nowTime - pendingRecv_->start).count(); + PerfPoint::RecordElapsed(PerfKey::PENDING_RECV_WAIT_TIME, elapsed); + wakeupPendingRecvOnProdFault = pendingRecv_ != nullptr && pendingRecv_->wakeupPendingRecvOnProdFault; + pendingRecv_ = nullptr; +} + +Status Consumer::WakeUpPendingReceive(uint64_t lastAppendCursor) +{ + std::lock_guard lock(mutex_); + // If this consumer has pendingRecv. + // And its lastRecvCursor + expectRecvNum <= TargetStream.lastAppendCursor(Enough data to receive). + if (pendingRecv_ != nullptr) { + auto lastRecvCursor = pendingRecv_->lastRecvCursor; + if (lastRecvCursor <= lastAppendCursor) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Do EraseAndExecTimer for timer:id %zu", LogPrefix(), + pendingRecv_->timer->GetId()); + (void)TimerQueue::GetInstance()->EraseAndExecTimer(*pendingRecv_->timer); + } + } + return Status::OK(); +} + +std::string Consumer::LogPrefix() const +{ + return FormatString("C:%s", id_); +} + +Status Consumer::SetForceClose() +{ + cursor_->SetForceClose(); + std::lock_guard lock(mutex_); + if (pendingRecv_ != nullptr) { + pendingRecv_->wakeupPendingRecvOnProdFault = true; + CHECK_FAIL_RETURN_STATUS(TimerQueue::GetInstance()->EraseAndExecTimer(*pendingRecv_->timer), + StatusCode::K_RUNTIME_ERROR, + FormatString("Consumer %s failed to erase and exec timer on producer failure", id_)); + } + return Status::OK(); +} + +void Consumer::CleanupConsumer() +{ + std::lock_guard lock(mutex_); + if (pendingRecv_ != nullptr) { + (void)TimerQueue::GetInstance()->Cancel(*pendingRecv_->timer); + pendingRecv_.reset(); + } + initialCursor_ = 0; + if (cursor_) { + LOG_IF_ERROR(cursor_->Init(), FormatString("[%s] CleanupConsumer", LogPrefix())); + cursor_->UpdateWALastAckCursor(0); + } +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/consumer.h b/src/datasystem/worker/stream_cache/consumer.h new file mode 100644 index 0000000..af1f61e --- /dev/null +++ b/src/datasystem/worker/stream_cache/consumer.h @@ -0,0 +1,168 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_CONSUMER_H +#define DATASYSTEM_WORKER_STREAM_CACHE_CONSUMER_H + +#include "datasystem/common/eventloop/timer_queue.h" +#include "datasystem/common/stream_cache/cursor.h" +#include "datasystem/protos/stream_posix.service.rpc.pb.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class Consumer { +public: + /** + * @brief Construct Consumer. + * @param[in] id The consumer id. + * @param[in] cursorSet The cursor class. + */ + Consumer(std::string id, uint64_t lastAckCursor, std::string streamName, std::shared_ptr cursor); + virtual ~Consumer(); + + struct PendingReceive { + uint64_t lastRecvCursor; + std::unique_ptr timer; + std::shared_ptr> stream; + std::chrono::time_point start; + bool wakeupPendingRecvOnProdFault; + }; + + /** + * @brief Add the pending receive to the consumer without locker protect, it should be call when the consumer can + * not read enough elements and the timeout is greater than 0. + * @param[in] lastRecvCursor The client last recv cursor. + * @param[in] timer The timer add to the TimerQueue. + * @param[in] stream The stream that used to write response to client. + * @return Status of the call. + */ + Status AddPendingReceive(uint64_t lastRecvCursor, uint64_t timeoutMs, const std::function &recvFunc, + std::shared_ptr> stream); + + /** + * @brief Remove the pending receive, it should be call after the callback function is executed. + */ + void RemovePendingReceive(bool &wakeupPendingRecvOnProdFault); + + /** + * @brief Remove the pending receive without lock protect, it should be call after the callback function is + * executed. + */ + void RemovePendingReceiveNoLock(bool &wakeupPendingRecvOnProdFault); + + /** + * @brief Wake up pending receive. + * @param[in] lastAppendCursor The last append cursor of the stream. + * @return Status of the call. + */ + Status WakeUpPendingReceive(uint64_t lastAppendCursor); + + /** + * @brief Get the consumer id. + * @return The consumer id. + */ + [[nodiscard]] std::string GetId() const + { + return id_; + } + + /** + * @brief Get log prefix. + * @return Log prefix. + */ + [[nodiscard]] virtual std::string LogPrefix() const; + + /** + * @brief Get the last ack cursor from the work area + * @return last ack cursor + */ + [[nodiscard]] uint64_t GetWALastAckCursor() const + { + return cursor_->GetWALastAckCursor(); + } + + /** + * @brief Update the last ack cursor from the work area + * @param[in] elementId The element Id. + * @return last ack cursor + */ + void UpdateWALastAckCursor(uint64_t elementId) const + { + cursor_->UpdateWALastAckCursor(elementId); + } + + /** + * Force a consumer when there is no producer + * @return + */ + Status SetForceClose(); + + /** + * @brief Get the element count and reset it to 0. + * @return The element count + */ + uint64_t GetElementCountAndReset() const + { + return cursor_->GetElementCountAndReset(); + } + + /** + * @brief Get the element count of the cursor + * @return The element count + */ + uint64_t GetElementCount() const + { + return cursor_->GetElementCount(); + } + + /** + * @brief Get the request count of the cursor + * @return The request count + */ + uint64_t GetRequestCountAndReset() const + { + return cursor_->GetRequestCountAndReset(); + } + + /** + * @brief Set the element count of the cursor + * @param val value to set element count to + */ + void SetElementCount(uint64_t val) const + { + return cursor_->SetElementCount(val); + } + + /** + * @brief Cleanup indexes and pending recv for this consumer. + */ + void CleanupConsumer(); + +private: + const std::string id_; + const std::string streamName_; + uint64_t initialCursor_; + std::mutex mutex_; // protect pendingRecv_ + std::unique_ptr pendingRecv_; + // A work area that is shared between the corresponding client::stream_cache::ConsumerImpl + // sz is the size of this work area. It is set up in the function ExclusivePageQueue::AddCursor + std::shared_ptr cursor_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/master_worker_sc_service_impl.cpp b/src/datasystem/worker/stream_cache/master_worker_sc_service_impl.cpp new file mode 100644 index 0000000..5886596 --- /dev/null +++ b/src/datasystem/worker/stream_cache/master_worker_sc_service_impl.cpp @@ -0,0 +1,305 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/master_worker_sc_service_impl.h" + +#include "datasystem/common/util/format.h" +#include "datasystem/common/log/log_helper.h" +#include "datasystem/common/util/strings_util.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +MasterWorkerSCServiceImpl::MasterWorkerSCServiceImpl(HostPort serverAddr, HostPort masterAddr, + ClientWorkerSCServiceImpl *clientSvc, + std::shared_ptr akSkManager) + : localWorkerAddress_(std::move(serverAddr)), + masterAddress_(std::move(masterAddr)), + clientWorkerSCSvc_(clientSvc), + akSkManager_(std::move(akSkManager)) +{ +} + +Status MasterWorkerSCServiceImpl::Init() +{ + CHECK_FAIL_RETURN_STATUS(clientWorkerSCSvc_ != nullptr, StatusCode::K_NOT_READY, + "ClientWorkerService must be initialized before MasterWorkerService construction"); + LOG(INFO) << FormatString("[%s, Master address:%s] Initialization success.", LogPrefix(true), + masterAddress_.ToString()); + return MasterWorkerSCService::Init(); +} + +Status MasterWorkerSCServiceImpl::AddRemoteConsumer(const std::shared_ptr &streamManager, + const ConsumerMetaPb &meta) +{ + const auto &workerAddr = meta.worker_address(); + const auto &consumerId = meta.consumer_id(); + const auto &streamName = meta.stream_name(); + HostPort workerAddress(workerAddr.host(), workerAddr.port()); + + SubscriptionConfig config(meta.sub_config().subscription_name(), + ToSubscriptionType(meta.sub_config().subscription_type())); + uint64_t lastAckCursor; + RETURN_IF_NOT_OK_EXCEPT(streamManager->AddRemoteSubNode(workerAddress, config, consumerId, lastAckCursor), + K_DUPLICATED); + auto &remoteWorkerManager = clientWorkerSCSvc_->remoteWorkerManager_; + RETURN_IF_NOT_OK(remoteWorkerManager->AddRemoteConsumer(streamManager, localWorkerAddress_, workerAddress, + streamName, config, consumerId, lastAckCursor)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] New Consumer success", LogPrefix(), + streamName, config.subscriptionName, consumerId, + workerAddress.ToString()); + return Status::OK(); +} + +Status MasterWorkerSCServiceImpl::DelRemoteConsumer(const std::shared_ptr &streamManager, + const ConsumerMetaPb &consumerMeta) +{ + const auto &workerAddr = consumerMeta.worker_address(); + const auto &consumerId = consumerMeta.consumer_id(); + const auto &streamName = consumerMeta.stream_name(); + HostPort workerAddress(workerAddr.host(), workerAddr.port()); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] DelSubNode begin", LogPrefix(), + streamName, consumerMeta.sub_config().subscription_name(), consumerId, + workerAddress.ToString()); + + RETURN_IF_NOT_OK_EXCEPT(streamManager->DelRemoteSubNode(workerAddress, consumerId), K_NOT_FOUND); + auto &remoteWorkerManager = clientWorkerSCSvc_->remoteWorkerManager_; + RETURN_IF_NOT_OK_EXCEPT(remoteWorkerManager->DelRemoteConsumer(workerAddress.ToString(), streamName, consumerId), + K_NOT_FOUND); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] DelSubNode success", LogPrefix(), + streamName, consumerMeta.sub_config().subscription_name(), consumerId, + workerAddress.ToString()); + return Status::OK(); +} + +Status MasterWorkerSCServiceImpl::SyncPubNode(const SyncPubNodeReqPb &req, SyncPubNodeRspPb &rsp) +{ + (void)rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + const std::string &streamName = req.stream_name(); + LOG(INFO) << FormatString("worker(%s) received SyncPubNode request, streamname: %s", localWorkerAddress_.ToString(), + streamName); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerSCSvc_->GetStreamManager(streamName, accessor), + "worker get streamManager failed"); + std::shared_ptr streamManager = accessor->second; + + auto pubNodeNum = req.worker_address_vector_size(); + + std::vector pubNodeSet; + pubNodeSet.reserve(pubNodeNum); + for (int i = 0; i < pubNodeNum; ++i) { + HostPort pubWorkerNode(req.worker_address_vector(i).host(), req.worker_address_vector(i).port()); + pubNodeSet.emplace_back(std::move(pubWorkerNode)); + } + Status rc = streamManager->SyncPubTable(pubNodeSet, req.is_reconciliation()); + LOG_IF(INFO, rc.IsError()) << "streamManager SyncPubTable failed: " << rc.ToString(); + RETURN_IF_NOT_OK(rc); + LOG(INFO) << FormatString("[%s, S:%s] SyncPubNode table success, Table size:%d", LogPrefix(true), streamName, + pubNodeNum); + return Status::OK(); +} + +Status MasterWorkerSCServiceImpl::SyncConsumerNode(const SyncConsumerNodeReqPb &req, SyncConsumerNodeRspPb &rsp) +{ + INJECT_POINT("MasterWorkerSCServiceImpl.SyncConsumerNode.sleep"); + (void)rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + const std::string &streamName = req.stream_name(); + LOG(INFO) << FormatString("worker(%s) received SyncConsumerNode request, streamname: %s", + localWorkerAddress_.ToString(), streamName); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerSCSvc_->GetStreamManager(streamName, accessor), + "worker get streamManager failed"); + std::shared_ptr streamManager = accessor->second; + RETURN_IF_NOT_OK(streamManager->CheckIfStreamActive()); + + auto consumerNum = req.consumer_meta_vector_size(); + + // If pubTable has at least 1 consumer node, we do the following synchronize. + std::vector consumerNodeSet; + consumerNodeSet.reserve(consumerNum); + for (int i = 0; i < consumerNum; ++i) { + // ConsumerMeta = (consumerId, workerNode, subConfig, lastAckCursor). + const auto &consumerMetaPb = req.consumer_meta_vector(i); + const std::string &consumerId(consumerMetaPb.consumer_id()); + HostPort workerNode(consumerMetaPb.worker_address().host(), consumerMetaPb.worker_address().port()); + const auto &subconfigPb = consumerMetaPb.sub_config(); + SubscriptionConfig subConfig(subconfigPb.subscription_name(), + ToSubscriptionType(subconfigPb.subscription_type())); + + consumerNodeSet.emplace_back(streamName, consumerId, workerNode, subConfig, consumerMetaPb.last_ack_cursor()); + } + // Do not start tables from scratch if it is triggered by reconciliation code path, + // duplicates are skipped instead. + bool isRecon = req.is_reconciliation(); + uint64_t lastAckCursor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamManager->SyncSubTable(consumerNodeSet, isRecon, lastAckCursor), + "streamManager SyncSubTable failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(SyncRemoteConsumer(streamManager, consumerNodeSet, lastAckCursor), + "worker SyncRemoteConsumer failed"); + streamManager->SetRetainData(req.retain_data()); + LOG(INFO) << FormatString("[%s, S:%s] SyncConsumerNode table success, Table size:%d", LogPrefix(true), streamName, + consumerNum); + return Status::OK(); +} + +Status MasterWorkerSCServiceImpl::ClearAllRemotePub(const ClearRemoteInfoReqPb &req, ClearRemoteInfoRspPb &rsp) +{ + (void)rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + const std::string &streamName = req.stream_name(); + LOG(INFO) << FormatString("worker(%s) received ClearAllRemotePub request, streamname: %s", + localWorkerAddress_.ToString(), streamName); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerSCSvc_->GetStreamManager(streamName, accessor), + "worker get streamManager failed"); + std::shared_ptr streamManager = accessor->second; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamManager->ClearAllRemotePub(), "streamManager ClearAllRemotePub failed"); + LOG(INFO) << "worker ClearAllRemotePub done, streamname: " << streamName; + return Status::OK(); +} + +Status MasterWorkerSCServiceImpl::ClearAllRemoteConsumer(const ClearRemoteInfoReqPb &req, ClearRemoteInfoRspPb &rsp) +{ + // ClearAllRemoteConsumer RPC deprecated. + INJECT_POINT("MasterWorkerSCServiceImpl.ClearAllRemoteConsumer.sleep"); + (void)rsp; + (void)req; + return Status::OK(); +} + +Status MasterWorkerSCServiceImpl::DelStreamContext(const DelStreamContextReqPb &req, DelStreamContextRspPb &rsp) +{ + (void)rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + const auto &streamName = req.stream_name(); + bool forceDelete = req.force_delete(); + auto timeout = req.timeout(); + LOG(INFO) << FormatString("worker(%s) received DelStreamContext request, streamname: %s, force delete: %s", + localWorkerAddress_.ToString(), streamName, forceDelete ? "true" : "false"); + + RETURN_IF_NOT_OK(clientWorkerSCSvc_->DeleteStreamContext(streamName, forceDelete, timeout)); + return Status::OK(); +} + +std::string MasterWorkerSCServiceImpl::LogPrefix(bool withAddress) const +{ + if (withAddress) { + return FormatString("MasterWorkerSvc, Node:%s", localWorkerAddress_.ToString()); + } else { + return "MasterWorkerSvc"; + } +} + +Status MasterWorkerSCServiceImpl::SyncRemoteConsumer(const std::shared_ptr &streamManager, + const std::vector &remoteConsumerSet, + uint64_t lastAckCursor) +{ + const std::string streamName = streamManager->GetStreamName(); + auto &remoteWorkerManager = clientWorkerSCSvc_->remoteWorkerManager_; + for (const auto &remoteConsumer : remoteConsumerSet) { + // Do sync part + RETURN_IF_NOT_OK(remoteWorkerManager->AddRemoteConsumer( + streamManager, localWorkerAddress_, remoteConsumer.WorkerAddress(), streamName, remoteConsumer.SubConfig(), + remoteConsumer.ConsumerId(), lastAckCursor)); + } + return Status::OK(); +} + +Status MasterWorkerSCServiceImpl::QueryMetadata( + std::shared_ptr> stream) +{ + LOG(INFO) << "worker(" << localWorkerAddress_.ToString() << ") receive QueryMetaData for all streams from master"; + GetMetadataAllStreamReqPb req; + RETURN_IF_NOT_OK(stream->Read(req)); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + Status status = clientWorkerSCSvc_->SendAllStreamMetadata(req, stream); + if (status.IsError() && status.GetCode() != K_RPC_STREAM_END) { + LOG(ERROR) << "Send response to stream failed: " << status.GetMsg(); + return status; + } + LOG(INFO) << "worker QueryMetaData for all streams done"; + return stream->Finish(); +} + +Status MasterWorkerSCServiceImpl::QueryMetadata(const GetMetadataAllStreamReqPb &req, GetMetadataAllStreamRspPb &rsp) +{ + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + return clientWorkerSCSvc_->GetAllStreamMetadata(req, rsp); +} + +Status MasterWorkerSCServiceImpl::UpdateTopoNotification(const UpdateTopoNotificationReq &req, + UpdateTopoNotificationRsp &rsp) +{ + (void)rsp; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(req), "AK/SK failed."); + const auto &streamName = req.stream_name(); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("Stream:<%s> UpdateTopoNotification req:%s", streamName, + LogHelper::IgnoreSensitive(req)); + LOG(INFO) << FormatString("worker(%s) received UpdateTopoNotification request, streamname: %s", + localWorkerAddress_.ToString(), streamName); + INJECT_POINT("MasterWorkerSCServiceImpl.UpdateTopoNotification.begin"); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(clientWorkerSCSvc_->GetStreamManager(streamName, accessor), + "worker get streamManager failed"); + std::shared_ptr streamManager = accessor->second; + for (const auto &pub : req.pubs()) { + const auto &workerAddr = pub.worker_addr(); + if (pub.is_close()) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamManager->HandleClosedRemotePubNode(pub.force_close()), + "streamManager HandleClosedRemotePubNode failed"); + } else { + Status rc = streamManager->AddRemotePubNode(workerAddr); + LOG_IF(INFO, rc.IsError()) << "streamManager AddRemotePubNode failed: " << rc.ToString(); + if (rc.IsError() && rc.GetCode() != K_DUPLICATED) { + RETURN_STATUS(rc.GetCode(), rc.GetMsg()); + } + StreamFields streamFields(pub.max_stream_size(), pub.page_size(), pub.auto_cleanup(), + pub.retain_num_consumer(), pub.encrypt_stream(), pub.reserve_size(), + pub.stream_mode()); + if (!streamFields.Empty()) { + RETURN_IF_NOT_OK(streamManager->UpdateStreamFields(streamFields, true)); + // There should be existing consumers for it to receive the notification request, + // update the reserved memory if applicable according to the page size + LOG_IF_ERROR(clientWorkerSCSvc_->ReserveMemoryFromUsageMonitor(streamName, streamFields.pageSize_), ""); + } + } + } + + for (const auto &sub : req.subs()) { + if (sub.is_close()) { + Status rc = DelRemoteConsumer(streamManager, sub.consumer()); + LOG_IF(INFO, rc.IsError()) << "worker DelRemoteConsumer failed: " << rc.ToString(); + RETURN_IF_NOT_OK(rc); + } else { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(AddRemoteConsumer(streamManager, sub.consumer()), + "worker AddRemoteConsumer failed"); + } + } + + // This request only comes when we have enough consumers + // Retain state will be set when producer is created + if (req.retain_data() == RetainDataState::State::NOT_RETAIN) { + streamManager->SetRetainData(req.retain_data()); + } + LOG(INFO) << "worker UpdateTopoNotification done, streamname: " << streamName; + return Status::OK(); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/master_worker_sc_service_impl.h b/src/datasystem/worker/stream_cache/master_worker_sc_service_impl.h new file mode 100644 index 0000000..9ca18af --- /dev/null +++ b/src/datasystem/worker/stream_cache/master_worker_sc_service_impl.h @@ -0,0 +1,160 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_MASTER_WORKER_SC_SERVICE_IMPL_H +#define DATASYSTEM_WORKER_STREAM_CACHE_MASTER_WORKER_SC_SERVICE_IMPL_H + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/common/stream_cache/consumer_meta.h" +#include "datasystem/protos/worker_stream.service.rpc.pb.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class MasterWorkerSCServiceImpl : public MasterWorkerSCService { +public: + /** + * @brief Construct MasterWorkerSCServiceImpl that is used to provide service for master. + * @param[in] serverAddr The worker address. + * @param[in] masterAddr The master address. + * @param[in] clientSvc The pointer of client call worker service. + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + MasterWorkerSCServiceImpl(HostPort serverAddr, HostPort masterAddr, ClientWorkerSCServiceImpl *clientSvc, + std::shared_ptr akSkManager); + ~MasterWorkerSCServiceImpl() override = default; + + /** + * @brief Init the service. + * @return Status of the call. + */ + Status Init() override; + + /** + * @brief Synchronize all remote pub node for target stream to current worker node. + * Invoked when first consumer occurs on current node. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status SyncPubNode(const SyncPubNodeReqPb &req, SyncPubNodeRspPb &rsp) override; + + /** + * @brief Synchronize all remote consumer node for target stream to current worker node. + * Invoked when first producer occurs on current node. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status SyncConsumerNode(const SyncConsumerNodeReqPb &req, SyncConsumerNodeRspPb &rsp) override; + + /** + * @brief Clear all remote pub node for target stream on current worker node. + * Invoked when last consumer disappears within current node. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status ClearAllRemotePub(const ClearRemoteInfoReqPb &req, ClearRemoteInfoRspPb &rsp) override; + + /** + * @brief Clear all remote consumer node for target stream on current worker node. + * Invoked when last producer disappears within current node. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status ClearAllRemoteConsumer(const ClearRemoteInfoReqPb &req, ClearRemoteInfoRspPb &rsp) override; + + /** + * @brief Delete stream context broadcast to this worker. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status DelStreamContext(const DelStreamContextReqPb &req, DelStreamContextRspPb &rsp) override; + + /** + * @brief Query meta data for all streams from worker. + * @param[in, out] stream The server reader writer session. + * @return K_OK on success; the error code otherwise. + */ + Status QueryMetadata( + std::shared_ptr> stream) override; + + /** + * @brief Query meta data from worker. This version simply populates the rsp for all the streams for to a master. + * @param[in] req The metadata request to lookup + * @param[out] rsp The metadata response populated with the results + * @return K_OK on success; the error code otherwise. + */ + Status QueryMetadata(const GetMetadataAllStreamReqPb &req, GetMetadataAllStreamRspPb &rsp); + + /** + * @brief Master notify worker to update Topo. + * @param[in] req The rpc request protobuf. + * @param[out] rsp The rpc response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status UpdateTopoNotification(const UpdateTopoNotificationReq &req, UpdateTopoNotificationRsp &rsp) override; + + /** + * @brief Get log prefix + * @param[in] withAddress whether to log with address, the default is false. + * @return Return log prefix + */ + [[nodiscard]] std::string LogPrefix(bool withAddress = false) const; + +private: + /** + * @brief Synchronize remote consumer to remoteWorkerManager_. + * @param[in] streamManager Target stream + * @param[in] remoteConsumerSet remote consumer set obtained in CreateProducer process. + * @param[in] lastAckCursor The last ack cursor. + * @return K_OK on success; the error code otherwise. + */ + Status SyncRemoteConsumer(const std::shared_ptr &streamManager, + const std::vector &remoteConsumerSet, uint64_t lastAckCursor); + + /** + * @brief Master notify worker add remote consumer. + * @param[in] streamManager The stream manager. + * @param[in] consumerMeta The consumer metadata. + * @return K_OK on success; the error code otherwise. + */ + Status AddRemoteConsumer(const std::shared_ptr &streamManager, const ConsumerMetaPb &consumerMeta); + + /** + * @brief Master notify worker delete remote consumer. + * @param[in] streamManager The stream manager. + * @param[in] consumerMeta The consumer metadata. + * @return K_OK on success; the error code otherwise. + */ + Status DelRemoteConsumer(const std::shared_ptr &streamManager, const ConsumerMetaPb &consumerMeta); + + HostPort localWorkerAddress_; + HostPort masterAddress_; + + ClientWorkerSCServiceImpl *clientWorkerSCSvc_; + std::shared_ptr akSkManager_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem + +#endif // DATASYSTEM_WORKER_STREAM_CACHE_MASTER_WORKER_SC_SERVICE_IMPL_H diff --git a/src/datasystem/worker/stream_cache/metrics/CMakeLists.txt b/src/datasystem/worker/stream_cache/metrics/CMakeLists.txt new file mode 100644 index 0000000..51b6255 --- /dev/null +++ b/src/datasystem/worker/stream_cache/metrics/CMakeLists.txt @@ -0,0 +1,16 @@ +set(SC_METRICS_SRCS + sc_metrics.cpp + sc_metrics_monitor.cpp) + +set(SC_METRICS_DEPEND_LIBS + common_log + common_metrics + common_util + posix_protos + ) + +add_library(sc_metrics STATIC ${SC_METRICS_SRCS}) +target_link_libraries(sc_metrics PRIVATE ${SC_METRICS_DEPEND_LIBS}) +add_dependencies(sc_metrics + posix_protos + ) \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/metrics/sc_metrics.cpp b/src/datasystem/worker/stream_cache/metrics/sc_metrics.cpp new file mode 100644 index 0000000..0455b5f --- /dev/null +++ b/src/datasystem/worker/stream_cache/metrics/sc_metrics.cpp @@ -0,0 +1,142 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/metrics/sc_metrics.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" + +namespace datasystem { +// We will print in this order +std::vector SCStreamMetrics::streamMetrics_ = { + StreamMetric::NumLocalProducers, + StreamMetric::NumRemoteProducers, + StreamMetric::NumLocalConsumers, + StreamMetric::NumRemoteConsumers, + + StreamMetric::SharedMemoryUsed, + StreamMetric::LocalMemoryUsed, + + StreamMetric::NumTotalElementsSent, + StreamMetric::NumTotalElementsReceived, + StreamMetric::NumTotalElementsAcked, + StreamMetric::NumSendRequests, + StreamMetric::NumReceiveRequests, + + StreamMetric::NumPagesCreated, + StreamMetric::NumPagesReleased, + StreamMetric::NumPagesInUse, + StreamMetric::NumPagesCached, + StreamMetric::NumBigPagesCreated, + StreamMetric::NumBigPagesReleased, + + StreamMetric::NumLocalProducersBlocked, + StreamMetric::NumRemoteProducersBlocked, + StreamMetric::NumRemoteConsumersBlocking, + + StreamMetric::RetainDataState, + StreamMetric::StreamState +}; +// We will print in this order +std::vector SCStreamMetrics::masterStreamMetrics_ = { + StreamMetric::NumProducersMaster, + StreamMetric::NumConsumersMaster +}; +void SCMetrics::LogMetric(const StreamMetric metric, const uint64_t value) +{ + metricsValuesMap_[metric].store(value, std::memory_order_relaxed); +} + +void SCMetrics::IncrementMetric(const StreamMetric metric, const uint64_t inc) +{ + metricsValuesMap_[metric].fetch_add(inc, std::memory_order_relaxed); +} + +void SCMetrics::DecrementMetric(const StreamMetric metric, const uint64_t inc) +{ + metricsValuesMap_[metric].fetch_sub(inc, std::memory_order_relaxed); +} + +void SCMetrics::Init(const std::vector &metrics) +{ + // Initialize all metrics + for (auto metric : metrics) { + LogMetric(metric, 0); + } +} + +std::string SCMetrics::PrintMetric(const StreamMetric metric) +{ + return std::to_string(GetMetric(metric)); +} + +std::string SCMetrics::PrintMetrics(const std::vector &metrics) +{ + std::string out; + for (auto metric : metrics) { + out += PrintMetric(metric) + "/"; + } + // remove last / + out.pop_back(); + return out; +} + +uint64_t SCMetrics::GetMetric(const StreamMetric metric) +{ + return metricsValuesMap_[metric].load(std::memory_order_relaxed); +} + +int64_t SCMetrics::GetProducersNum() +{ + return metricsValuesMap_[StreamMetric::NumLocalProducers].load(std::memory_order_relaxed); +} + +int64_t SCMetrics::GetConsumersNum() +{ + return metricsValuesMap_[StreamMetric::NumLocalConsumers].load(std::memory_order_relaxed); +} + +uint64_t SCMetrics::GetLocalMemUsed() +{ + return metricsValuesMap_[StreamMetric::LocalMemoryUsed].load(std::memory_order_relaxed); +} + +SCStreamMetrics::SCStreamMetrics(std::string streamName) : streamName_(streamName) +{ + Init(streamMetrics_); + Init(masterStreamMetrics_); +} + +SCStreamMetrics::~SCStreamMetrics() +{ + ScMetricsMonitor::Instance()->ExitStream(this->streamName_, this->PrintMetrics(true)); +} + +std::string SCStreamMetrics::PrintMetrics(const bool isExit) +{ + std::string exit = isExit ? " exit" : ""; + std::string result = streamName_ + exit + "/"; + if (isMgrExist) { + result += SCMetrics::PrintMetrics(streamMetrics_) + "/"; + } else { + result += std::string(SCStreamMetrics::streamMetrics_.size() - 1, '/'); + } + if (isMetaExist) { + result += SCMetrics::PrintMetrics(masterStreamMetrics_); + } else { + result += std::string(SCStreamMetrics::masterStreamMetrics_.size() - 1, '/'); + } + return result; +} +} // namespace datasystem \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/metrics/sc_metrics.def b/src/datasystem/worker/stream_cache/metrics/sc_metrics.def new file mode 100644 index 0000000..bb919cc --- /dev/null +++ b/src/datasystem/worker/stream_cache/metrics/sc_metrics.def @@ -0,0 +1,39 @@ +// Stream Metrics +// Topology +SC_METRIC_KEY_DEF(NumLocalProducers, "Number of local producers on the stream") +SC_METRIC_KEY_DEF(NumRemoteProducers, "Number of remote producers on the stream") +SC_METRIC_KEY_DEF(NumLocalConsumers, "Number of local consumers on the stream") +SC_METRIC_KEY_DEF(NumRemoteConsumers, "Number of remote producers on the stream") + +// Memory Metrics +SC_METRIC_KEY_DEF(SharedMemoryUsed, "Amount of shared memory currently used by the stream") +SC_METRIC_KEY_DEF(LocalMemoryUsed, "Amount of local memory currently used by the stream") + +// Data Metrics +SC_METRIC_KEY_DEF(NumTotalElementsSent, "Number of elements sent by producers on the stream") +SC_METRIC_KEY_DEF(NumTotalElementsReceived, "Number of elements received by all consumers on the stream") +SC_METRIC_KEY_DEF(NumTotalElementsAcked, "Number of elements acked by all consumers on the stream") +SC_METRIC_KEY_DEF(NumSendRequests, "Number of producer send requests called by the client on stream") +SC_METRIC_KEY_DEF(NumReceiveRequests, "Number of consumer receive requests called by the client on stream") + +// Page Metrics +SC_METRIC_KEY_DEF(NumPagesCreated, "Number of normal pages created in the stream") +SC_METRIC_KEY_DEF(NumPagesReleased, "Number of normal pages released in the stream") +SC_METRIC_KEY_DEF(NumPagesInUse, "Number of pages currently in use in the stream") +SC_METRIC_KEY_DEF(NumPagesCached, "Number of pages currently cached in the stream") +SC_METRIC_KEY_DEF(NumBigPagesCreated, "Number of big element pages created in the stream") +SC_METRIC_KEY_DEF(NumBigPagesReleased, "Number of big element pages released in the stream") + +// Blocking Metrics +SC_METRIC_KEY_DEF(NumLocalProducersBlocked, "Number of local producers blocked on the stream") +SC_METRIC_KEY_DEF(NumRemoteProducersBlocked, "Number of remote producers blocked on the stream") +SC_METRIC_KEY_DEF(NumRemoteConsumersBlocking, "Number of remote consumers that are blocking") + +// State Metrics +SC_METRIC_KEY_DEF(RetainDataState, "Current retain data state of the stream (INIT, RETAIN, NOT_RETAIN)") +SC_METRIC_KEY_DEF(StreamState, + "Current state of the stream (ACTIVE, RESET_IN_PROGRESS, RESET_COMPLETE, DELETE_IN_PROGRESS)") + +// Master Stream Metrics +SC_METRIC_KEY_DEF(NumProducersMaster, "Number of producers in the stream on the master") +SC_METRIC_KEY_DEF(NumConsumersMaster, "Number of consumers in the stream on the master") \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/metrics/sc_metrics.h b/src/datasystem/worker/stream_cache/metrics/sc_metrics.h new file mode 100644 index 0000000..000a593 --- /dev/null +++ b/src/datasystem/worker/stream_cache/metrics/sc_metrics.h @@ -0,0 +1,124 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Defines SCMetrics, SCWorkerMetrics, SCStreamMetrics, SCMasterStreamMetrics. + */ +#ifndef DATASYSTEM_STREAM_SC_METRICS_H +#define DATASYSTEM_STREAM_SC_METRICS_H + +#include +#include +#include +#include + +namespace datasystem { +/** + * @brief PerfKey enum specifies the performance point. + * The enum value should add to GetPerfKeyDefines function. + */ +#define SC_METRIC_KEY_DEF(keyEnum, ...) keyEnum, +enum class StreamMetric : size_t { + NONE = 0, + +#include "datasystem/worker/stream_cache/metrics/sc_metrics.def" +}; +#undef SC_METRIC_KEY_DEF + +class SCMetrics { +public: + /** + * @brief Update the specified metric. + */ + void LogMetric(const StreamMetric metric, const uint64_t value); + + /** + * @brief Increment the specified metric by inc. + */ + void IncrementMetric(const StreamMetric metric, const uint64_t inc); + + /** + * @brief Decrement the specified metric by inc. + */ + void DecrementMetric(const StreamMetric metric, const uint64_t inc); + + /** + * @brief Get the print string of the specified metric + * @return Stream metric string + */ + std::string PrintMetric(const StreamMetric metric); + + /** + * @brief Get the print string of the specified metrics + * @return Stream metrics string + */ + std::string PrintMetrics(const std::vector &metrics); + + /** + * @brief Get the specified metric + * @return Stream metric + */ + uint64_t GetMetric(const StreamMetric metric); + + /** + * @brief Obtain local producer num. + * @return Num of local producers. + */ + int64_t GetProducersNum(); + + /** + * @brief Obtain local consumer num. + * @return Num of local consumers. + */ + int64_t GetConsumersNum(); + + /** + * @brief Obtain local memory usage. + * @return Local memory usage. + */ + uint64_t GetLocalMemUsed(); + +protected: + /** + * @brief Initialize the specified metrics + */ + void Init(const std::vector &metrics); + +private: + std::unordered_map> metricsValuesMap_; +}; + +class SCStreamMetrics : public SCMetrics { +public: + SCStreamMetrics(std::string streamName); + ~SCStreamMetrics(); + + /** + * @brief Get the print string of all metrics in the stream + * @return Stream metrics string + */ + std::string PrintMetrics(bool isExit = false); + + static std::vector streamMetrics_; + static std::vector masterStreamMetrics_; + bool isMgrExist = false; + bool isMetaExist = false; + +private: + std::string streamName_; +}; +} // namespace datasystem +#endif \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.cpp b/src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.cpp new file mode 100644 index 0000000..002d58c --- /dev/null +++ b/src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.cpp @@ -0,0 +1,180 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" + +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/log/logging.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/metrics/hard_disk_exporter/hard_disk_exporter.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/validator.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +DS_DECLARE_bool(log_monitor); +DS_DECLARE_uint64(sc_local_cache_memory_size_mb); +DS_DEFINE_uint32(sc_metrics_log_interval_s, 60, "Interval between logging stream metrics. Default to 60s"); +DS_DEFINE_validator(sc_metrics_log_interval_s, &Validator::ValidateUint32); +DS_DECLARE_string(log_dir); + +const uint64_t MAX_LOCAL_CACHE_MEMORY_BYTES = FLAGS_sc_local_cache_memory_size_mb * 1024 * 1024; +const int THREE_DECIMAL_PLACES = 1000; + +#define MONITOR_LOCK_ARGS(lockname) \ + (lockname), [funName = __FUNCTION__] { return FormatString("%s, %s:%s", #lockname, funName, __LINE__); } + +namespace datasystem { +Status ScMetricsMonitor::StartMonitor() +{ + if (FLAGS_log_monitor) { + std::string filePath = FLAGS_log_dir + "/" + SC_METRICS_LOG_NAME + ".log"; + auto hardDiskExporter = std::make_unique(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Logging::CreateLogDir(), K_NOT_READY, "Log file creation failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(hardDiskExporter->Init(filePath), "hardDiskExporter Init failed."); + exporter_ = std::move(hardDiskExporter); + thread_ = std::make_unique(&ScMetricsMonitor::Tick, this); + exitPrintThread_ = std::make_unique(1); + } + isEnabled_ = FLAGS_log_monitor; + return Status::OK(); +} + +ScMetricsMonitor::~ScMetricsMonitor() +{ + if (isEnabled_) { + Shutdown(); + } +} + +ScMetricsMonitor *ScMetricsMonitor::Instance() +{ + static ScMetricsMonitor inst; + return &inst; +} + +bool ScMetricsMonitor::IsEnabled() +{ + return isEnabled_; +} + +void ScMetricsMonitor::UpdateAndPrintMetrics() +{ + Uri uri(__FILE__); + std::unordered_map tempStreams; + { + // Make a temporary copy of streams_ without holding the lock for too long + ReadLockHelper rlock(MONITOR_LOCK_ARGS(mutex_)); + tempStreams = streams_; + } + // Print stream metrics + for (auto streamIt = tempStreams.begin(); streamIt != tempStreams.end(); ++streamIt) { + StreamEntry entry = streamIt->second; + auto streamManager = entry.mgr.lock(); + auto streamMeta = entry.meta.lock(); + if (streamManager || streamMeta) { + std::shared_ptr scMetric; + if (streamManager) { + streamManager->UpdateStreamMetrics(); + scMetric = streamManager->GetSCStreamMetrics(); + } + if (streamMeta) { + streamMeta->UpdateStreamMetrics(); + scMetric = streamMeta->GetSCStreamMetrics(); + } + exporter_->Send(scMetric->PrintMetrics(), uri, __LINE__); + } + } + // Flush log message + exporter_->SubmitWriteMessage(); +} + +Status ScMetricsMonitor::AddStream(const std::string streamName, const std::weak_ptr stream, + std::shared_ptr &metrics) +{ + WriteLockHelper wlock(MONITOR_LOCK_ARGS(mutex_)); + StreamEntry &entry = streams_[streamName]; + entry.mgr = stream; + // Get the stream metrics object from StreamMetaData if it already exists + if (auto ptr = entry.meta.lock()) { + metrics = ptr->GetSCStreamMetrics(); + } else { + metrics = std::make_shared(streamName); + } + metrics->isMgrExist = true; + return Status::OK(); +} + +Status ScMetricsMonitor::AddStreamMeta(const std::string streamName, const std::weak_ptr streamMeta, + std::shared_ptr &metrics) +{ + WriteLockHelper wlock(MONITOR_LOCK_ARGS(mutex_)); + StreamEntry &entry = streams_[streamName]; + entry.meta = streamMeta; + // Get the stream metrics object from StreamManager if it already exists + if (auto ptr = entry.mgr.lock()) { + metrics = ptr->GetSCStreamMetrics(); + } else { + metrics = std::make_shared(streamName); + } + metrics->isMetaExist = true; + return Status::OK(); +} + +void ScMetricsMonitor::ExitStream(const std::string streamName, const std::string msg) +{ + if (isEnabled_) { + // Execute print in another thread to avoid destructor taking time + exitPrintThread_->Execute([this, msg]() { + Uri uri(__FILE__); + exporter_->Send(msg, uri, __LINE__); + }); + // Erase from streams_ + WriteLockHelper wlock(MONITOR_LOCK_ARGS(mutex_)); + streams_.erase(streamName); + } +} + +void ScMetricsMonitor::Tick() +{ + const int CONVERT_TO_MS = 1000; + std::chrono::time_point prevLogTime; + while (!interruptFlag_) { + std::chrono::time_point nowTime = clock::now(); + int64_t elapsed = std::chrono::duration_cast(nowTime - prevLogTime).count(); + // Check if its been FLAGS_sc_metrics_log_interval_s number of secs + if (elapsed >= FLAGS_sc_metrics_log_interval_s * CONVERT_TO_MS) { + prevLogTime = nowTime; + UpdateAndPrintMetrics(); + } + // Wait for FLAGS_sc_metrics_log_interval_s number of secs between another log + cvLock_.WaitFor(FLAGS_sc_metrics_log_interval_s * CONVERT_TO_MS); + } +} + +void ScMetricsMonitor::Shutdown() +{ + if (!thread_) { + return; + } + interruptFlag_ = true; + cvLock_.Set(); + if (thread_->joinable()) { + thread_->join(); + } +} +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h b/src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h new file mode 100644 index 0000000..0208410 --- /dev/null +++ b/src/datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h @@ -0,0 +1,115 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Defines ScMetricsMonitor class to monitor and print stream metrics + */ +#ifndef DATASYSTEM_STREAM_SC_METRICS_MONITOR_H +#define DATASYSTEM_STREAM_SC_METRICS_MONITOR_H + +#include +#include +#include +#include + +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/common/metrics/metrics_exporter.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics.h" + +DS_DECLARE_uint32(sc_metrics_log_interval_s); +namespace datasystem { +namespace worker { + namespace stream_cache { + class StreamManager; + } +} +namespace master { + class StreamMetadata; +} +using StreamManager = worker::stream_cache::StreamManager; +using StreamMetadata = master::StreamMetadata; +using clock = std::chrono::steady_clock; +class ScMetricsMonitor { +public: + ~ScMetricsMonitor(); + /** + * @brief Get the Singleton Metrics manager instance. + * @return ScMetricsMonitor instance. + */ + static ScMetricsMonitor *Instance(); + + /** + * @brief Init and Start the ScMetricsMonitor. + * @return Status of the call + */ + Status StartMonitor(); + + /** + * @brief Register a stream to monitor. + * @return Status of the call + */ + Status AddStream(const std::string streamName, const std::weak_ptr stream, + std::shared_ptr &metrics); + + /** + * @brief Register a master meta to monitor. + * @return Status of the call + */ + Status AddStreamMeta(const std::string streamName, const std::weak_ptr streamMeta, + std::shared_ptr &metrics); + + /** + * @brief Trigger Monitoring logs. Should call in main thread. + */ + void Tick(); + + /** + * @brief Shutdown the monitor + */ + void Shutdown(); + + /** + * @brief Check whether stream metrics monitoring is enabled + * @return True if stream metrics monitoring is enabled + */ + bool IsEnabled(); + + /** + * @brief Remove stream from streams_ map, Print the stream metric message + */ + void ExitStream(const std::string streamName, const std::string msg); + +private: + /** + * @brief Updates and prints all worker and stream metrics + */ + void UpdateAndPrintMetrics(); + struct StreamEntry { + std::weak_ptr mgr; + std::weak_ptr meta; + }; + + std::unique_ptr thread_{ nullptr }; + std::unique_ptr exitPrintThread_{ nullptr }; + std::unordered_map streams_; + std::shared_timed_mutex mutex_; // protect streams_ map + std::atomic interruptFlag_{ false }; + std::unique_ptr exporter_{ nullptr }; + WaitPost cvLock_; + bool isEnabled_{ false }; +}; +} // namespace datasystem +#endif \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/page_queue/CMakeLists.txt b/src/datasystem/worker/stream_cache/page_queue/CMakeLists.txt new file mode 100644 index 0000000..6dbb7cf --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/CMakeLists.txt @@ -0,0 +1,18 @@ +set(SC_PAGE_QUEUE_SRCS + page_queue_base.cpp + exclusive_page_queue.cpp + shared_page_queue.cpp + shared_page_queue_group.cpp + page_queue_handler.cpp +) + +set(SC_PAGE_QUEUE_DEPEND_LIBS + common_log + common_util +) + +add_library(sc_page_queue STATIC ${SC_PAGE_QUEUE_SRCS}) +target_link_libraries(sc_page_queue PRIVATE ${SC_PAGE_QUEUE_DEPEND_LIBS}) +add_dependencies(sc_page_queue + share_memory_protos +) diff --git a/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp b/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp new file mode 100644 index 0000000..146d4fb --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.cpp @@ -0,0 +1,603 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: ExclusivePageQueue + */ + +#include "datasystem/worker/stream_cache/page_queue/exclusive_page_queue.h" + +#include "datasystem/common/constants.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/worker/stream_cache/stream_data_pool.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +DS_DEFINE_uint32(sc_cache_pages, 16, "Default number of cache pages"); +DS_DECLARE_string(sc_encrypt_secret_key); +DS_DECLARE_string(encrypt_kit); +DS_DECLARE_uint64(shared_memory_size_mb); + +namespace datasystem { +namespace worker { +namespace stream_cache { + +Status ExclusivePageQueue::UnblockProducers() +{ + // If there is at least one page available, wake up all blocked producers. + // BigElement size is always greater than one page and so when a BigElement + // is released, we know there should be sufficient memory to create one data + // page. + { + ReadLockHelper rlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + if (pendingFreePages_.empty() && ackChain_.empty() + && streamFields_.maxStreamSize_ - usedMemBytes_ < static_cast(streamFields_.pageSize_)) { + return Status::OK(); + } + } + { + // unblock remote stream send + std::lock_guard lock2(unblockMutex_); + std::for_each(unblockCallbacks_.begin(), unblockCallbacks_.end(), [](auto &cb) { cb.second(); }); + // We need to clear the call backs after we sent them + unblockCallbacks_.clear(); + } + // Ask the parent stream manager to unblock any local CreateDataPage waiters + streamMgr_->UnblockCreators(); + return Status::OK(); +} + +Status ExclusivePageQueue::GetDataPage( + const GetDataPageReqPb &req, const std::shared_ptr &consumer, + const std::shared_ptr> &serverApi) +{ + const auto &lastRecvCursor = req.last_recv_cursor(); + const auto timeoutMs = static_cast(req.timeout_ms()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Locate page. lastRecvCursor %zu. Timeout %u", LogPrefix(), + lastRecvCursor, timeoutMs); + std::shared_ptr page; + Status rc = LocatePage(lastRecvCursor, page); + if (rc.GetCode() == K_NOT_FOUND && timeoutMs > 0) { + // Worker execute the pending receive timer before add the pending receive task. + TimerQueue::TimerImpl timer; + auto traceID = Trace::Instance().GetTraceID(); + auto clientService = streamMgr_->GetClientService(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(clientService != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + auto streamName = streamMgr_->GetStreamName(); + GetDataPageReqPb rq = req; // Pass a non-const copy to the lambda + auto func = [clientService, streamName, rq, serverApi, traceID]() mutable { + // Turn off the timer because we are going call itself. Otherwise, we run into loop. + rq.set_timeout_ms(0); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceID); + // When wakes up, check if the stream has been deleted or not. + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(clientService->GetStreamManager(streamName, accessor)); + std::shared_ptr streamMgr = accessor->second; + RETURN_IF_NOT_OK(streamMgr->CheckIfStreamActive()); + std::shared_ptr sub; + RETURN_IF_NOT_OK(streamMgr->GetSubscription(rq.subscription_name(), sub)); + const auto &consumerId = rq.consumer_id(); + CHECK_FAIL_RETURN_STATUS(sub->GetSubscriptionType() == SubscriptionType::STREAM, StatusCode::K_INVALID, + "Only support STREAM mode."); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(sub->GetConsumer(consumerId, consumer)); + Status rc; + bool wakeupPendingRecvOnProdFault; + consumer->RemovePendingReceive(wakeupPendingRecvOnProdFault); + if (wakeupPendingRecvOnProdFault) { + rc = { K_SC_PRODUCER_NOT_FOUND, "Consumer cannot continue due to producer being forced off." }; + serverApi->SendStatus(rc); + return Status::OK(); + } + rc = streamMgr->GetExclusivePageQueue()->GetDataPage(rq, consumer, serverApi); + if (rc.IsError()) { + serverApi->SendStatus(rc); + } + return Status::OK(); + }; + return consumer->AddPendingReceive(lastRecvCursor, timeoutMs, func, serverApi); + } + RETURN_IF_NOT_OK(rc); + RETURN_RUNTIME_ERROR_IF_NULL(page); + return ReturnGetPageRspPb(page->GetShmView(), serverApi); +} + +Status ExclusivePageQueue::ReturnGetPageRspPb( + const ShmView &shmView, + const std::shared_ptr> &serverApi) +{ + GetDataPageRspPb rsp; + ShmViewPb pb; + pb.set_fd(shmView.fd); + pb.set_mmap_size(shmView.mmapSz); + pb.set_offset(shmView.off); + pb.set_size(shmView.sz); + rsp.mutable_page_view()->CopyFrom(pb); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi->Write(rsp), "Write reply to client stream failed"); + // We do not have to increase the reference count. This call is made for the purpose of + // consumers (both local and remote) and we won't ack any page consumers are still reading. + return Status::OK(); +} + +Status ExclusivePageQueue::UpdateStreamFields(const StreamFields &streamFields) +{ + const static uint64_t shmMemSize = FLAGS_shared_memory_size_mb * 1024 * 1024; + std::unique_lock lock(cfgMutex_); + // If the fields were empty, do not sanity check them and just set the fields to the new values. + if (streamFields_.Empty()) { + // The optional parameters are given, sanity check them first against the shared memory limits. + // (Other checks are already done at the client side) + CHECK_FAIL_RETURN_STATUS( + streamFields.maxStreamSize_ <= shmMemSize, K_INVALID, + FormatString("maxStreamSize exceeds shared memory size [stream size, shm size]: [%zu, %zu] ", + streamFields.maxStreamSize_, shmMemSize)); + CHECK_FAIL_RETURN_STATUS(static_cast(streamFields.pageSize_) <= shmMemSize, K_INVALID, + FormatString("pageSize exceeds shared memory size [page size, shm size]: [%zu, %zu] ", + streamFields.pageSize_, shmMemSize)); + streamFields_ = streamFields; + LOG(INFO) << FormatString( + "[%s] Stream configuration updated with max stream size: %zu, page size: %zu, auto cleanup: %s, " + "retain for num consumers: %zu, encrypt stream: %s, and reserve size: %zu", + LogPrefix(), streamFields_.maxStreamSize_, streamFields_.pageSize_, + streamFields_.autoCleanup_ ? "true" : "false", streamFields_.retainForNumConsumers_, + streamFields_.encryptStream_ ? "true" : "false", streamFields_.reserveSize_); + return Status::OK(); + } + + // otherwise, sanity check that the new settings are a match to the old ones (and do not change them) + CHECK_FAIL_RETURN_STATUS( + streamFields_ == streamFields, K_INVALID, + FormatString("[%s] Changing stream config fields [max stream size, page size, auto cleanup, retain for num " + "consumers, encrypt stream, reserve size, stream mode] not supported: Current: [%zu, %zu, %s, " + "%zu, %s, %zu, %d] " + "Invalid: [%zu, %zu, %s, %zu, %s, %zu, %d]", + LogPrefix(), streamFields_.maxStreamSize_, streamFields_.pageSize_, + (streamFields_.autoCleanup_ ? "true" : "false"), streamFields_.retainForNumConsumers_, + streamFields_.encryptStream_ ? "true" : "false", streamFields_.reserveSize_, + streamFields_.streamMode_, streamFields.maxStreamSize_, streamFields.pageSize_, + (streamFields.autoCleanup_ ? "true" : "false"), streamFields.retainForNumConsumers_, + streamFields.encryptStream_ ? "true" : "false", streamFields.reserveSize_, + streamFields.streamMode_)); + return Status::OK(); +} + +size_t ExclusivePageQueue::GetPageSize() const +{ + return streamFields_.pageSize_; +} + +bool ExclusivePageQueue::AutoCleanup() const +{ + return streamFields_.autoCleanup_; +} + +uint64_t ExclusivePageQueue::GetReserveSize() const +{ + return streamFields_.reserveSize_ == 0 ? streamFields_.pageSize_ : streamFields_.reserveSize_; +} + +std::string ExclusivePageQueue::LogPrefix() const +{ + return FormatString("S:%s", streamMgr_->GetStreamName()); +} + +Status ExclusivePageQueue::CheckHadEnoughMem(size_t memSize) +{ + PerfPoint point(PerfKey::PAGE_CHECK_MEM); + INJECT_POINT("worker.CheckHadEnoughMem"); + auto maxStreamSize = streamFields_.maxStreamSize_; + // We will use the other form to calculate a theoretical used size by taking + // pendingFreePages_ into consideration. + // The pages in the pendingFreePages_ will be released when they are safe to do + // In a way, we are borrowing some bytes from the Arena but will return them as soon + // as the pending free pages are released. + // Note that the difference of two atomic variables does not always give accurate result. + // But pendingFreeBytes_ should always be subset of usedMemBytes_ + auto usedMemBytes = std::min(usedMemBytes_.load(std::memory_order_relaxed), maxStreamSize); + auto pendingFreeBytes = pendingFreeBytes_.load(std::memory_order_relaxed); + auto bytesUsed = (pendingFreeBytes <= usedMemBytes) ? usedMemBytes - pendingFreeBytes : 0; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] max memory %zu, used %zu, need %zu", LogPrefix(), maxStreamSize, + bytesUsed, memSize); + if (maxStreamSize - bytesUsed < memSize) { + RETURN_STATUS( + StatusCode::K_OUT_OF_MEMORY, + FormatString("[%s] The stream does not have enough memory, and max memory %zu, used %zu, need %zu", + LogPrefix(), maxStreamSize, bytesUsed, memSize)); + } + return Status::OK(); +} + +Status ExclusivePageQueue::VerifyWhenAlloc() const +{ + std::unique_lock lock(cfgMutex_); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + !streamFields_.Empty(), StatusCode::K_RUNTIME_ERROR, + FormatString("[%s] Uninitialized page or stream sizes in stream page owner", LogPrefix())); + return Status::OK(); +} + +Status ExclusivePageQueue::AllocateMemoryImpl(size_t memSizeNeeded, ShmUnit &shmUnit, bool retryOnOOM) +{ + auto tenantId = TenantAuthManager::ExtractTenantId(streamMgr_->GetStreamName()); + return streamMgr_->GetAllocManager()->AllocateMemoryForStream(tenantId, streamMgr_->GetStreamName(), memSizeNeeded, + true, shmUnit, retryOnOOM); +} + +Status ExclusivePageQueue::ReserveStreamMemory() +{ + RETURN_OK_IF_TRUE(reserveState_.pageZeroCreated && reserveState_.freeListCreated); + // This creates the initial first page when the object is created. + // It is driven by StreamManager. + { + std::unique_lock lock(cfgMutex_); + // If we don't have the page size (yet), return K_NOT_READY. + if (streamFields_.Empty()) { + const std::string errMsg = + FormatString("[%s] Uninitialized page or stream sizes in stream page owner", LogPrefix()); + LOG(INFO) << errMsg; + RETURN_STATUS(K_NOT_READY, errMsg); + } + } + // We will reserve stream memory up to streamFields_.reserveSize_ + // If reserveSize_ is 0, it defaults to one page size. + // In other words, we always allocate at least the first page. + auto createPageZero = [this]() { + INJECT_POINT("CreatePageZero.AllocMemory"); + RETURN_OK_IF_TRUE(reserveState_.pageZeroCreated); + std::shared_ptr lastPage; + // We pass a null ShmView as the last reference and the page creation will be a no-op + // if the first page is already created. + // We do not want to call the CreateOrGetLastDataPage version because the stream manager + // is not fully created yet. It is not wise to call TryWakeUpPendingReceive() or update + // any producer's mailbox. + Status rc = CreateOrGetLastDataPageImpl(0, ShmView(), lastPage, true); + if (rc.IsOk()) { + reserveState_.pageZeroCreated = true; + } + return rc; + }; + Status rc = createPageZero(); + // Map OOM to a new return code if we aren't able to reserve the initial page + if (rc.GetCode() == K_OUT_OF_MEMORY) { + RETURN_STATUS_LOG_ERROR( + K_SC_STREAM_RESOURCE_ERROR, + FormatString("[%s] Resource error. Unable to reserve the first shared memory page memory, detail: %s", + LogPrefix(), rc.ToString())); + } + if (streamFields_.reserveSize_ <= streamFields_.pageSize_) { + reserveState_.freeListCreated = true; // No need to continue; + return rc; + } + // The rest of the code is when the reserveSize_ > page size + RETURN_IF_NOT_OK(ReserveAdditionalMemory()); + return Status::OK(); +} + +Status ExclusivePageQueue::ReserveAdditionalMemory() +{ + // Switch to use signed integer + auto pageSize = static_cast(streamFields_.pageSize_); + // Page 0 has been created. + int64_t remaining = static_cast(streamFields_.reserveSize_) - pageSize; + // We must add the rest to the ackChain_, not idxChain_ + WriteLockHelper xlock1(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + // Check again after we have the lock. + RETURN_OK_IF_TRUE(reserveState_.freeListCreated); + std::list freeList; + std::vector undoList; + bool needRollback = true; + Raii raii([this, &needRollback, &undoList]() { + if (needRollback) { + // Undo the changes. We can leave the lastPage alone. It is more work to take it out. + (void)FreePages(undoList, false); + } + }); + auto func = [this, remaining, pageSize, &freeList, &undoList]() mutable -> Status { + while (remaining > 0) { + auto f = []() { + INJECT_POINT("ReserveAdditionalMemory.AllocMemory"); + return Status::OK(); + }; + RETURN_IF_NOT_OK(f()); + std::shared_ptr page; + RETURN_IF_NOT_OK(CreateNewPage(page, true)); + undoList.push_back(page->GetPageId()); + freeList.emplace_back(0, std::move(page)); + remaining -= pageSize; + } + LOG(INFO) << FormatString("[%s] Reserve additional %zu pages", LogPrefix(), freeList.size()); + return Status::OK(); + }; + Status rc = func(); + // As in other code path. Change the rc from OOM to RESOURCE_ERROR + if (rc.GetCode() == K_OUT_OF_MEMORY) { + RETURN_STATUS_LOG_ERROR(K_SC_STREAM_RESOURCE_ERROR, + FormatString("[%s] Resource error. Unable to reserve additional %zu bytes", LogPrefix(), + static_cast(streamFields_.reserveSize_) - pageSize)); + return rc; + } + Optional> reserveList(freeList); + RETURN_IF_NOT_OK(AppendFreePagesImplNotLocked(0, reserveList, false)); + reserveState_.freeListCreated = true; + needRollback = false; + return Status::OK(); +} + +Status ExclusivePageQueue::InsertBigElement(void *buf, size_t sz, std::pair &res, uint64_t timeoutMs, + const bool headerBit, StreamMetaShm *streamMetaShm) +{ + RaiiPlus raiiP; + if (streamMetaShm) { + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(streamMetaShm->TryIncUsage(sz), "TryIncUsage failed"); + raiiP.AddTask([sz, streamMetaShm]() { + LOG_IF_ERROR(streamMetaShm->TryDecUsage(sz), ""); + }); + } + + auto flags = InsertFlags::REMOTE_ELEMENT | InsertFlags::BIG_ELEMENT; + ShmView outView; + RETURN_IF_NOT_OK(streamMgr_->AllocBigShmMemoryInternalReq(timeoutMs, sz, outView)); + std::shared_ptr pageUnitInfo; + RETURN_IF_NOT_OK(LocatePage(outView, pageUnitInfo)); + // From now on make sure we free the memory on error exit + bool needRollback = true; + auto raii = std::make_unique([this, &needRollback, &pageUnitInfo]() { + if (needRollback) { + std::vector v; + auto pageId = StreamPageBase::CreatePageId(pageUnitInfo); + v.push_back(pageId); + (void)FreePages(v, true); + } + }); + auto bigElementPage = std::make_shared(pageUnitInfo, false); + std::string pointerString; + RETURN_IF_NOT_OK(bigElementPage->Init()); + // using ExclusivePage for consumer node, no need streamNo. + HeaderAndData ele(reinterpret_cast(buf), sz, 0); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(StreamDataPage::SerializeToShmViewPb(bigElementPage->GetShmView(), pointerString), + "Serialization error"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(bigElementPage->Insert(ele), "BigElement insert"); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Inserting a BigElement [S:%zu] into BigElement page: %s", + streamMgr_->GetStreamName(), ele.size, bigElementPage->GetPageId()); + INJECT_POINT("InsertBigElement.Rollback"); + std::vector v; + v.push_back(pointerString.size()); + std::vector headerBits; + headerBits.push_back(headerBit); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + BatchInsertImpl(pointerString.data(), v, res, flags, timeoutMs, headerBits, streamMetaShm), "Insert pointer"); + res.second = sz; + needRollback = false; + raiiP.ClearAllTask(); + (void)needRollback; + return Status::OK(); +} + +Status ExclusivePageQueue::BatchInsertImpl(void *buf, std::vector &sz, std::pair &res, + InsertFlags flags, uint64_t timeoutMs, const std::vector &headerBits, + StreamMetaShm *streamMetaShm) +{ + // This is a special form of batch insert. The data is sent from a remote worker + // and the elements are already in the correct reverse order. + Status rc; + std::shared_ptr lastPage; + auto handleNoSpace = [this, &lastPage, &timeoutMs]() { + // As in local producer case, if there is a free page coming after, + // seal the current page. + // First of all we need to lock the page to block other producers. + ShmView nextPage; + bool isFreePage; + lastPage->nextPage_->GetView(nextPage, isFreePage, std::numeric_limits::max()); + // If there is no next page encoded on the page, just return no space + if (nextPage.fd <= 0) { + RETURN_STATUS(K_NO_SPACE, "No next page"); + } + // There is a next page. Check if the current page has been sealed. + if (isFreePage) { + // Now we will seal the current page, acquire the next page. + auto func = [this](const ShmView &v, std::shared_ptr &out) { return LocatePage(v, out); }; + RETURN_IF_NOT_OK_EXCEPT(lastPage->Seal(nextPage, timeoutMs, func, LogPrefix()), K_DUPLICATED); + } + RETURN_STATUS(K_SC_END_OF_PAGE, "New empty page is created"); + }; + do { + std::vector remaining(sz.begin() + res.first, sz.end()); + ShmView outView; + RETURN_IF_NOT_OK(streamMgr_->AllocDataPageInternalReq( + timeoutMs, lastPage == nullptr ? ShmView() : lastPage->GetShmView(), outView)); + std::shared_ptr pageUnitInfo; + RETURN_IF_NOT_OK(LocatePage(outView, pageUnitInfo)); + lastPage = std::make_shared(pageUnitInfo, 0, false, isSharedPage_); + RETURN_IF_NOT_OK(lastPage->Init()); + RETURN_IF_NOT_OK(lastPage->RefPage(FormatString("%s:%s", __FUNCTION__, __LINE__))); + std::pair batchRes(0, 0); + rc = lastPage->BatchInsert(buf, remaining, timeoutMs, batchRes, flags, headerBits, streamMetaShm); + res.first += batchRes.first; + res.second += batchRes.second; + if (rc.GetCode() == K_NO_SPACE) { + rc = handleNoSpace(); + } + Status rc2 = lastPage->ReleasePage(FormatString("%s:%s", __FUNCTION__, __LINE__)); + if (rc.IsOk()) { + RETURN_IF_NOT_OK(rc2); + if (res.first < sz.size()) { + // Need to find another page to continue + std::string msg = + FormatString("[%s] Page can only insert %zu rows. Total %zu. Remaining %zu. Continue to next page", + LogPrefix(), batchRes.first, sz.size(), sz.size() - res.first); + VLOG(SC_NORMAL_LOG_LEVEL) << msg; + rc = Status(K_TRY_AGAIN, msg); + } + } + } while (rc.GetCode() == K_NO_SPACE || rc.GetCode() == K_SC_END_OF_PAGE || rc.GetCode() == K_TRY_AGAIN); + return rc; +} + +Status ExclusivePageQueue::BatchInsert(void *buf, std::vector &sz, std::pair &res, + uint64_t timeoutMs, const std::vector &headerBits, + StreamMetaShm *streamMetaShm) +{ + // Check if it is a big element. If it is a big element, remote manager will send + // it separately in its own PV. + if (sz.size() == 1 && sz.at(0) > static_cast(streamFields_.pageSize_ - StreamDataPage::PageOverhead())) { + return InsertBigElement(buf, sz.at(0), res, timeoutMs, headerBits.at(0), streamMetaShm); + } else { + return BatchInsertImpl(buf, sz, res, InsertFlags::REMOTE_ELEMENT, timeoutMs, headerBits, streamMetaShm); + } +} + +Status ExclusivePageQueue::ReleaseAllPages() +{ + auto lastAppendCursor = GetLastAppendCursor(); + WriteLockHelper xlock1(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + lastPage_.reset(); + idxChain_.clear(); + ackChain_.clear(); + pendingFreePages_.clear(); + pendingFreeBytes_ = 0; + { + WriteLockHelper xlock4(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + shmPool_.clear(); + } + usedMemBytes_ = 0; + lastAckCursor_ = lastAppendCursor; + nextCursor_ = lastAppendCursor + 1; + reserveState_.pageZeroCreated = false; + reserveState_.freeListCreated = false; + // Set release metrics equal to create metrics + // Metrics safe since already protected by other locks + scMetricPagesReleased_.store(scMetricPagesCreated_.load(std::memory_order_relaxed), std::memory_order_relaxed); + scMetricBigPagesReleased_.store(scMetricBigPagesCreated_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + LOG(INFO) << FormatString("[%s] ExclusivePageQueue release all pages done", LogPrefix()); + return Status::OK(); +} + +Status ExclusivePageQueue::Reset() +{ + // Synchronize with the timer queue + { + std::unique_lock lock(unblockMutex_); + unblockCallbacks_.clear(); + } + // Block any new page created + WriteLockHelper xlock1(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + // Block timer queue + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + // Drop the reference on the last page + if (lastPage_) { + RETURN_IF_NOT_OK(lastPage_->ReleasePage(LogPrefix())); + } + lastPage_.reset(); + idxChain_.clear(); + ackChain_.clear(); + pendingFreePages_.clear(); + pendingFreeBytes_ = 0; + { + WriteLockHelper xlock4(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + shmPool_.clear(); + } + usedMemBytes_ = 0; + lastAckCursor_ = 0; + nextCursor_ = 1; + reserveState_.pageZeroCreated = false; + reserveState_.freeListCreated = false; + // Set release metrics equal to create metrics + // Metrics safe since already protected by other locks + scMetricPagesReleased_.store(scMetricPagesCreated_.load(std::memory_order_relaxed), std::memory_order_relaxed); + scMetricBigPagesReleased_.store(scMetricBigPagesCreated_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + LOG(INFO) << FormatString("[%s] ExclusivePageQueue reset done", LogPrefix()); + return Status::OK(); +} + +void ExclusivePageQueue::RemoveUnblockCallback(const std::string &addr) +{ + std::lock_guard lock(unblockMutex_); + unblockCallbacks_.erase(addr); +} + +void ExclusivePageQueue::AddUnblockCallback(const std::string &addr, const std::function &unblockCallback) +{ + std::lock_guard lock(unblockMutex_); + unblockCallbacks_.emplace(addr, unblockCallback); +} + +ExclusivePageQueue::ExclusivePageQueue(StreamManager *mgr, Optional cfg) + : streamFields_(0, 0, false, 0, false, 0, StreamMode::MPMC), + maxWindowCount_(1), + streamMgr_(mgr), + streamName_(streamMgr_ ? streamMgr_->GetStreamName() : "") +{ + if (cfg) { + streamFields_ = cfg.value(); + } +} + +bool ExclusivePageQueue::IsEncryptStream(const std::string &streamName) const +{ + (void)streamName; + // Note: If FLAGS_sc_encrypt_secret_key is empty, or if FLAGS_encrypt_kit is set to plaintext, + // or if authentication function between components is already enabled, + // then stream data encryption/decryption is not applicable. + // But it should not fail the stream creation, so the check is delayed. + return streamFields_.encryptStream_ && !FLAGS_sc_encrypt_secret_key.empty() + && FLAGS_encrypt_kit != ENCRYPT_KIT_PLAINTEXT; +} + +RemoteWorkerManager *ExclusivePageQueue::GetRemoteWorkerManager() const +{ + return streamMgr_->GetRemoteWorkerManager(); +} + +uint64_t ExclusivePageQueue::GetNumPagesCached() const +{ + ReadLockHelper rlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + return ackChain_.size(); +} + +uint64_t ExclusivePageQueue::GetNumPagesInUse() const +{ + ReadLockHelper rlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + return idxChain_.size(); +} + +Status ExclusivePageQueue::UpdateLocalCursorLastDataPage(const ShmView &shmView) +{ + RETURN_RUNTIME_ERROR_IF_NULL(updateLastDataPageHandler_); + return updateLastDataPageHandler_(shmView); +} + +std::pair ExclusivePageQueue::GetNextBlockedRequestSize() +{ + return streamMgr_->GetNextBlockedRequestSize(); +} + +std::shared_ptr ExclusivePageQueue::SharedFromThis() +{ + return std::static_pointer_cast(shared_from_this()); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.h b/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.h new file mode 100644 index 0000000..14da6cc --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/exclusive_page_queue.h @@ -0,0 +1,281 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: ExclusivePageQueue + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_EXCLUSIVE_PAGE_QUEUE_H +#define DATASYSTEM_WORKER_STREAM_CACHE_EXCLUSIVE_PAGE_QUEUE_H + +#include + +#include + +#include "datasystem/common/rpc/rpc_server_stream_base.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/stream_cache/stream_meta_shm.h" +#include "datasystem/common/util/bitmask_enum.h" +#include "datasystem/worker/stream_cache/consumer.h" +#include "datasystem/worker/stream_cache/page_queue/page_queue_base.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class RemoteWorkerManager; +class StreamManager; +/** + * A class that manages the stream data pages of a stream + */ +class ExclusivePageQueue : public std::enable_shared_from_this, public PageQueueBase { +public: + ExclusivePageQueue(StreamManager *mgr, Optional cfg); + ~ExclusivePageQueue() = default; + + /** + * @brief Verifies the input stream fields match the existing setting. + * If the existing settings are uninitialized, updates the values. + * @param[in] streamFields The stream fields with page size and max stream size to check + * @return Status of the call. + */ + Status UpdateStreamFields(const StreamFields &streamFields); + + /** + * @return The reserve memory size + */ + uint64_t GetReserveSize() const; + + /** + * @return T if auto cleanup is on + */ + bool AutoCleanup() const; + + /** + * @brief Reserve stream memory when the stream data object is created. + * @return K_OK if successful or if the 1st page is already created. + * Possible error code return can be K_OUT_OF_MEMORY + * @note If the page size of the stream is unknown, the function returns K_OK without allocating any memory. + */ + Status ReserveStreamMemory(); + + /** + * Reset the object + * @return + */ + Status Reset(); + + /** + * @brief Getter of lastAckCursor_. + * @return lastAckCursor_. + */ + uint64_t GetLastAckCursor() const + { + return lastAckCursor_; + } + + /** + * @brief Add callback to unblock sending stream. + * @param[in] addr The producer worker address. + * @param[in] unblockCallback The callback functions to unblock the stream. + */ + void AddUnblockCallback(const std::string &addr, const std::function &unblockCallback); + + /** + * @brief Remove callback when remote producer is getting deleted. + * @param[in] addr The producer worker address. + */ + void RemoveUnblockCallback(const std::string &addr); + + /* + * Link back to the stream manager + */ + auto GetStreamManager() + { + return streamMgr_; + } + + std::string GetStreamName() const override + { + return streamName_; + } + + RemoteWorkerManager *GetRemoteWorkerManager() const override; + + /** + * @brief Gets the max stream size + * @return Max stream size + */ + uint32_t GetMaxStreamSize() const + { + return streamFields_.maxStreamSize_; + } + + /** + * @brief Gets the shared memory used + * @return Shared memory used + */ + uint64_t GetSharedMemoryUsed() const + { + return usedMemBytes_.load(std::memory_order_relaxed); + } + + /** + * @brief Gets the number of pages created + */ + uint64_t GetNumPagesCreated() const + { + return scMetricPagesCreated_.load(std::memory_order_relaxed); + } + + /** + * @brief Gets the number of pages released + */ + uint64_t GetNumPagesReleased() const + { + return scMetricPagesReleased_.load(std::memory_order_relaxed); + } + + /** + * @brief Gets the number of pages in use + */ + uint64_t GetNumPagesInUse() const; + + /** + * @brief Gets the number of pages cached + */ + uint64_t GetNumPagesCached() const; + + /** + * @brief Gets the number of big element pages created + */ + uint64_t GetNumBigPagesCreated() const + { + return scMetricBigPagesCreated_.load(std::memory_order_relaxed); + } + + /** + * @brief Gets the number of big element pages released + */ + uint64_t GetNumBigPagesReleased() const + { + return scMetricBigPagesReleased_.load(std::memory_order_relaxed); + } + + /** + * @brief Batch insert. + * @param[in] buf contiguous payload of the elements in reverse order + * @param[in] sz vector of the size of the elements + * @param[in] headerBits Is the element's data contain header + * @param[in] streamMetaShm The pointer to streamMetaShm + * @return Status + * @note This function is not thread safe, and/or lock safe. + */ + Status BatchInsert(void *buf, std::vector &sz, std::pair &res, uint64_t timeoutMs, + const std::vector &headerBits, StreamMetaShm *StreamMetaShm); + + /** + * @brief Used by the consumer to return the first starting page of the lastAckCursor + * @param req + * @param consumer + * @param GetDataPage unary server writer handle + * @return OK + */ + Status GetDataPage(const GetDataPageReqPb &req, const std::shared_ptr &consumer, + const std::shared_ptr> &serverApi); + + /** + * @brief Gets the stream fields + * @param[out] streamFields Return the stream fields with page size and max stream size. + */ + void GetStreamFields(StreamFields &streamFields) + { + streamFields = streamFields_; + } + + /** + * @brief Get encryptStream of streamFields_, and apply sanity checks. + * @return true if stream data encryption is applicable. + */ + bool IsEncryptStream(const std::string &streamName) const override; + + /** + * @brief Get max window count + */ + auto GetMaxWindowCount() const + { + return maxWindowCount_; + } + + /** + * @brief Logs the cursors + * @return + */ + void LogCursors(); + + Status ReleaseAllPages(); + + void RegisterUpdateLastDataPageHandler(std::function updateLastDataPageHandler) + { + updateLastDataPageHandler_ = std::move(updateLastDataPageHandler); + } + +public: + std::string LogPrefix() const override; + size_t GetPageSize() const override; + Status CheckHadEnoughMem(size_t memSize) override; + Status VerifyWhenAlloc() const override; + Status AllocateMemoryImpl(size_t memSizeNeeded, ShmUnit &shmUnit, bool retryOnOOM) override; + Status UpdateLocalCursorLastDataPage(const ShmView &shmView) override; + Status AfterAck() override + { + return UnblockProducers(); + } + + virtual std::pair GetNextBlockedRequestSize() override; + +private: + std::shared_ptr SharedFromThis() override; + friend class StreamDataPool; + mutable std::mutex cfgMutex_; // protect streamFields_ + StreamFields streamFields_; + uint64_t maxWindowCount_; + mutable std::mutex unblockMutex_; // unblockCallbacks_ + std::unordered_map> unblockCallbacks_; + StreamManager *streamMgr_; // Back pointer to parent class + const std::string streamName_; + struct { + std::atomic pageZeroCreated{ false }; + std::atomic freeListCreated{ false }; + } reserveState_; + + std::function updateLastDataPageHandler_; + + static Status ReturnGetPageRspPb( + const ShmView &shmView, + const std::shared_ptr> &serverApi); + Status InsertBigElement(void *buf, size_t sz, std::pair &res, uint64_t timeoutMs, + const bool headerBit, StreamMetaShm *streamMetaShm); + Status BatchInsertImpl(void *buf, std::vector &sz, std::pair &res, InsertFlags flags, + uint64_t timeoutMs, const std::vector &headerBits, StreamMetaShm *streamMetaShm); + Status UnblockProducers(); + Status ReserveAdditionalMemory(); +}; + +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp new file mode 100644 index 0000000..05cdb8b --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.cpp @@ -0,0 +1,1239 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: PageQueueBase + */ + +#include +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/status_helper.h" + +#include "datasystem/common/constants.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/request_counter.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/stream_cache/page_queue/page_queue_base.h" +#include "datasystem/worker/stream_cache/remote_worker_manager.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +PageQueueBase::PageQueueBase() : lastPage_(nullptr), usedMemBytes_(0), lastAckCursor_(0), nextCursor_(1) +{ +} + +uint64_t PageQueueBase::GetLastAppendCursor() const +{ + ReadLockHelper rlock1(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + if (lastPage_ == nullptr) { + return lastAckCursor_; + } + return lastPage_->GetLastCursor(); +} + +Status PageQueueBase::CreateOrGetLastDataPage(uint64_t timeoutMs, const ShmView &lastView, + std::shared_ptr &lastPage, bool retryOnOOM) +{ + // This function can be called from two different threads. + // Remote producers can call ExclusivePageQueue::BatchInsert and + // local producers can call via rpc. + // Any changes must ensure it is thread safe. + RETURN_IF_NOT_OK(CreateOrGetLastDataPageImpl(timeoutMs, lastView, lastPage, retryOnOOM)); + return Status::OK(); +} + +Status PageQueueBase::CreateOrGetLastDataPageImpl(uint64_t timeoutMs, const ShmView &lastView, + std::shared_ptr &lastPage, bool retryOnOOM) +{ + WriteLockHelper xlock1(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + // Deal with the easy case when the object is empty + if (lastPage_ == nullptr) { + RETURN_IF_NOT_OK(CreateNewPage(lastPage, retryOnOOM)); + auto nextCursor = nextCursor_.load(std::memory_order_relaxed); + // Now we update the beginning cursor of the page + __atomic_store_n(&lastPage->GetPageHeader()->begCursor_, nextCursor, __ATOMIC_SEQ_CST); + // Add it to the chain + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + idxChain_.emplace_back(nextCursor - 1, lastPage); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Chain page<%s> [%zu, ) to index", LogPrefix(), + lastPage->GetPageId(), nextCursor); + // Swap the page. Increment the ref count by 1. + RETURN_IF_NOT_OK(lastPage->RefPage(FormatString("%s:%s", __FUNCTION__, __LINE__))); + lastPage_ = lastPage; + RETURN_IF_NOT_OK(VerifyLastPageRefCountNotLocked()); + // Update all local producers that a new page is added + RETURN_IF_NOT_OK(UpdateLocalCursorLastDataPage(lastPage->GetShmView())); + // Early exit + return Status::OK(); + } + // Update the lastPage_ because it can be stale due to the way producers can insert elements + // past the lastPage_ into the logical free pages. A simpler way is append an empty free list with + // no seal. + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + auto emptyPageList = Optional>(); + RETURN_IF_NOT_OK(AppendFreePagesImplNotLocked(timeoutMs, emptyPageList, false)); + lastPage = lastPage_; + // Now check if there is any free pages in the chain. If there is at least one, there is no + // need to create any new page. We can just check what's next after the lastPage_. + // We aren't going to lock this page to prevent any producer to insert. + { + ReadLockHelper rlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + RETURN_OK_IF_TRUE(!ackChain_.empty()); + } + // Two producers (both local and remote) can see the page is full since they share the same page + // and both sends create a new page request. We will serialize the requests and ensure we don't + // create redundant new pages. + // To tell the difference if it is a false alarm and a real request, each producer will tell + // us the page info it currently has. If it is the same as our last page, we will create a new + // one. Otherwise, we simply return our current last page. + ShmView lastPageView = lastPage_->GetShmView(); + if (lastView != lastPageView) { + // Producer's view is stale. Return our last page to this producer to try + return Status::OK(); + } + // At this point, we are going to create a new page. Due to the way we recycle used pages to + // the chain, we are going to do the same but will seal the last page. + std::shared_ptr page; + RETURN_IF_NOT_OK(CreateNewPage(page, retryOnOOM)); + std::list freeList; + freeList.emplace_back(0, std::move(page)); + auto freeListOptional = Optional>(freeList); + RETURN_IF_NOT_OK(AppendFreePagesImplNotLocked(timeoutMs, freeListOptional, true)); + lastPage = lastPage_; + return Status::OK(); +} + +Status PageQueueBase::CreateNewPage(std::shared_ptr &lastPage, bool retryOnOOM) +{ + RETURN_IF_NOT_OK(VerifyWhenAlloc()); + auto pageSize = GetPageSize(); + // Check if there is a pending free page that we can steal + { + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + if (!pendingFreePages_.empty()) { + auto &list = std::get(pendingFreePages_.front()); + lastPage = std::move(list.back()); + list.pop_back(); + pendingFreeBytes_ -= pageSize; + if (list.empty()) { + pendingFreePages_.pop_front(); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Recycle Page<%s>", LogPrefix(), lastPage->GetPageId()); + // Make sure we clear the next page pointer. The code path ReclaimAckedChain can move a list + // of free pages from ackChain which are chained to each other. + RETURN_IF_NOT_OK(lastPage->GetSharedMemViewForNextPage()->Init(true)); + return Status::OK(); + } + } + std::shared_ptr pageUnitInfo; + // We now count metadata as part of the page size. + RETURN_IF_NOT_OK(AllocMemory(pageSize, false, pageUnitInfo, retryOnOOM)); + // The lock id of worker is 0 by default. + auto page = std::make_shared(pageUnitInfo, 0, false, isSharedPage_); + RETURN_IF_NOT_OK(page->Init()); + // Worker has the extra work to initialize the page + RETURN_IF_NOT_OK(page->InitEmptyPage()); + lastPage = std::move(page); + return Status::OK(); +} + +Status PageQueueBase::AllocMemory(size_t pageSize, bool bigElement, std::shared_ptr &shmUnitInfo, + bool retryOnOOM) +{ + PerfPoint point(PerfKey::PAGE_CREATE_NEW); + auto pageUnit = std::make_unique(); + // maxStreamSize_ and pageSize_ are initialized to 0 in the constructor. They need to be set later + // in order to make the CreateNewPage usable. This check ensures they are initialized properly. + RETURN_IF_NOT_OK(VerifyWhenAlloc()); + { + // Two threads can call AllocMemory at the same time and both think they have enough memory, + // and can result in over exceeding the max stream size. So we will serialize the memory + // allocation using the allocMutex_ lock. + std::unique_lock lock(allocMutex_); + RETURN_IF_NOT_OK(CheckHadEnoughMem(pageSize)); + RETURN_IF_NOT_OK(AllocateMemoryImpl(pageSize, *pageUnit, retryOnOOM)); + usedMemBytes_ += pageSize; // Track the used bytes now that the allocation is successful + } + // Create another ShmUnitInfo because the destructor of ShmUnit will release the memory + auto pageUnitInfo = std::make_shared(); + pageUnitInfo->fd = pageUnit->fd; + pageUnitInfo->mmapSize = pageUnit->mmapSize; + pageUnitInfo->size = pageUnit->size; + pageUnitInfo->offset = pageUnit->offset; + pageUnitInfo->pointer = pageUnit->pointer; + auto page = std::make_shared(pageUnitInfo, false); + RETURN_IF_NOT_OK(page->Init()); + RETURN_IF_NOT_OK(AddPageToPool(page->GetPageId(), std::move(pageUnit), bigElement)); + if (bigElement) { + scMetricBigPagesCreated_.fetch_add(1, std::memory_order_relaxed); + } else { + scMetricPagesCreated_.fetch_add(1, std::memory_order_relaxed); + } + LOG(INFO) << FormatString("[%s] Create shared memory page<%s> successful.", LogPrefix(), page->GetPageId()); + shmUnitInfo = pageUnitInfo; + return Status::OK(); +} + +Status PageQueueBase::AddPageToPool(const std::string &pageId, std::unique_ptr &&pageUnit, bool bigElement) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + pageUnit->refCount = 1; // The ref count is used in the case of BigElement page only. + ShmPagesMap::accessor accessor; + auto memInfo = std::make_shared(); + memInfo->pageUnit = std::move(pageUnit); + memInfo->createTime = std::chrono::steady_clock::now(); + memInfo->bigElement = bigElement; + bool success = shmPool_.emplace(accessor, pageId, std::move(memInfo)); + RETURN_OK_IF_TRUE(success); + RETURN_STATUS(K_DUPLICATED, FormatString("[%s] Page %s already in the pool", LogPrefix(), pageId)); +} + +Status PageQueueBase::VerifyLastPageRefCountNotLocked() const +{ + auto refCount = lastPage_->GetRefCount(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(refCount > 1, // at least 2 + K_RUNTIME_ERROR, + FormatString("[%s] Unexpected ref count %zu on the last page<%s>", LogPrefix(), + refCount, lastPage_->GetPageId())); + return Status::OK(); +} + +Status PageQueueBase::AppendFreePagesImplNotLocked(uint64_t timeoutMs, + Optional> &optionalFreeList, bool seal, + const bool updateLocalPubLastPage) +{ + std::shared_ptr lastPage; + auto iter = idxChain_.begin(); + // Refresh lastPage_ if the producer(s) have reused some of the free pages in the ack chain + RETURN_IF_NOT_OK_EXCEPT(RefreshLastPage(iter, lastPage), K_NOT_FOUND); + RETURN_OK_IF_TRUE(lastPage == nullptr); + if (optionalFreeList) { + std::list &freeList = optionalFreeList.value(); + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + AddListToAckChain(freeList); + if (!ackChain_.empty()) { + // Link the next free page here to beginning of the ack chain. + // We should lock the last page if any producer is trying to seal it at the same time. + StreamPageLock pageLock(lastPage); + RETURN_IF_NOT_OK(pageLock.Lock(timeoutMs)); + if (!lastPage->HasNextPage()) { + auto &page = ackChain_.begin()->second; + auto &smv = lastPage->GetSharedMemViewForNextPage(); + LOG_IF_ERROR( + smv->SetView(page->GetShmView(), true, std::numeric_limits::max()), + FormatString("[%s] The page %s set next page view failed", LogPrefix(), lastPage->GetPageId())); + } + } + } + // If asked to seal the last page, make sure it is not empty and has a free page in the ack chain + if (seal && !lastPage->Empty()) { + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + if (!ackChain_.empty()) { + // Get a reference to the beginning of the free chain. It is possible a producer has already + // sealed the page. So we will tolerate K_DUPLICATED return code. + auto &ele = ackChain_.front(); + auto func = [&ele](const ShmView &v, std::shared_ptr &out) { + (void)v; + out = ele.second; + return Status::OK(); + }; + RETURN_IF_NOT_OK_EXCEPT(lastPage->Seal(ele.second->GetShmView(), timeoutMs, func, LogPrefix()), + K_DUPLICATED); + auto &page = ele.second; + ele.first = page->GetBegCursor() - 1; // Using the last cursor of lastPage as the key + idxChain_.emplace_back(ele); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Chain page<%s> [%zu, ) to index", LogPrefix(), + page->GetPageId(), page->GetBegCursor()); + nextCursor_.store(page->GetBegCursor(), std::memory_order_relaxed); + std::swap(page, lastPage); + iter = std::prev(idxChain_.end()); + ackChain_.pop_front(); + } + } + if (lastPage.get() != lastPage_.get()) { + RETURN_IF_NOT_OK(VerifyLastPageRefCountNotLocked()); + RETURN_IF_NOT_OK(lastPage_->ReleasePage(FormatString("[%s] %s:%s", LogPrefix(), __FUNCTION__, __LINE__))); + INJECT_POINT("AppendFreePagesImplNotLocked.sleep"); + // Swap the page. Increment the ref count by 1. + RETURN_IF_NOT_OK(lastPage->RefPage(FormatString("[%s] %s:%s", LogPrefix(), __FUNCTION__, __LINE__))); + lastPage_ = lastPage; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Last page updated to %s. begCursor %zu", LogPrefix(), + lastPage_->GetPageId(), lastPage_->GetBegCursor()); + RETURN_IF_NOT_OK(VerifyLastPageRefCountNotLocked()); + // Update all local producers that a new page is added + if (updateLocalPubLastPage) { + RETURN_IF_NOT_OK(UpdateLocalCursorLastDataPage(lastPage->GetShmView())); + } + } + return Status::OK(); +} + +Status PageQueueBase::RefreshLastPage(std::list::iterator &out, std::shared_ptr &lastPage) +{ + LogCursors(); + auto iter = idxChain_.begin(); + while (iter != idxChain_.end() && iter->second.get() != lastPage_.get()) { + ++iter; + } + CHECK_FAIL_RETURN_STATUS(iter != idxChain_.end(), K_NOT_FOUND, + FormatString("[%s] Unable to locate lastPage_", LogPrefix())); + lastPage = iter->second; + // If the lastPage_ is sealed by the producers, move those pages that have been reused back to idx chain. + while (lastPage->HasNextPage()) { + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + lastPage->GetBegCursor() > 0 && !lastPage->Empty(), K_OUT_OF_RANGE, + FormatString("[%s] Invalid page<%s> begCursor %zu slotCount %zu", LogPrefix(), lastPage->GetPageId(), + lastPage->GetBegCursor(), lastPage->GetSlotCount())); + uint64_t lastAppendCursor = lastPage->GetLastCursor(); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Page<%s> [%zu, %zu] is sealed by producer", LogPrefix(), + lastPage->GetPageId(), lastPage->GetBegCursor(), lastAppendCursor); + // Previous run of this function may hit lock timeout while we are moving things from + // ack chain to idx chain and leave the lastPage_ not pointing to the right place. + // But it is easy to verify the position of the iterator. + auto next = std::next(iter, 1); + if (next != idxChain_.end()) { + auto page = next->second; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + lastAppendCursor + 1 == page->GetBegCursor(), K_OUT_OF_RANGE, + FormatString("[%s] Expect page<%s> begCursor %zu but get %zu", LogPrefix(), page->GetPageId(), + lastAppendCursor + 1, page->GetBegCursor())); + std::swap(page, lastPage); + iter = next; + continue; + } + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + CHECK_FAIL_RETURN_STATUS( + !ackChain_.empty(), K_NOT_FOUND, + FormatString("[%s] Last page<%s> is sealed but there is no next page", LogPrefix(), lastPage->GetPageId())); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(lastAppendCursor + 1 == ackChain_.front().second->GetBegCursor(), + K_OUT_OF_RANGE, + FormatString("[%s] Expect page<%s> begCursor %zu but get %zu", LogPrefix(), + ackChain_.front().second->GetPageId(), lastAppendCursor + 1, + ackChain_.front().second->GetBegCursor())); + auto ele = std::move(ackChain_.front()); + ackChain_.pop_front(); + ele.first = lastAppendCursor; + auto page = ele.second; + idxChain_.emplace_back(ele); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Chain page<%s> [%zu, ) to index", LogPrefix(), + ele.second->GetPageId(), lastAppendCursor + 1); + nextCursor_.store(lastAppendCursor + 1, std::memory_order_relaxed); + std::swap(page, lastPage); + iter = std::prev(idxChain_.end()); + } + out = iter; + return Status::OK(); +} + +Status PageQueueBase::VerifyWhenAlloc() const +{ + return Status::OK(); +} + +void PageQueueBase::LogCursors() +{ + uint64_t totalElements = 0; + // Found from GetLastAppendCursor() + totalElements = (lastPage_ != nullptr) ? lastPage_->GetLastCursor() : 0; + + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString("[%s] %d received elements, %d elements acked", LogPrefix(), + totalElements, lastAckCursor_); +} + +void PageQueueBase::AddListToAckChain(std::list &freeList) +{ + std::shared_ptr prev; + if (!ackChain_.empty()) { + auto prevIt = std::prev(ackChain_.end()); + prev = prevIt->second; + } + auto iter = freeList.begin(); + while (iter != freeList.end()) { + if (prev) { + prev->GetSharedMemViewForNextPage()->SetView(iter->second->GetShmView(), true, + std::numeric_limits::max()); + } + ackChain_.emplace_back(0, iter->second); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Chain page<%s> to free list", LogPrefix(), + iter->second->GetPageId()); + prev = iter->second; + iter = freeList.erase(iter); + } +} + +Status PageQueueBase::Ack(uint64_t cursor, StreamMetaShm *streamMetaShm) +{ + static auto last = std::chrono::steady_clock::now(); + INJECT_POINT("ExclusivePageQueue.Ack.Start"); + std::list freeList; + // Producers can insert the elements past the lastPage_. We need it + // updated before we process the Ack chain. The easiest way is pass + // an empty freeList to AppendFreePages + RETURN_IF_NOT_OK(AppendFreePages(freeList)); + // Now we traverse to Ack chain. + std::vector bigElementPage; + std::vector status; + Status rc = AckImpl(cursor, freeList, bigElementPage, streamMetaShm); + status.emplace_back(rc); + if (freeList.empty() && bigElementPage.empty()) { + // If nothing to reclaim. Periodically dump the pool stat. But not too much to flood the log + const int dumpInterval = 10; + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - last).count() >= dumpInterval) { + last = now; + DumpPoolPages(0); + } + } + rc = ProcessBigElementPages(bigElementPage, streamMetaShm); + status.emplace_back(rc); + rc = ProcessAckedPages(cursor, freeList); + status.emplace_back(rc); + rc = AfterAck(); + status.emplace_back(rc); + auto iter = std::find_if(status.begin(), status.end(), [](auto &kv) { return kv.IsError(); }); + if (iter != status.end()) { + return (*iter); + } + return Status::OK(); +} + +Status PageQueueBase::AckImpl(uint64_t cursor, std::list &freeList, std::vector &bigElementPages, + StreamMetaShm *streamMetaShm) +{ + PerfPoint point(PerfKey::MANAGER_ACK_CURSOR_GET_LOCK); + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + point.RecordAndReset(PerfKey::MANAGER_ACK_CURSOR_LOGIC); + RETURN_OK_IF_TRUE(cursor <= lastAckCursor_); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Ack cursor %zu", LogPrefix(), cursor); + auto it = idxChain_.begin(); + while (it != idxChain_.end()) { + bool keepThisPageInChain = false; + auto &page = it->second; + auto id = page->GetPageId(); + auto slotCount = page->GetSlotCount(); + auto begCursor = page->GetBegCursor(); + auto lastCursor = it->first + slotCount; + if (begCursor > cursor) { + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("[%s] Ack cursor %zu stops at page<%s> range [%zu, %zu)", + LogPrefix(), cursor, id, it->first + 1, lastCursor + 1); + return Status::OK(); + } + // For a data page, we can only wait until its reference count drops to 1 to recycle/release + // the page. But that will take too long to release BigElement row. So we will browse through + // the slot and release the BigElement memory prior to ack cursor + RETURN_IF_NOT_OK(ReleaseBigElementsUpTo(cursor, page, bigElementPages, keepThisPageInChain)); + // Now we check the reference count. It is not 1, move on. + // When a page is created, its reference count is 1. See StreamDataPage::InitEmptyPage(). + // Any worker/producer/client will increase/decrease the reference count for access. + // Also, ExclusivePageQueue always maintains an additional reference to the last page. + // If this page has some unfinished BigElement, move on to the next page regardless of the reference + // count + auto count = page->GetRefCount(); + if (count != 1 || keepThisPageInChain) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "[%s] Page<%s> pending ack. NextAckCursor = %zu. Reference count %d", LogPrefix(), id, cursor, count); + // We can't free this page (yet) but move on to the next one. + ++it; + continue; + } + // it is the last page, won't do anything do it until the all producers/consumers are gone. + auto nextIt = std::next(it, 1); + if (nextIt == idxChain_.end()) { + break; + } + if (lastCursor < cursor) { + // Every row on this page has been processed by the clients. + RETURN_IF_NOT_OK(page->WakeUpConsumers()); + // Move forward the lastAckCursor_, it is possible we still have a page with + // begCursor < lastAckCursor_ due to reference count. So make sure we don't + // move it backward. + auto lastAckCursor = lastAckCursor_.load(std::memory_order_relaxed); + lastAckCursor_.store(std::max(lastCursor, lastAckCursor), std::memory_order_relaxed); + // Cache the free page + auto size = page->GetTotalEleSize(); + freeList.emplace_back(std::move(*it)); + it = idxChain_.erase(it); + if (streamMetaShm) { + LOG_IF_ERROR(streamMetaShm->TryDecUsage(size), "TryDecUsage failed"); + } + LogCursors(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Ack page<%s> [%zu, %zu] success.", LogPrefix(), id, + begCursor, lastCursor); + continue; + } else { + // begCursor <= cursor < lastCursor + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("[%s] Ack cursor %zu stops at page<%s> range [%zu, %zu)", + LogPrefix(), cursor, id, it->first + 1, lastCursor + 1); + return Status::OK(); + } + } + return Status::OK(); +} + +Status PageQueueBase::ReleaseBigElementsUpTo(uint64_t cursor, std::shared_ptr &page, + std::vector &bigElementPages, bool &keepThisPageInChain) +{ + std::vector> bigId; + RETURN_IF_NOT_OK(page->ExtractBigElementsUpTo(cursor, bigId, false)); + // Pop (from the back) all those BigElement pages that are in still in use + while (!bigId.empty()) { + int32_t bigRefCount; + auto pageInfo = std::make_shared("", bigId.back().second, nullptr); + auto pageId = StreamPageBase::CreatePageId(pageInfo); + RETURN_IF_NOT_OK(GetBigElementPageRefCount(pageId, bigRefCount)); + if (bigRefCount > 1) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] BigElementPage<%s> still in use", LogPrefix(), pageId); + // We need to come back to this page to reclaim the big memory + keepThisPageInChain = true; + bigId.pop_back(); + } else { + break; + } + } + if (!bigId.empty()) { + auto ackCursor = bigId.back().first; + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Release BigElement up to cursor %zu", LogPrefix(), ackCursor); + bigId.clear(); + // Note: Here the attribute of the big element will be removed. + RETURN_IF_NOT_OK(page->ExtractBigElementsUpTo(ackCursor, bigId, true)); + std::transform(bigId.begin(), bigId.end(), std::back_inserter(bigElementPages), + [](auto &kv) { return kv.second; }); + } + return Status::OK(); +} + +Status PageQueueBase::GetBigElementPageRefCount(const std::string &pageId, int32_t &refCount) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + ShmPagesMap::const_accessor accessor; + bool success = shmPool_.find(accessor, pageId); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(success, K_NOT_FOUND, + FormatString("[%s] Page<%s> not found", LogPrefix(), pageId)); + refCount = accessor->second->pageUnit->GetRefCount(); + return Status::OK(); +} + +Status PageQueueBase::AppendFreePages(std::list &freeList, const bool updateLocalPubLastPage) +{ + // The page lock acquisition is not expected to fail in this case, + // but there can be situation where the page and page lock need recovery. + // In that case, idxMutex_ needs to be released for TryUnlockByLockId logic, + // so we perform retry until success if error code is K_TRY_AGAIN from PageLock::Lock. + // Note that lastPageMutex_ also needs to be released for the + // CreateStreamManagerIfNotExist -> CreatePageZero code path deadlock. + const int32_t RETRY_FOREVER = std::numeric_limits::max(); + auto freeListOptional = Optional>(freeList); + auto func = [this, &freeListOptional, updateLocalPubLastPage](int32_t) { + // Block rpc and pause the idx and ack(in the correct order). + WriteLockHelper xlock1(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + return AppendFreePagesImplNotLocked(DEFAULT_TIMEOUT_MS, freeListOptional, false, updateLocalPubLastPage); + }; + RETURN_IF_NOT_OK(RetryOnError(RETRY_FOREVER, func, [] { return Status::OK(); }, { K_TRY_AGAIN })); + return Status::OK(); +} + +void PageQueueBase::DumpPoolPages(int level) const +{ + // Sort (in reverse order) based on duration + using second = std::chrono::duration>; + struct PageStrInfo { + second key; + std::string info; + }; + struct Compare { + bool operator()(const PageStrInfo &a, const PageStrInfo &b) + { + return a.key < b.key; + } + }; + std::priority_queue, Compare> bigEleQue; + std::priority_queue, Compare> dataPageQue; + std::ostringstream oss; + size_t numRegularPages = 0; + size_t numBigElements = 0; + size_t totalDataPageSz = 0; + size_t totalBigElementSz = 0; + size_t poolSize = 0; + + WriteLockHelper xlock4(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + poolSize = shmPool_.size(); + auto now = std::chrono::steady_clock::now(); + for (auto &ele : shmPool_) { + const std::string &pageId = ele.first; + auto &shmEleInfo = ele.second; + if (shmEleInfo->bigElement) { + ++numBigElements; + totalBigElementSz += shmEleInfo->pageUnit->size; + if (level >= SC_NORMAL_LOG_LEVEL) { + auto key = std::chrono::duration_cast(now - shmEleInfo->createTime); + auto val = FormatString("[%s] Page<%s>, Duration: [%.6lf]s", LogPrefix(), pageId, key.count()); + auto info = PageStrInfo{ .key = key, .info = val }; + bigEleQue.emplace(std::move(info)); + } + } else { + ++numRegularPages; + totalDataPageSz += GetPageSize(); + if (level >= SC_INTERNAL_LOG_LEVEL) { + auto key = std::chrono::duration_cast(now - shmEleInfo->createTime); + auto val = FormatString("[%s] Page<%s>, Duration: [%.6lf]s", LogPrefix(), pageId, key.count()); + auto info = PageStrInfo{ .key = key, .info = val }; + dataPageQue.emplace(std::move(info)); + } + } + } + + // Let go of the lock. We have gathered what we need. + xlock4.unlock(); + + // Summary + oss << FormatString( + "[%s] Dump pool stat (v=%d). Size of pool: %zu. Number of data pages: %zu. Total size of data pages: %zu. " + "Number of BigElement: %zu. Total size of BigElements: %zu\n", + LogPrefix(), level, poolSize, numRegularPages, totalDataPageSz, numBigElements, totalBigElementSz); + + // Only dump the summary when v=0 + if (level < SC_NORMAL_LOG_LEVEL) { + LOG(INFO) << oss.str(); + return; + } + + auto func = [&oss](std::priority_queue, Compare> &que, size_t maxLines) { + size_t i = 0; + while (!que.empty()) { + auto ele = que.top(); + que.pop(); + oss << FormatString("%s. seq:[%zu]\n", ele.info, ++i); + if (i >= maxLines) { + break; + } + } + }; + + // For v=1, dump the BigElements (page id only) + if (level >= SC_NORMAL_LOG_LEVEL && numBigElements > 0) { + oss << FormatString("[%s] Dump the first few BigElements (sort by duration)\n", LogPrefix()); + const size_t maxLines = 10; + func(bigEleQue, level == SC_DEBUG_LOG_LEVEL ? numBigElements : maxLines); + } + + // Dump the BigElements when v=1 + if (level < SC_INTERNAL_LOG_LEVEL) { + LOG(INFO) << oss.str(); + return; + } + + // Rest is v=2 + if (numRegularPages > 0) { + oss << FormatString("[%s] Dump the first few data pages (sort by duration)\n", LogPrefix()); + const size_t maxLines = 5; + func(dataPageQue, level == SC_DEBUG_LOG_LEVEL ? numRegularPages : maxLines); + } + + LOG(INFO) << oss.str(); +} + +Status PageQueueBase::MoveUpLastPage(const bool updateLocalPubLastPage) +{ + std::list freeList; + // Producers can insert the elements past the lastPage_. We need it + // updated before we process the Ack chain. The easiest way is pass + // an empty freeList to AppendFreePages + return AppendFreePages(freeList, updateLocalPubLastPage); +} + +Status PageQueueBase::ReclaimAckedChain(uint64_t timeoutMs) +{ + // Block rpc and pause the idx and ack(in the correct order). + WriteLockHelper xlock1(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + // Don't lock ack chain yet because we will call the function RefreshLastPage + std::shared_ptr lastPage; + auto iter = idxChain_.begin(); + // Refresh lastPage_ if the producer(s) have reused some of the free pages in the ack chain + RETURN_IF_NOT_OK_EXCEPT(RefreshLastPage(iter, lastPage), K_NOT_FOUND); + RETURN_OK_IF_TRUE(lastPage == nullptr); + // Now we can lock the ackChain + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + RETURN_OK_IF_TRUE(ackChain_.empty()); + // We have to be careful if we can take out a page from the ack chain + // because any producer can freely traverse the chain. The only + // way to block them is to lock the page. + while (true) { + StreamPageLock pageLock(lastPage); + RETURN_IF_NOT_OK(pageLock.Lock(timeoutMs)); + ShmView nextPageView; + std::shared_ptr nextPage; + bool isFreePage; + RETURN_IF_NOT_OK(lastPage->GetSharedMemViewForNextPage()->GetView(nextPageView, isFreePage, timeoutMs)); + RETURN_OK_IF_TRUE(nextPageView.fd <= 0); + RETURN_IF_NOT_OK(LocatePage(nextPageView, nextPage)); + if (!isFreePage) { + std::swap(lastPage, nextPage); + continue; + } + // We have the current page locked (all producers are blocked) and there is a pointer to + // a free page. We can't free this free page but rather the one(s) follow it because + // other producers may already spot the existence of this page and is trying to seal + // (and wait on our page lock). + iter = ackChain_.begin(); + while (iter != ackChain_.end()) { + ShmView v = iter->second->GetShmView(); + if (v != nextPageView) { + ++iter; + continue; + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Keeping Page<%s> in the free list", LogPrefix(), + nextPage->GetPageId()); + // As mentioned before, we can't free this one, and it is too dangerous. + // But we will free the rest. + ++iter; + break; + } + RETURN_OK_IF_TRUE(iter == ackChain_.end()); + // Two more things to do. + // Clear the next pointer + nextPage->GetSharedMemViewForNextPage()->SetView(ShmView(), false, std::numeric_limits::max()); + // Split the chain starting from iter + std::list freeList; + freeList.splice(freeList.end(), ackChain_, iter, ackChain_.end()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Move %zu pages out from ack chain", LogPrefix(), + freeList.size()); + return MoveFreeListToPendFree(0, freeList); + } + return Status::OK(); +} + +Status PageQueueBase::ReleaseMemory(const ShmView &pageView) +{ + std::shared_ptr pageInfo; + RETURN_IF_NOT_OK(LocatePage(pageView, pageInfo)); + auto bigElementPage = std::make_shared(pageInfo, true); + RETURN_IF_NOT_OK(bigElementPage->Init()); + LOG(INFO) << FormatString("[%s] Release big element page<%s>", LogPrefix(), bigElementPage->GetPageId()); + std::vector list; + list.push_back(bigElementPage->GetPageId()); + return FreePages(list, true); +} + +void PageQueueBase::TryUnlockByLockId(uint32_t lockId) +{ + // This form of recovery is obsolete except for down level client. + // The problem of this logic is the locks are acquired in the wrong + // order (which can lead to deadlock). The page can be locked + // due to client crash but the logic is trying to acquire idxMutex_ + // which is the opposite order of other code path where the idxMutex_ + // is locked first, and then the page lock. A better method is to + // use the cursor_ info. + WriteLockHelper xlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + auto it = idxChain_.begin(); + while (it != idxChain_.end()) { + it->second->TryUnlockByLockId(lockId); + ++it; + } +} + +void PageQueueBase::ForceUnlockMemViemForPages(uint32_t lockId) +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + for (auto &ele : shmPool_) { + if (ele.second->bigElement) { + continue; + } + VLOG(1) << FormatString("Try unlock for stream: %s, page id: %s, lockId: %zu", LogPrefix(), ele.first, lockId); + auto &pageUnit = ele.second->pageUnit; + auto pageUnitInfo = std::make_shared(); + pageUnitInfo->fd = pageUnit->fd; + pageUnitInfo->mmapSize = pageUnit->mmapSize; + pageUnitInfo->size = pageUnit->size; + pageUnitInfo->offset = pageUnit->offset; + pageUnitInfo->pointer = pageUnit->pointer; + + auto page = std::make_shared(pageUnitInfo, WORKER_LOCK_ID, false, isSharedPage_); + auto rc = page->Init(); + if (rc.IsError()) { + LOG(ERROR) << FormatString("%s, PageId: %s, page init failed: %s", LogPrefix(), ele.first, rc.ToString()); + continue; + } + + auto &smv = page->GetSharedMemViewForNextPage(); + if (smv != nullptr) { + auto msg = FormatString("%s, PageId: %s", LogPrefix(), ele.first); + LOG_IF_ERROR(smv->ForceUnLock(lockId, msg), "ForceUnLock for page failed"); + } + } +} + +Status PageQueueBase::ProcessBigElementPages(std::vector &bigElementId, StreamMetaShm *streamMetaShm) +{ + RETURN_OK_IF_TRUE(bigElementId.empty()); + std::vector freeList; + auto func = [this, &freeList](const ShmView &v) { + std::shared_ptr shmInfo; + RETURN_IF_NOT_OK(LocatePage(v, shmInfo)); + auto bigElementPage = std::make_shared(shmInfo, true); + RETURN_IF_NOT_OK(bigElementPage->Init()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Release big element page<%s>", LogPrefix(), + bigElementPage->GetPageId()); + freeList.emplace_back(bigElementPage->GetPageId()); + return Status::OK(); + }; + for (auto &v : bigElementId) { + (void)func(v); + } + return FreePages(freeList, true, streamMetaShm); +} + +Status PageQueueBase::LocatePage(const ShmView &v, std::shared_ptr &out) +{ + auto pageInfo = std::make_shared("", v, nullptr); + auto pageId = StreamPageBase::CreatePageId(pageInfo); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + ShmPagesMap::accessor accessor; + bool exist = shmPool_.find(accessor, pageId); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(exist, K_NOT_FOUND, + FormatString("[%s] Page %s not found", LogPrefix(), pageId)); + pageInfo->pointer = accessor->second->pageUnit->pointer; + out = std::move(pageInfo); + return Status::OK(); +} + +std::pair PageQueueBase::GetNextBlockedRequestSize() +{ + return std::make_pair(0, false); +} + +Status PageQueueBase::ProcessAckedPages(uint64_t cursor, std::list &freeList) +{ + // Some locking orders to consider + // StreamManager::UnblockCreators can hold this locks in this order + // (a) StreamManager::streamManagerBlockedListsMutex_ + // (b) Call StreamManager::AllocBigShmMemory + // (c) Call ExclusivePageQueue::ReclaimAckedChain + // (d) Hold ExclusivePageQueue::lastPageMutex_ + // (e) Hold ExclusivePageQueue::idxChain_; + // (f) Wait for ExclusivePageQueue::ackMutex_ + + // So we need to follow the same order. + // Find out the next request's size to determine if we should cache or free + + size_t nextReqSz; + bool bigElement; + std::tie(nextReqSz, bigElement) = GetNextBlockedRequestSize(); + { + WriteLockHelper xlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + RETURN_OK_IF_TRUE(freeList.empty() && pendingFreePages_.empty()); + // Clear all pages back to empty + for (auto &e : freeList) { + RETURN_IF_NOT_OK(e.second->ResetToEmpty()); + } + if (bigElement && (CheckHadEnoughMem(nextReqSz).GetCode() == K_OUT_OF_MEMORY) + && (!freeList.empty() || !pendingFreePages_.empty())) { + // We have a BigElement in the next request, and don't have enough stream memory to serve it. + // We will stop the caching, and free as much as we can. + LOG(WARNING) << FormatString( + "[%s] Not enough stream memory to handle BigElement %zu bytes request, used %zu", LogPrefix(), + nextReqSz, usedMemBytes_); + return MoveFreeListToPendFree(cursor, freeList); + } + // The size of the Ack chain should <= FLAGS_sc_cache_pages. + // That said, we may have reserved more pages than FLAGS_sc_cache_pages in the function + // ReserveStreamMemory and we will then not touch anything already in the ackChain. + // Before this function is called, we have already moved as many 'recycled' pages + // back to the idx chain as possible. The producers may still continue to exhaust + // the ack chain but we shouldn't block them just to maintain FLAGS_sc_cache_pages. + // So the producers may end up sending rpc to ask for more free pages. + // The value of FLAGS_sc_cache_pages should be carefully chosen. + const size_t chainLength = ackChain_.size(); + const size_t sizeToKeep = std::max(FLAGS_sc_cache_pages, 0); + std::list pendingList; + while (chainLength + freeList.size() > sizeToKeep) { + if (freeList.empty()) { + break; + } + auto ele = std::move(freeList.back()); + freeList.pop_back(); + pendingList.emplace_back(std::move(ele)); + } + // Return the memory back to the pool. We aren't going to free them right away + // due to some racing condition between the worker and the producer. We only + // free ack pages that are acked a while ago. That should give enough time + // for the producers to move away the ack pages. + if (!pendingList.empty() || !pendingFreePages_.empty()) { + RETURN_IF_NOT_OK(MoveFreeListToPendFree(cursor, pendingList)); + } + } + // Link them to the chain as 'logical' page. + // We're still calling the function even though the freeList can be empty. + // This is to continue the rest of the flow to unblock producers and creator. + // Also, we can do one more final update to the idx chain. + return AppendFreePages(freeList); +} + +Status PageQueueBase::FreePages(std::vector &pages, bool bigElementPage, StreamMetaShm *streamMetaShm) +{ + PerfPoint point(PerfKey::PAGE_RELEASE); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + auto b4 = usedMemBytes_.load(std::memory_order_relaxed); + size_t totalReleased = 0; + while (!pages.empty()) { + auto pageId = std::move(pages.back()); + pages.pop_back(); + ShmPagesMap::accessor accessor; + bool exist = shmPool_.find(accessor, pageId); + if (!exist) { + LOG(ERROR) << FormatString("[%s] Page %s not found", LogPrefix(), pageId); + continue; + } + auto &pageUnit = accessor->second->pageUnit; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + bigElementPage == accessor->second->bigElement, K_RUNTIME_ERROR, + FormatString("[%s, Page<%s>] BigElement attribute doesn't match.", LogPrefix(), pageId)); + size_t memSizeToRelease = bigElementPage ? pageUnit->size : GetPageSize(); + RETURN_IF_NOT_OK(pageUnit->FreeMemory()); + if (bigElementPage) { + if (streamMetaShm) { + LOG_IF_ERROR(streamMetaShm->TryDecUsage(memSizeToRelease), "TryDecUsage failed"); + } + scMetricBigPagesReleased_.fetch_add(1, std::memory_order_relaxed); + } else { + scMetricPagesReleased_.fetch_add(1, std::memory_order_relaxed); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Page<%s> is released", LogPrefix(), pageId); + shmPool_.erase(accessor); + usedMemBytes_ -= memSizeToRelease; + totalReleased += memSizeToRelease; + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Total memory release %zu. Latest usedMemBytes %zu", LogPrefix(), + totalReleased, b4 - totalReleased); + return Status::OK(); +} + +Status PageQueueBase::MoveFreeListToPendFree(uint64_t cursor, std::list &freeList) +{ + RETURN_IF_NOT_OK(FreePendingList()); + std::vector> pagesToFree; + while (!freeList.empty()) { + auto ele = std::move(freeList.back()); + freeList.pop_back(); + pagesToFree.push_back(std::move(ele.second)); + } + if (!pagesToFree.empty()) { + auto now = std::chrono::steady_clock::now(); + pendingFreeBytes_ += pagesToFree.size() * GetPageSize(); + auto ele = std::make_tuple(cursor, now, std::move(pagesToFree)); + pendingFreePages_.emplace_back(std::move(ele)); + } + return Status::OK(); +} + +Status PageQueueBase::FreePendingList() +{ + // Return the memory back to the pool. We aren't going to free them right away + // due to some racing condition between the worker and the producer. We only + // free ack pages that are acked a while ago. That should give enough time + // for the producers to move away the ack pages. + auto now = std::chrono::steady_clock::now(); + if (!pendingFreePages_.empty()) { + std::chrono::time_point start; + uint64_t begCursor; + std::tie(begCursor, start, std::ignore) = pendingFreePages_.front(); + const int interval = 12; + if (std::chrono::duration_cast(now - start).count() >= interval) { + auto ele = std::move(pendingFreePages_.front()); + pendingFreePages_.pop_front(); + std::vector freePages; + auto &list = std::get(ele); + std::transform(list.begin(), list.end(), std::back_inserter(freePages), + [](const auto &kv) { return kv->GetPageId(); }); + RETURN_IF_NOT_OK(FreePages(freePages)); + pendingFreeBytes_ -= list.size() * GetPageSize(); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Free pages from cursor %zu ack", LogPrefix(), begCursor); + } + } + return Status::OK(); +} + +Status PageQueueBase::LocatePage(const ShmView &v, std::shared_ptr &out) +{ + std::shared_ptr pageInfo; + RETURN_IF_NOT_OK(LocatePage(v, pageInfo)); + auto page = std::make_shared(pageInfo, 0, false, isSharedPage_); + RETURN_IF_NOT_OK(page->Init()); + out = std::move(page); + return Status::OK(); +} + +ShmView PageQueueBase::GetLastPageShmView() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(lastPageMutex_)); + if (lastPage_) { + return lastPage_->GetShmView(); + } + return ShmView(); +} + +Status PageQueueBase::LocatePage(uint64_t lastAppendCursor, std::shared_ptr &out, bool incRef) +{ + // We are going to lock both chains in the correct order to avoid deadlock. + ReadLockHelper rlock2(STREAM_COMMON_LOCK_ARGS(idxMutex_)); + ReadLockHelper rlock3(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + // If both chains are empty, return K_NOT_FOUND + if (idxChain_.empty() && ackChain_.empty()) { + RETURN_STATUS(K_NOT_FOUND, FormatString("[%s] Stream index is empty.", LogPrefix())); + } + // Now we want to locate the page that contains '1 + lastAppendCursor' + // The way we build the index is using the cursor of the last slot of *previous* page. + // Here is the tricky part. Producers can insert past the idxChain and continue onto + // the ackChain to recycle the pages. + CHECK_FAIL_RETURN_STATUS(lastAppendCursor >= lastAckCursor_, K_NOT_FOUND, "Page has been released already"); + auto funcName = __FUNCTION__; + auto func = [this, incRef, &out, &funcName](const std::shared_ptr &page) { + // If asked to increase the reference, do it while we are holding the idx and ack chain mutex + if (incRef) { + RETURN_IF_NOT_OK(page->RefPage(FormatString("[%s] %s:%s", LogPrefix(), funcName, __LINE__))); + } + out = page; + return Status::OK(); + }; + // We will start with the idxChain, then continue to the ackChain + std::shared_ptr cur = nullptr; + std::shared_ptr prev = nullptr; + auto it = idxChain_.begin(); + while (it != idxChain_.end()) { + uint64_t endCursor = it->first; + cur = it->second; + auto lastCursor = endCursor + cur->GetSlotCount(); + if (endCursor < lastCursor && lastCursor <= lastAckCursor_) { + // This page is pending released. Move onto + // the next one. No need to update prev. + ++it; + continue; + } + // Loop invariant: + // endCursor < Every cursor on it->second + if (lastAppendCursor == endCursor) { + // cur is what we are looking for because + // the begCursor of the page is 1+lastAppendCursor + return func(cur); + } else if (lastAppendCursor < endCursor) { + // prev is what we are looking for unless it is null + if (prev == nullptr) { + RETURN_STATUS(K_NOT_FOUND, + FormatString("[%s] cursor %zu has been ack'ed already", LogPrefix(), lastAppendCursor)); + } + return func(prev); + } else { + prev = cur; + ++it; + } + } + // If we get here, we have to continue onto the ack chain which however we can't use link + // and must follow the next pointer on the page instead. + CHECK_FAIL_RETURN_STATUS(prev != nullptr, K_NOT_FOUND, + FormatString("[%s, cursor %zu] prev is null.", LogPrefix(), lastAppendCursor)); + while (prev->HasNextPage()) { + uint64_t endCursor = prev->GetLastCursor(); + // The logic is similar + if (lastAppendCursor < endCursor) { + return func(prev); + } + ShmView v = prev->GetNextPage(); + RETURN_IF_NOT_OK(LocatePage(v, cur)); + if (lastAppendCursor == endCursor) { + // cur is what we are looking for because + // the begCursor of the page is 1+lastAppendCursor + return func(cur); + } + prev = cur; + } + // We do not have to increase the reference count. This call is made for the purpose of + // consumers (both local and remote) and we won't ack any page consumers are still reading. + return func(cur); +} + +Status PageQueueBase::ScanAndEval(uint64_t &lastAckCursor, uint64_t timeoutMs, + const std::vector &remoteWorkers, ScanFlags flag) +{ + // By design, lastAckCursor_ is always on page boundary. Ensure lastAckCursor is within range. + CHECK_FAIL_RETURN_STATUS(lastAckCursor_ <= lastAckCursor, K_INVALID, + FormatString("[%s] lastAckCursor [%zu] is invalid. Stream has been reclaimed to %zu", + LogPrefix(), lastAckCursor, lastAckCursor_)); + auto funcName = __FUNCTION__; + do { + INJECT_POINT("ExclusivePageQueue.ScanAndEval.wait"); + std::shared_ptr lastPage; + // Usually we don't increase the reference for consumer because only us can ack the page. + // But the page we get back isn't the page that contains the lastAckCursor+1, and it can + // be the last page of the idxChain + RETURN_IF_NOT_OK(LocatePage(lastAckCursor, lastPage, true)); + RETURN_RUNTIME_ERROR_IF_NULL(lastPage); + const std::string logPrefix = FormatString("%s Page:%s", LogPrefix(), lastPage->GetPageId()); + Raii unfix([&lastPage, logPrefix, &funcName]() { + // We asked to fix the page with an extra reference count. + LOG_IF_ERROR(lastPage->ReleasePage(FormatString("[%s] %s:%s", logPrefix, funcName, __LINE__)), + "Page unfix"); + }); + std::vector dirtyElements; + Status rc = lastPage->Receive(lastAckCursor, timeoutMs, dirtyElements); + if (rc.GetCode() == K_SC_END_OF_PAGE) { + // Ensure lastAckCursor is the last cursor on this page + auto numElements = lastPage->GetSlotCount(); + auto begCursor = lastPage->GetBegCursor(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + lastAckCursor + 1 == begCursor + numElements, K_OUT_OF_RANGE, + FormatString("[%s] LastAppendCursor mismatch. lastAckCursor %zu, begCursor %zu, numSlots %zu", + logPrefix, lastAckCursor, begCursor, numElements)); + // New dirty data on the next page. Early exit if pageBreak is set + if (TESTFLAG(flag, ScanFlags::PAGE_BREAK)) { + return rc; + } + continue; + } + // Another possible error is K_TRY_AGAIN. Basically it means no new elements are detected + // since the last check. No need to log error for this case + if (rc.GetCode() == K_TRY_AGAIN) { + return rc; + } + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(rc, FormatString("[S:%s]", logPrefix)); + RETURN_OK_IF_TRUE(dirtyElements.empty()); + // Ensure that the producer will not write new data + RequestCounter::GetInstance().ResetLastArrivalTime("ExclusivePageQueue::ScanAndEval"); + INJECT_POINT("ExclusivePageQueue.ScanAndEval"); + uint64_t nextAppendCursor = lastAckCursor + dirtyElements.size(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Eval cursor [%zu, %zu)", logPrefix, lastAckCursor + 1, + nextAppendCursor + 1); + // Pass the dirty elements to the destination + for (const auto &remoteWorker : remoteWorkers) { + RETURN_IF_NOT_OK( + SendElements(lastPage, lastAckCursor + 1, nextAppendCursor + 1, remoteWorker, dirtyElements)); + } + // Return the next append cursor + lastAckCursor = nextAppendCursor; + if (TESTFLAG(flag, ScanFlags::EVAL_BREAK)) { + break; + } + } while (true); + return Status::OK(); +} + +Status PageQueueBase::SendElements(const std::shared_ptr &page, uint64_t begCursor, uint64_t endCursor, + const std::string &remoteWorker, std::vector recvElements) +{ + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[RW:%s, %s] Flush element cursor:[%zu, %zu) PageId: %s", remoteWorker, + LogPrefix(), begCursor, endCursor, page->GetPageId()); + // Filter out elements that are from remote producers. We should not send them again remotely. + CHECK_FAIL_RETURN_STATUS( + endCursor - begCursor == recvElements.size(), K_OUT_OF_RANGE, + FormatString("[RW:%s, %s] Range not match. begCursor = %zu, endCursor = %zu, vector size = %zu", remoteWorker, + LogPrefix(), begCursor, endCursor, recvElements.size())); + auto *remoteWorkerManager = GetRemoteWorkerManager(); + // Elements are written in reverse order. But we can still pack them together as long + // as the receiving worker knows how to walk the payload. + auto iter = recvElements.begin(); + while (iter != recvElements.end()) { + std::shared_ptr elementView; + RETURN_IF_NOT_OK(SendElementView::CreateSendElementView(page, remoteWorker, *iter, SharedFromThis(), + remoteWorkerManager, elementView)); + ++iter; + // BigElement should be sent in its own PV. Otherwise, pack element of the same nature in one PV. + while (iter != recvElements.end()) { + if (!elementView->PackDataElement(*iter, false, remoteWorkerManager)) { + break; + } + ++iter; + } + // Pass the PV to RWM + RETURN_IF_NOT_OK(remoteWorkerManager->SendElementsView(elementView)); + } + return Status::OK(); +} + +Status PageQueueBase::IncBigElementPageRefCount(const std::string &pageId) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + ShmPagesMap::accessor accessor; + bool success = shmPool_.find(accessor, pageId); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(success, K_NOT_FOUND, + FormatString("[%s] Page<%s> not found", LogPrefix(), pageId)); + auto &pageUnit = accessor->second->pageUnit; + pageUnit->IncrementRefCount(); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] BigElement page<%s> ref count %zu", LogPrefix(), pageId, + pageUnit->refCount); + return Status::OK(); +} + +Status PageQueueBase::ExtractBigElement(DataElement &ele, std::shared_ptr &bigElementPage) +{ + // Double check it is a big element. But we can simply tolerate it. + RETURN_OK_IF_TRUE(!ele.IsBigElement()); + ShmView v; + RETURN_IF_NOT_OK(StreamDataPage::ParseShmViewPb(ele.ptr, ele.size, v)); + std::shared_ptr pageInfo; + RETURN_IF_NOT_OK(LocatePage(v, pageInfo)); + auto page = std::make_shared(pageInfo, false); + RETURN_IF_NOT_OK(page->Init()); + // Replace the original pointer with the big element pointer + ele.ptr = reinterpret_cast(page->GetPointer()); + ele.size = page->PageSize(); + bigElementPage = std::move(page); + return Status::OK(); +} + +Status PageQueueBase::DecBigElementPageRefCount(const std::string &pageId) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + ShmPagesMap::accessor accessor; + bool success = shmPool_.find(accessor, pageId); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(success, K_NOT_FOUND, + FormatString("[%s] Page<%s> not found", LogPrefix(), pageId)); + auto &pageUnit = accessor->second->pageUnit; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + pageUnit->refCount >= 1, K_OUT_OF_RANGE, + FormatString("[%s] Page<%s> ref count %zu unexpected", LogPrefix(), pageId, pageUnit->refCount)); + pageUnit->DecrementRefCount(); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] BigElement page<%s> ref count %zu", LogPrefix(), pageId, + pageUnit->refCount); + return Status::OK(); +} + +Status PageQueueBase::UpdatePageRefIfExist(const ShmView &v, const std::string &logPrefix, const bool toggle) +{ + auto pageInfo = std::make_shared("", v, nullptr); + auto pageId = StreamPageBase::CreatePageId(pageInfo); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(poolMutex_)); + ShmPagesMap::accessor accessor; + bool exist = shmPool_.find(accessor, pageId); + if (!exist) { + const std::string errMsg = FormatString("[%s] Page %s not found", LogPrefix(), pageId); + LOG(INFO) << errMsg; + return { K_NOT_FOUND, errMsg }; + } + pageInfo->pointer = accessor->second->pageUnit->pointer; + auto page = std::make_shared(pageInfo, 0, false, isSharedPage_); + RETURN_IF_NOT_OK(page->Init()); + if (toggle) { + return page->RefPage(logPrefix); + } else { + return page->ReleasePage(logPrefix); + } +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/page_queue/page_queue_base.h b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.h new file mode 100644 index 0000000..d40d873 --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/page_queue_base.h @@ -0,0 +1,304 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: PageQueueBase + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_PAGE_QUEUE_BASE_H +#define DATASYSTEM_WORKER_STREAM_CACHE_PAGE_QUEUE_BASE_H + +#include +#include +#include + +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/stream_cache/stream_data_page.h" + +DS_DECLARE_uint32(sc_cache_pages); + +#define STREAM_COMMON_LOCK_ARGS(lockname) \ + (lockname), [this, funName = __FUNCTION__] { \ + return FormatString("%s %s, %s:%s", LogPrefix(), #lockname, funName, __LINE__); \ + } + +namespace datasystem { +namespace worker { +namespace stream_cache { +class RemoteWorkerManager; +enum class ScanFlags : uint32_t { NONE = 0, SCAN_LOCK = 1u, PAGE_BREAK = 1u << 1, EVAL_BREAK = 1u << 2 }; +ENABLE_BITMASK_ENUM_OPS(ScanFlags); + +class PageQueueBase { +public: + struct ShmMemInfo { + std::unique_ptr pageUnit; + std::chrono::time_point createTime; + bool bigElement; + }; + using ShmPagesMap = tbb::concurrent_hash_map>; + using PageShmUnit = std::pair>; + + PageQueueBase(); + virtual ~PageQueueBase() = default; + + /** + * @brief Create or Get the last data page + * @param timeoutMs in millisecond + * @param lastView ShmView of the caller's last page (if any) + * @param lastPage output + * @param retryOnOOM Retry on OOM if true. + * @return OK if successful + * @note If the caller's lastView is out dated, we will return the last page + * rather than creating a new one. Only the lastView matches the + * current last page, we will create a new page + */ + Status CreateOrGetLastDataPage(uint64_t timeoutMs, const ShmView &lastView, + std::shared_ptr &lastPage, bool retryOnOOM); + + /** + * @brief Ack to see if any streamPageView already consumed by all and can be erased. + * @param[in] cursor advanced ack cursor. + * @param[in] streamMetaShm The pointer to streamMetaShm + * @return K_OK on success; the error code otherwise. + */ + Status Ack(uint64_t cursor, StreamMetaShm *streamMetaShm = nullptr); + + /** + * @brief Locate a page that contains lastAppendCursor + 1 + * @param lastAppendCursor + * @param out + * @return OK if found + */ + Status LocatePage(uint64_t lastAppendCursor, std::shared_ptr &out, bool incRef = false); + + /** + * @brief Allocate memory and append to PageQueue. + * @param[in] pageSz The pageSize. + * @param[in] bigElement Is big element or not. + * @param[out] pageUnitInfo The shm page info. + * @param[in] retryOnOOM Retry on OOM if true. + * @return Status of this call + */ + Status AllocMemory(size_t pageSz, bool bigElement, std::shared_ptr &pageUnitInfo, bool retryOnOOM); + + /** + * @brief Scan data from page queue and send to remote workers. + * @param[in/out] lastAckCursor The last ack cursor. + * @param[in] timeoutMs The scan timeout. + * @param[in] remoteWorkers The remote worker address. + * @param[in] flag The scan flags. + * @return Status of this call + */ + Status ScanAndEval(uint64_t &lastAckCursor, uint64_t timeoutMs, const std::vector &remoteWorkers, + ScanFlags flag); + + /** + * Dump memory pool to log for diagnostic + */ + void DumpPoolPages(int level) const; + + /** + * @brief Force to update the last page location + * @param[in] updateLocalPubLastPage If local producers need to get an update of the lastest last page. + * @return Status object + */ + Status MoveUpLastPage(const bool updateLocalPubLastPage = true); + + Status ReclaimAckedChain(uint64_t timeoutMs); + + Status ReleaseMemory(const ShmView &pageView); + + /** + * Crash recovery based on lockId + * @param lockId + */ + void TryUnlockByLockId(uint32_t lockId); + + /** + * @brief Crash recovery for lost client to unlock mem view current page queue. + * @param[in] lockId The lock id. + */ + void ForceUnlockMemViemForPages(uint32_t lockId); + + /** + * @brief Get the StreamDataPage by ShmView. + * @param[in] v The ShmView instance. + * @param[out] out The StreamDataPage instance. + * @return Status of this call. + */ + Status LocatePage(const ShmView &v, std::shared_ptr &out); + + /** + * @brief Get the ShmView of the last page + * @return The ShmView of the last page. + */ + ShmView GetLastPageShmView(); + +public: + /** + * @brief Debugging log prefix + * @return the log profex. + */ + virtual std::string LogPrefix() const = 0; + + /** + * @brief Gets the page size + * @return The page size value + */ + virtual size_t GetPageSize() const = 0; + + /** + * @brief Verify when alloc memory. + * @return Status of this call. + */ + virtual Status VerifyWhenAlloc() const; + + /** + * @brief Check if it had enough memory for this stream. + * @param[in] memSize The memory size. + * @return K_OK on success; the error code otherwise. + */ + virtual Status CheckHadEnoughMem(size_t memSize) = 0; + + /** + * @brief The implement of allocate shared memory for stream cache. + * @param[in] memSizeNeeded The page size. + * @param[out] shmUnit The share unit instance. + * @return Status of this call. + */ + virtual Status AllocateMemoryImpl(size_t memSizeNeeded, ShmUnit &shmUnit, bool retryOnOOM) = 0; + + /** + * @brief Update cursor to notify the last page changed. + * @param[in] shmView The ShmView for the last page. + * @return Status of this call + */ + virtual Status UpdateLocalCursorLastDataPage(const ShmView &shmView) = 0; + + /** + * @brief Call after ack finish. + * @return Status of this call + */ + virtual Status AfterAck() = 0; + + /** + * @brief Get encryptStream of streamFields_, and apply sanity checks. + * @return true if stream data encryption is applicable. + */ + virtual bool IsEncryptStream(const std::string &streamName) const = 0; + + /** + * @brief Get the stream name or the shared page identifier. + * @return Stream name or the shared page identifier. + */ + virtual std::string GetStreamName() const = 0; + + /** + * @brief Get remote worker manager. + * @return Raw ptr to remote worker manager. + */ + virtual RemoteWorkerManager *GetRemoteWorkerManager() const = 0; + + virtual Status IncBigElementPageRefCount(const std::string &pageId); + virtual Status ExtractBigElement(DataElement &ele, std::shared_ptr &bigElementPage); + virtual Status DecBigElementPageRefCount(const std::string &pageId); + virtual Status UpdatePageRefIfExist(const ShmView &v, const std::string &logPrefix, bool toggle); + + /** + * @brief Get the blocked request size info. + * @return the blocked request size info + */ + virtual std::pair GetNextBlockedRequestSize(); + + /** + * @brief Send element to remote worker. + * @param[in] page The data page for the element + * @param[in] begCursor The begin cursor + * @param[in] endCursor The end cursor + * @param[in] remoteWorker The remote worker address. + * @param[in] recvElements The element need send to remote worker + * @return Status of this call + */ + virtual Status SendElements(const std::shared_ptr &page, uint64_t begCursor, uint64_t endCursor, + const std::string &remoteWorker, std::vector recvElements); + + /** + * @brief Getter of lastAppendCursor_. + * @return lastAppendCursor_. + */ + [[nodiscard]] uint64_t GetLastAppendCursor() const; + +protected: + virtual std::shared_ptr SharedFromThis() = 0; + void LogCursors(); + Status CreateOrGetLastDataPageImpl(uint64_t timeoutSec, const ShmView &lastView, + std::shared_ptr &lastPage, bool retryOnOOM); + + Status CreateNewPage(std::shared_ptr &lastPage, bool retryOnOOM); + Status AddPageToPool(const std::string &pageId, std::unique_ptr &&pageUnit, bool bigElement); + Status VerifyLastPageRefCountNotLocked() const; + Status AppendFreePagesImplNotLocked(uint64_t timeoutMs, Optional> &freeList, bool seal, + const bool updateLocalPubLastPage = true); + Status RefreshLastPage(std::list::iterator &iter, std::shared_ptr &lastPage); + void AddListToAckChain(std::list &freeList); + Status AckImpl(uint64_t cursor, std::list &freeList, std::vector &bigElementPage, + StreamMetaShm *streamMetaShm = nullptr); + Status ReleaseBigElementsUpTo(uint64_t cursor, std::shared_ptr &page, + std::vector &bigElementPages, bool &keepThisPageInChain); + Status GetBigElementPageRefCount(const std::string &pageId, int32_t &refCount); + Status AppendFreePages(std::list &freeList, const bool updateLocalPubLastPage = true); + Status ProcessBigElementPages(std::vector &bigElementId, StreamMetaShm *streamMetaShm); + Status LocatePage(const ShmView &v, std::shared_ptr &out); + Status ProcessAckedPages(uint64_t cursor, std::list &freeList); + Status FreePages(std::vector &pages, bool bigElementPage = false, + StreamMetaShm *streamMetaShm = nullptr); + Status FreePendingList(); + Status MoveFreeListToPendFree(uint64_t cursor, std::list &freeList); + + mutable std::shared_timed_mutex ackMutex_; // protect ackChain_/pendingFreePages_ + mutable std::shared_timed_mutex idxMutex_; // protect idxChain_ + mutable std::shared_timed_mutex poolMutex_; // protect shmPool_ + mutable std::shared_timed_mutex lastPageMutex_; // protect lastPage_ + std::mutex allocMutex_; // protect shm memory allocate to avoid exceed the max stream size limit. + + std::list ackChain_; + std::list idxChain_; + ShmPagesMap shmPool_; + + static constexpr int K_FREE_LIST = 2; + std::deque, + std::vector>>> + pendingFreePages_; + std::atomic pendingFreeBytes_{ 0 }; + + std::shared_ptr lastPage_; // last page. ref count > 0 + std::atomic usedMemBytes_; + std::atomic_uint64_t lastAckCursor_; + std::atomic_uint64_t nextCursor_; + + // Stream metrics variables + std::atomic scMetricPagesCreated_{ 0 }; + std::atomic scMetricPagesReleased_{ 0 }; + std::atomic scMetricBigPagesCreated_{ 0 }; + std::atomic scMetricBigPagesReleased_{ 0 }; + + bool isSharedPage_{ false }; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/page_queue/page_queue_handler.cpp b/src/datasystem/worker/stream_cache/page_queue/page_queue_handler.cpp new file mode 100644 index 0000000..b82af4e --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/page_queue_handler.cpp @@ -0,0 +1,420 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: PageQueueHandler + */ + +#include "datasystem/worker/stream_cache/page_queue/page_queue_handler.h" +#include +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/stream_cache/page_queue/exclusive_page_queue.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +PageQueueHandler::PageQueueHandler(StreamManager *mgr, Optional cfg) +{ + exclusivePageQueue_ = CreateExclusivePageQueue(mgr, cfg); + if (exclusivePageQueue_) { + exclusivePageQueue_->RegisterUpdateLastDataPageHandler( + [this](const ShmView &shmView) { return UpdateLocalCursorLastDataPage(shmView); }); + } + enableSharedPage_ = cfg && StreamManager::EnableSharedPage(cfg->streamMode_); + if (mgr != nullptr) { + streamName_ = mgr->GetStreamName(); + } +} + +std::shared_ptr PageQueueHandler::CreateExclusivePageQueue(StreamManager *mgr, + Optional cfg) +{ + return std::make_shared(mgr, cfg); +} + +Status PageQueueHandler::UpdateStreamFields(const StreamFields &streamFields) +{ + RETURN_IF_NOT_OK(exclusivePageQueue_->UpdateStreamFields(streamFields)); + enableSharedPage_ = StreamManager::EnableSharedPage(streamFields.streamMode_); + return Status::OK(); +} + +Status PageQueueHandler::CreateOrGetLastDataPage(uint64_t timeoutMs, const ShmView &lastView, + std::shared_ptr &lastPage, bool retryOnOOM) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + return sharedPageQueue_->CreateOrGetLastDataPage(timeoutMs, lastView, lastPage, retryOnOOM); + } + } + return exclusivePageQueue_->CreateOrGetLastDataPage(timeoutMs, lastView, lastPage, retryOnOOM); +} + +bool PageQueueHandler::ExistsSharedPageQueue() const +{ + std::shared_lock locker(mutex_); + return sharedPageQueue_ != nullptr; +} + +std::string PageQueueHandler::LogPrefix() const +{ + return FormatString("S:%s", streamName_); +} + +Status PageQueueHandler::AddCursor(const std::string &id, bool isProducer, std::shared_ptr &out, ShmView &view) +{ + auto lastPageShmView = exclusivePageQueue_->GetLastPageShmView(); + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(cursorMutex_)); + INJECT_POINT("worker.AddCursor.afterLockCursorMutex"); + CHECK_FAIL_RETURN_STATUS(cursorMap_.find(id) == cursorMap_.end(), K_DUPLICATED, + FormatString("[%s Add:%s] already created", streamName_, id)); + auto func = [this, id](std::deque> &cache, size_t sz, std::unique_ptr &shmUnit) { + if (cache.empty()) { + shmUnit = std::make_unique(); + shmUnit->SetHardFreeMemory(); + std::string tenantId = TenantAuthManager::ExtractTenantId(streamName_); + RETURN_IF_NOT_OK(shmUnit->AllocateMemory(tenantId, sz, false, ServiceType::STREAM)); + } else { + shmUnit = std::move(cache.front()); + cache.pop_front(); + } + auto rc = memset_s(shmUnit->GetPointer(), sz, 0, sz); + CHECK_FAIL_RETURN_STATUS(rc == 0, K_RUNTIME_ERROR, + FormatString("[S:%s, Add:%s] Memset to 0 results in errno %d", streamName_, id, rc)); + return Status::OK(); + }; + std::unique_ptr shmUnit; + const size_t cursorSize = Cursor::K_CURSOR_SIZE_V2; + RETURN_IF_NOT_OK(func(cacheCursor_, cursorSize, shmUnit)); + + auto cursor = std::make_shared(shmUnit->GetPointer(), cursorSize, WORKER_LOCK_ID); + RETURN_IF_NOT_OK(cursor->Init()); + RETURN_IF_NOT_OK(cursor->SetWorkerVersion(Cursor::K_WORKER_EYECATCHER_V1)); + + // last page ref + std::unique_ptr lastPageRefShmUnit; + std::shared_ptr lastPageRefShmViewImpl; + if (enableSharedPage_ && isProducer) { + ShmView lastPageRefShmView; + bool usingSharedPageQueue; + { + std::shared_lock locker(mutex_); + usingSharedPageQueue = sharedPageQueue_ != nullptr; + } + if (usingSharedPageQueue) { + RETURN_IF_NOT_OK(sharedPageQueue_->GetOrCreateLastPageRef(lastPageRefShmView, lastPageRefShmViewImpl)); + } else { + const size_t lastPageRefSize = sizeof(SharedMemView); + RETURN_IF_NOT_OK(func(cacheLastPageRef_, lastPageRefSize, lastPageRefShmUnit)); + lastPageRefShmView = lastPageRefShmUnit->GetShmView(); + lastPageRefShmViewImpl = + std::make_shared(lastPageRefShmUnit->GetPointer(), lastPageRefSize, WORKER_LOCK_ID); + RETURN_IF_NOT_OK(lastPageRefShmViewImpl->Init(false)); + } + + RETURN_IF_NOT_OK( + cursor->SetLastPageRef(lastPageRefShmView, std::numeric_limits::max(), usingSharedPageQueue)); + LOG(INFO) << FormatString("[%s, Add:%s] Update the last page ref to %s", LogPrefix(), id, + lastPageRefShmView.ToStr()); + } else { + RETURN_IF_NOT_OK(cursor->SetLastPage(lastPageShmView, std::numeric_limits::max())); + LOG(INFO) << FormatString("[%s, Add:%s] Update the last page to %s", LogPrefix(), id, lastPageShmView.ToStr()); + } + + out = cursor; + view = shmUnit->GetShmView(); + // Add them to the cursor map + auto cInfo = std::make_unique(); + cInfo->shmUnit = std::move(shmUnit); + cInfo->cursor = std::move(cursor); + cInfo->lastPageRefShmUnit = std::move(lastPageRefShmUnit); + cInfo->lastPageRefShmViewImpl = std::move(lastPageRefShmViewImpl); + (void)cursorMap_.emplace(id, std::move(cInfo)); + LOG(INFO) << FormatString("[%s, Add:%s] Cursor added. Number of cursors %zu", streamName_, id, cursorMap_.size()); + return Status::OK(); +} + +Status PageQueueHandler::DeleteCursor(const std::string &id) +{ + constexpr static size_t maxCache = 5; + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(cursorMutex_)); + auto iter = cursorMap_.find(id); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(iter != cursorMap_.end(), K_NOT_FOUND, + FormatString("[%s] cursor for %s not found", streamName_, id)); + std::unique_ptr cInfo = std::move(iter->second); + cursorMap_.erase(iter); + if (cacheCursor_.size() < maxCache) { + cacheCursor_.emplace_back(std::move(cInfo->shmUnit)); + } + if (cacheLastPageRef_.size() < maxCache && cInfo->lastPageRefShmUnit != nullptr) { + cacheLastPageRef_.emplace_back(std::move(cInfo->lastPageRefShmUnit)); + } + LOG(INFO) << FormatString("[%s, Delete:%s] Cursor removed. Number of cursors %zu", streamName_, id, + cursorMap_.size()); + return Status::OK(); +} + +void PageQueueHandler::ForceUnlockByCursor(const std::string &cursorId, bool isProducer, uint32_t lockId) +{ + bool fallback = false; + LOG_IF_ERROR(ForceUnlockByCursorImpl(cursorId, lockId, fallback), + FormatString("TryForceUnlockImpl for %s %s and lockId %zu failed", + (isProducer ? "producer" : "consumer"), cursorId, lockId)); + if (isProducer && fallback) { + LOG(INFO) << FormatString("[%s, P:%s] Switch to use V1 client recovery logic.", LogPrefix(), cursorId); + TryUnlockByLockId(lockId); + } +} + +Status PageQueueHandler::ForceUnlockByCursorImpl(const std::string &cursorId, uint32_t lockId, bool &fallback) +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(cursorMutex_)); + auto iter = cursorMap_.find(cursorId); + if (iter == cursorMap_.end() || iter->second == nullptr) { + return Status::OK(); + } + + // Only V2 client will update the last locked page field. If it is a V1 client, then + // we will fall back to the old method. + auto &cursor = iter->second->cursor; + if (cursor->GetClientVersion() < Cursor::K_CURSOR_SIZE_V2) { + fallback = true; + return Status::OK(); + } + + LOG(INFO) << FormatString("[%s, cursorId:%s] V2 client detected.", LogPrefix(), cursorId); + auto unlock = [this, &cursor, lockId, &cursorId]() { + // unlock SharedMemView in Cursor + auto msg = FormatString("%s cursorId:%s", LogPrefix(), cursorId); + LOG_IF_ERROR(cursor->ForceUnLock(lockId, LogPrefix()), "Cursor ForceUnLock failed"); + // Get the last page (potentially) locked by this producer. + ShmView view; + RETURN_IF_NOT_OK(cursor->GetLastLockedPageView(view, DEFAULT_TIMEOUT_MS)); + LOG(INFO) << FormatString("[%s, cursorId:%s] Last locked page<%s>", LogPrefix(), cursorId, view.ToStr()); + RETURN_OK_IF_TRUE(view == ShmView()); // No page is locked + // Get the page + std::shared_ptr page; + RETURN_IF_NOT_OK(LocatePage(view, page)); + LOG(INFO) << FormatString("[%s, cursorId:%s] Unlock page<%s>", LogPrefix(), cursorId, page->GetPageId()); + page->TryUnlockByLockId(lockId); + return Status::OK(); + }; + auto status = unlock(); + fallback = status.IsError(); + return status; +} + +void PageQueueHandler::TryUnlockByLockId(uint32_t lockId) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + sharedPageQueue_->TryUnlockByLockId(lockId); + } + } + exclusivePageQueue_->TryUnlockByLockId(lockId); +} + +void PageQueueHandler::ForceUnlockMemViemForPages(uint32_t lockId) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + sharedPageQueue_->ForceUnlockMemViemForPages(lockId); + } + } + exclusivePageQueue_->ForceUnlockMemViemForPages(lockId); +} + +Status PageQueueHandler::LocatePage(const ShmView &v, std::shared_ptr &out) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + auto rc = sharedPageQueue_->LocatePage(v, out); + // continue find from exclusive page queue if not exists. + if (rc.IsOk() || rc.GetCode() != K_NOT_FOUND) { + return rc; + } + } + } + return exclusivePageQueue_->LocatePage(v, out); +} + +Status PageQueueHandler::UpdateLocalCursorLastDataPage(const ShmView &shmView) +{ + INJECT_POINT("worker.UpdateLocalCursorLastDataPage.beforeLockCursorMutex"); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(cursorMutex_)); + if (!enableSharedPage_) { + for (auto &ele : cursorMap_) { + // if a SetLastPage errors out go to next one + // When crash handling happens, we will clean this up by deleting pub + INJECT_POINT("UpdateLocalPubLastDataPage.skip"); + WARN_IF_ERROR(ele.second->cursor->SetLastPage(shmView, DEFAULT_TIMEOUT_MS), + FormatString("[%s] UpdateLocalCursorLastDataPage error", LogPrefix())); + } + return Status::OK(); + } + for (auto &ele : cursorMap_) { + if (ele.second->lastPageRefShmViewImpl == nullptr) { + LOG(WARNING) << LogPrefix() << " lastPageRefShmViewImpl is nullptr"; + continue; + } + WARN_IF_ERROR(ele.second->lastPageRefShmViewImpl->SetView(shmView, false, DEFAULT_TIMEOUT_MS), + FormatString("[%s] UpdateLocalCursorLastDataPage error", LogPrefix())); + } + return Status::OK(); +} + +Status PageQueueHandler::MoveUpLastPage(const bool updateLocalPubLastPage) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + return sharedPageQueue_->MoveUpLastPage(updateLocalPubLastPage); + } + } + return exclusivePageQueue_->MoveUpLastPage(updateLocalPubLastPage); +} + +Status PageQueueHandler::AllocMemory(size_t pageSz, bool bigElement, std::shared_ptr &pageUnitInfo, + bool retryOnOOM) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + return sharedPageQueue_->AllocMemory(pageSz, bigElement, pageUnitInfo, retryOnOOM); + } + } + return exclusivePageQueue_->AllocMemory(pageSz, bigElement, pageUnitInfo, retryOnOOM); +} + +Status PageQueueHandler::ReclaimAckedChain(uint64_t timeoutMs) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + return sharedPageQueue_->ReclaimAckedChain(timeoutMs); + } + } + return exclusivePageQueue_->ReclaimAckedChain(timeoutMs); +} + +Status PageQueueHandler::ReleaseMemory(const ShmView &pageView) +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + // the memory alloc from exclusive page queue, + // but release after switch to share page queue. + auto rc = sharedPageQueue_->ReleaseMemory(pageView); + if (rc.GetCode() != K_NOT_FOUND) { + return rc; + } + } + } + return exclusivePageQueue_->ReleaseMemory(pageView); +} + +void PageQueueHandler::DumpPoolPages(int level) const +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + sharedPageQueue_->DumpPoolPages(level); + return; + } + } + exclusivePageQueue_->DumpPoolPages(level); +} + +size_t PageQueueHandler::GetPageSize() const +{ + { + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + return sharedPageQueue_->GetPageSize(); + } + } + return exclusivePageQueue_->GetPageSize(); +} + +std::string PageQueueHandler::GetSharedPageQueueId() const +{ + std::shared_lock locker(mutex_); + if (sharedPageQueue_ != nullptr) { + return sharedPageQueue_->GetPageQueueId(); + } + return ""; +} + +void PageQueueHandler::SetSharedPageQueue(std::shared_ptr &sharedPageQueue) +{ + LOG(INFO) << LogPrefix() << " update SharedPageQueue to " << sharedPageQueue->GetPageQueueId(); + { + ShmView lastPageRefShmView; + std::shared_ptr lastPageRefShmViewImpl; + auto shmView = sharedPageQueue->GetOrCreateLastPageRef(lastPageRefShmView, lastPageRefShmViewImpl); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(cursorMutex_)); + for (auto &ele : cursorMap_) { + LOG_IF_ERROR(ele.second->cursor->SetLastPageRef(lastPageRefShmView, DEFAULT_TIMEOUT_MS, true), + FormatString("%s update last page ref failed.", LogPrefix())); + } + } + std::lock_guard locker(mutex_); + sharedPageQueue_ = sharedPageQueue; +} + +Status PageQueueHandler::GetOrCreateShmMeta(const std::string &tenantId, ShmView &view) +{ + { + std::shared_lock lck(streamMetaShmMux_); + if (shmUnitOfStreamMeta_) { + view = shmUnitOfStreamMeta_->GetShmView(); + return Status::OK(); + } + } + { + std::lock_guard lck(streamMetaShmMux_); + if (shmUnitOfStreamMeta_) { + view = shmUnitOfStreamMeta_->GetShmView(); + return Status::OK(); + } + auto shmUnitOfStreamMeta = std::make_unique(); + RETURN_IF_NOT_OK(shmUnitOfStreamMeta->AllocateMemory(tenantId, streamMetaShmSize_, false, ServiceType::STREAM)); + auto rc = memset_s(shmUnitOfStreamMeta->GetPointer(), streamMetaShmSize_, 0, streamMetaShmSize_); + CHECK_FAIL_RETURN_STATUS(rc == 0, K_RUNTIME_ERROR, FormatString("Memset to 0 results in errno %d", rc)); + StreamFields streamFields; + exclusivePageQueue_->GetStreamFields(streamFields); + streamMetaShm_ = std::make_unique(streamName_, shmUnitOfStreamMeta->GetPointer(), + streamMetaShmSize_, streamFields.maxStreamSize_); + RETURN_IF_NOT_OK(streamMetaShm_->Init()); + view = shmUnitOfStreamMeta->GetShmView(); + shmUnitOfStreamMeta_ = std::move(shmUnitOfStreamMeta); + } + return Status::OK(); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/page_queue/page_queue_handler.h b/src/datasystem/worker/stream_cache/page_queue/page_queue_handler.h new file mode 100644 index 0000000..c0d90bf --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/page_queue_handler.h @@ -0,0 +1,191 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: PageQueueHandler + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_PAGE_QUEUE_HANDLER_H +#define DATASYSTEM_WORKER_STREAM_CACHE_PAGE_QUEUE_HANDLER_H + +#include +#include + +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/stream_cache/stream_meta_shm.h" +#include "datasystem/utils/optional.h" +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class StreamManager; +class PageQueueHandler { +public: + PageQueueHandler(StreamManager *mgr, Optional cfg); + ~PageQueueHandler() = default; + + static std::shared_ptr CreateExclusivePageQueue(StreamManager *mgr, Optional cfg); + + auto GetExclusivePageQueue() const + { + return exclusivePageQueue_; + } + + std::string GetSharedPageQueueId() const; + + bool ExistsSharedPageQueue() const; + + Status CreateOrGetLastDataPage(uint64_t timeoutMs, const ShmView &lastView, + std::shared_ptr &lastPage, bool retryOnOOM); + + /** + * @brief A shared memory work area that is shared between this worker and the producer/consumer + * @param[in] id ProducerId or ConsumerId + * @param[in] isProducer Add cursor for producer. + * @param[out] out shared work area + * @param[out] view ShmView to access this work area + * @return Status + */ + Status AddCursor(const std::string &id, bool isProducer, std::shared_ptr &out, ShmView &view); + + /** + * @brief Delete a shared memory work area + * @param[in] id ProducerId or ConsumerId + * @return Status + */ + Status DeleteCursor(const std::string &id); + + /** + * @brief Crash recovery for lost client to unlock by cursor. + * @param[in] cursorId The cursorId. + * @param[in] isProducer Ture for producer. + * @param[in] lockId The lock id. + */ + void ForceUnlockByCursor(const std::string &cursorId, bool isProducer, uint32_t lockId); + + /** + * @brief Crash recovery for lost client to unlock by cursor. + * @param[in] cursorId The cursorId. + * @param[in] lockId The lock id. + * @param[out] fallback whether need fallback to old logic. + * @return Status of this call. + */ + Status ForceUnlockByCursorImpl(const std::string &cursorId, uint32_t lockId, bool &fallback); + + /** + * @brief Unlock mem view on all pages for this stream. + * @param[in] streams The stream name list. + * @param[in] lockId The lock id. + */ + void ForceUnlockMemViemForPages(uint32_t lockId); + + /** + * Crash recovery based on lockId and producerId + * @param[in] lockId The lock id. + */ + void TryUnlockByLockId(uint32_t lockId); + + /** + * @brief Force to update the last page location + * @param[in] updateLocalPubLastPage If local producers need to get an update of the lastest last page. + * @return Status object + */ + Status MoveUpLastPage(const bool updateLocalPubLastPage = true); + + Status AllocMemory(size_t pageSz, bool bigElement, std::shared_ptr &pageUnitInfo, bool retryOnOOM); + + Status ReclaimAckedChain(uint64_t timeoutMs); + + Status ReleaseMemory(const ShmView &pageView); + + void DumpPoolPages(int level) const; + + size_t GetPageSize() const; + + void SetSharedPageQueue(std::shared_ptr &sharedPageQueue); + + /** + * @brief Get or create shm meta. + * @param[in] tenantId The ID of tenant. + * @param[out] view The view of shm meta. + * @return Status of the call. + */ + Status GetOrCreateShmMeta(const std::string &tenantId, ShmView &view); + + /** + * @brief Try to decrease the usage of shared memory in this node for this stream. + * @param[in] size The size to be increased. + * @return Status of the call. + */ + Status TryDecUsage(uint64_t size) + { + std::shared_lock lck(streamMetaShmMux_); + return streamMetaShm_ ? streamMetaShm_->TryDecUsage(size) : Status::OK(); + } + + /** + * @brief Get stream meta shm. + * @return The pointer to stream meta shm. + */ + StreamMetaShm *GetStreamMetaShm() + { + std::shared_lock lck(streamMetaShmMux_); + return streamMetaShm_ ? streamMetaShm_.get() : nullptr; + } + + /** + * @brief Verifies the input stream fields match the existing setting. + * If the existing settings are uninitialized, updates the values. + * @param[in] streamFields The stream fields with page size and max stream size to check + * @return Status of the call. + */ + Status UpdateStreamFields(const StreamFields &streamFields); + +private: + std::string LogPrefix() const; + Status LocatePage(const ShmView &v, std::shared_ptr &out); + Status UpdateLocalCursorLastDataPage(const ShmView &shmView); + + std::string streamName_; + std::atomic enableSharedPage_; + mutable std::shared_timed_mutex mutex_; + std::shared_ptr exclusivePageQueue_; + std::shared_ptr sharedPageQueue_; + + std::shared_timed_mutex streamMetaShmMux_; + std::unique_ptr shmUnitOfStreamMeta_; + std::unique_ptr streamMetaShm_; + const uint64_t streamMetaShmSize_ = 64; + + // WorkArea(s) for communicating with producers/consumers. + mutable std::shared_timed_mutex cursorMutex_; + struct CursorInfo { + std::unique_ptr shmUnit; + std::shared_ptr cursor; + std::unique_ptr lastPageRefShmUnit; + std::shared_ptr lastPageRefShmViewImpl; + }; + std::unordered_map> cursorMap_; + std::deque> cacheCursor_; + std::deque> cacheLastPageRef_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/page_queue/shared_page_queue.cpp b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue.cpp new file mode 100644 index 0000000..494aa3d --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue.cpp @@ -0,0 +1,201 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: SharedPageQueue + */ + +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue.h" + +#include "datasystem/common/constants.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/remote_worker_manager.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +DS_DECLARE_string(sc_encrypt_secret_key); +DS_DECLARE_string(encrypt_kit); + +namespace { +static bool ValidateSharedPageSize(const char *flagname, uint32_t value) +{ + const int32_t minValue = 1; + const int32_t maxValue = 16; + if ((value < minValue) || (value > maxValue)) { + LOG(ERROR) << "The " << flagname << " must be between " << minValue << " and " << maxValue << "."; + return false; + } + return true; +} +} // namespace + +DS_DEFINE_uint32(sc_shared_page_size_mb, 4, "the shared page size"); +DS_DEFINE_validator(sc_shared_page_size_mb, &ValidateSharedPageSize); + +namespace datasystem { +namespace worker { +namespace stream_cache { + +SharedPageQueue::SharedPageQueue(std::string tenantId, HostPort remoteWorker, int partId, + std::shared_ptr scAllocateManager, + ClientWorkerSCServiceImpl *scSvc) + : tenantId_(std::move(tenantId)), + remoteWorker_(std::move(remoteWorker)), + scAllocateManager_(scAllocateManager), + scSvc_(scSvc) +{ + isSharedPage_ = true; + if (tenantId_.empty()) { + pageQueueId_ = FormatString("%s-%d", remoteWorker_.ToString(), partId); + } else { + pageQueueId_ = FormatString("%s-%s-%d", remoteWorker_.ToString(), tenantId_, partId); + } +} + +const std::string &SharedPageQueue::GetPageQueueId() const +{ + return pageQueueId_; +} + +std::string SharedPageQueue::LogPrefix() const +{ + return FormatString("SPG:%s", GetPageQueueId()); +} + +size_t SharedPageQueue::GetPageSize() const +{ + return FLAGS_sc_shared_page_size_mb * MB_TO_BYTES; +} + +Status SharedPageQueue::CheckHadEnoughMem(uint64_t memSize) +{ + (void)memSize; + return Status::OK(); +} + +Status SharedPageQueue::UpdateLocalCursorLastDataPage(const ShmView &shmView) +{ + ReadLockHelper xlock(STREAM_COMMON_LOCK_ARGS(lastPageRefMutex_)); + if (lastPageRefShmViewImpl_ != nullptr) { + WARN_IF_ERROR(lastPageRefShmViewImpl_->SetView(shmView, false, DEFAULT_TIMEOUT_MS), + FormatString("[%s] UpdateLocalCursorLastDataPage error", LogPrefix())); + } else { + LOG(WARNING) << LogPrefix() << " lastPageRefShmViewImpl_ not init!"; + } + return Status::OK(); +} + +Status SharedPageQueue::AllocateMemoryImpl(size_t memSizeNeeded, ShmUnit &shmUnit, bool retryOnOOM) +{ + return scAllocateManager_->AllocateMemoryForStream(tenantId_, "", memSizeNeeded, true, shmUnit, retryOnOOM); +} + +Status SharedPageQueue::AfterAck() +{ + return Status::OK(); +} + +Status SharedPageQueue::RemoteAck() +{ + auto lastAppendCursor = GetLastAppendCursor(); + uint64_t newAckCursor = UpdateLastAckCursorUnlocked(lastAppendCursor); + RETURN_IF_NOT_OK(Ack(newAckCursor)); + return Status::OK(); +} + +uint64_t SharedPageQueue::UpdateLastAckCursorUnlocked(uint64_t minSubsAckCursor) +{ + bool success = false; + do { + uint64_t val = lastAckCursor_.load(); + // Go through all remote consumers. We may in the process of sending elements + // to the remote worker. + auto remoteWorkerManager = GetRemoteWorkerManager(); + auto remoteAckCursor = remoteWorkerManager->GetLastAckCursor(GetPageQueueId()); + minSubsAckCursor = std::min(minSubsAckCursor, remoteAckCursor); + if (minSubsAckCursor > val) { + INJECT_POINT_NO_RETURN("UpdateLastAckCursorUnlocked.sleep"); + success = lastAckCursor_.compare_exchange_strong(val, minSubsAckCursor); + if (success) { + LOG(INFO) << FormatString("[%s] The last ack of stream update from %zu to %zu", LogPrefix(), val, + minSubsAckCursor); + return minSubsAckCursor; + } + } else { + return minSubsAckCursor; + } + } while (true); +} + +bool SharedPageQueue::IsEncryptStream(const std::string &streamName) const +{ + StreamManagerMap::const_accessor accessor; + Status rc = scSvc_->GetStreamManager(streamName, accessor); + if (rc.IsError()) { + return false; + } + std::shared_ptr streamMgr = accessor->second; + StreamFields streamFields; + streamMgr->GetStreamFields(streamFields); + return streamFields.encryptStream_ && !FLAGS_sc_encrypt_secret_key.empty() + && FLAGS_encrypt_kit != ENCRYPT_KIT_PLAINTEXT; +} + +std::string SharedPageQueue::GetStreamName() const +{ + return GetPageQueueId(); +} + +Status SharedPageQueue::GetOrCreateLastPageRef(ShmView &lastPageRefShmView, + std::shared_ptr &lastPageRefShmViewImpl) +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(lastPageRefMutex_)); + if (lastPageRefShmUnit_ == nullptr) { + const size_t lastPageRefSize = sizeof(SharedMemView); + auto shmUnit = std::make_unique(); + shmUnit->SetHardFreeMemory(); + RETURN_IF_NOT_OK(shmUnit->AllocateMemory(tenantId_, lastPageRefSize, false, ServiceType::STREAM)); + auto rc = memset_s(shmUnit->GetPointer(), lastPageRefSize, 0, lastPageRefSize); + CHECK_FAIL_RETURN_STATUS(rc == 0, K_RUNTIME_ERROR, + FormatString("[%s] Memset to 0 results in errno %d", LogPrefix(), rc)); + + // The lock id of worker is 0. + auto lastPageRefShmViewImpl = + std::make_shared(shmUnit->GetPointer(), lastPageRefSize, WORKER_LOCK_ID); + RETURN_IF_NOT_OK(lastPageRefShmViewImpl->Init(false)); + lastPageRefShmUnit_ = std::move(shmUnit); + lastPageRefShmViewImpl_ = std::move(lastPageRefShmViewImpl); + lastPageRefShmViewImpl_->SetView(ShmView{}, false, std::numeric_limits::max()); + } + lastPageRefShmView = lastPageRefShmUnit_->GetShmView(); + lastPageRefShmViewImpl = lastPageRefShmViewImpl_; + return Status::OK(); +} + +std::shared_ptr SharedPageQueue::SharedFromThis() +{ + return std::static_pointer_cast(shared_from_this()); +} + +RemoteWorkerManager *SharedPageQueue::GetRemoteWorkerManager() const +{ + return scSvc_->GetRemoteWorkerManager(); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/page_queue/shared_page_queue.h b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue.h new file mode 100644 index 0000000..2f20d1c --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue.h @@ -0,0 +1,72 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: SharedPageQueue + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_SHARED_PAGE_QUEUE_BASE_H +#define DATASYSTEM_WORKER_STREAM_CACHE_SHARED_PAGE_QUEUE_BASE_H + +#include "datasystem/worker/stream_cache/page_queue/page_queue_base.h" +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class ClientWorkerSCServiceImpl; +class RemoteWorkerManager; +class SharedPageQueue : public std::enable_shared_from_this, public PageQueueBase { +public: + SharedPageQueue(std::string tenantId, HostPort remoteWorker, int partId, + std::shared_ptr scAllocateManager, ClientWorkerSCServiceImpl *scSvc); + virtual ~SharedPageQueue() = default; + Status RemoteAck(); + uint64_t UpdateLastAckCursorUnlocked(uint64_t minSubsAckCursor); + + const std::string &GetPageQueueId() const; + + Status GetOrCreateLastPageRef(ShmView &lastPageRefShmView, + std::shared_ptr &lastPageRefShmViewImpl); + +public: + std::string LogPrefix() const override; + size_t GetPageSize() const override; + Status AllocateMemoryImpl(size_t memSizeNeeded, ShmUnit &shmUnit, bool retryOnOOM) override; + Status CheckHadEnoughMem(uint64_t memSize) override; + Status UpdateLocalCursorLastDataPage(const ShmView &shmView) override; + Status AfterAck() override; + bool IsEncryptStream(const std::string &streamName) const override; + std::string GetStreamName() const override; + RemoteWorkerManager *GetRemoteWorkerManager() const override; + +private: + std::shared_ptr SharedFromThis() override; + const std::string tenantId_; + const HostPort remoteWorker_; + std::string pageQueueId_; + std::shared_ptr scAllocateManager_; + std::atomic lastAckCursor_; + ClientWorkerSCServiceImpl *scSvc_; + + mutable std::shared_timed_mutex lastPageRefMutex_; + std::unique_ptr lastPageRefShmUnit_; + std::shared_ptr lastPageRefShmViewImpl_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.cpp b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.cpp new file mode 100644 index 0000000..1f8c85a --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.cpp @@ -0,0 +1,132 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: SharedPageQueueGroup + */ + +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue_group.h" +#include +#include +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/utils/status.h" +#include "datasystem/common/iam/tenant_auth_manager.h" + +namespace { +static bool ValidateGroupCount(const char *flagname, uint32_t value) +{ + const int32_t minValue = 1; + const int32_t maxValue = 64; + if ((value < minValue) || (value > maxValue)) { + LOG(ERROR) << "The " << flagname << " must be between " << minValue << " and " << maxValue << "."; + return false; + } + return true; +} +} // namespace + +DS_DEFINE_uint32(sc_shared_page_group_count, 4, "the shared page count for each remote worker"); +DS_DEFINE_validator(sc_shared_page_group_count, &ValidateGroupCount); + +namespace datasystem { +namespace worker { +namespace stream_cache { + +SharedPageQueueGroup::SharedPageQueueGroup(HostPort remoteWorker, + std::shared_ptr scAllocateManager, + ClientWorkerSCServiceImpl *scSvc) + : partCount_(FLAGS_sc_shared_page_group_count), + remoteWorker_(std::move(remoteWorker)), + scAllocateManager_(std::move(scAllocateManager)), + scSvc_(scSvc) +{ +} + +SharedPageQueueGroup::~SharedPageQueueGroup() = default; + +size_t SharedPageQueueGroup::GetPartId(const std::string &streamName) const +{ + return std::hash{}(streamName) % partCount_; +} + +std::vector SharedPageQueueGroup::GetAllSharedPageName() +{ + std::vector names; + std::shared_lock locker(mutex_); + for (auto &it : tenantPageQueues_) { + for (auto &page : it.second) { + names.emplace_back(page->GetStreamName()); + } + } + return names; +} + +Status SharedPageQueueGroup::GetSharedPageQueue(const std::string &streamName, + std::shared_ptr &pageQueue) +{ + auto tenantId = TenantAuthManager::ExtractTenantId(streamName); + auto realStreamName = TenantAuthManager::ExtractRealObjectKey(streamName); + std::shared_lock locker(mutex_); + auto iter = tenantPageQueues_.find(tenantId); + if (iter != tenantPageQueues_.end()) { + auto index = GetPartId(realStreamName); + pageQueue = iter->second[index]; + return Status::OK(); + } + RETURN_STATUS(K_NOT_FOUND, FormatString("not found page queue for tenant %s", tenantId)); +} + +void SharedPageQueueGroup::GetOrCreateSharedPageQueue(const std::string &streamName, + std::shared_ptr &pageQueue) +{ + auto rc = GetSharedPageQueue(streamName, pageQueue); + if (rc.IsOk()) { + return; + } + auto tenantId = TenantAuthManager::ExtractTenantId(streamName); + auto realStreamName = TenantAuthManager::ExtractRealObjectKey(streamName); + std::lock_guard locker(mutex_); + auto iter = tenantPageQueues_.find(tenantId); + if (iter == tenantPageQueues_.end()) { + std::vector> pageQueues(partCount_); + for (size_t i = 0; i < partCount_; i++) { + pageQueues[i] = std::make_shared(tenantId, remoteWorker_, i, scAllocateManager_, scSvc_); + } + auto ret = tenantPageQueues_.emplace(std::move(tenantId), std::move(pageQueues)); + iter = ret.first; + } + auto index = GetPartId(realStreamName); + pageQueue = iter->second[index]; +} + +Status SharedPageQueueGroup::RemoveSharedPageQueueForTenant(const std::string &tenantId) +{ + std::lock_guard locker(mutex_); + auto iter = tenantPageQueues_.find(tenantId); + if (iter != tenantPageQueues_.end()) { + tenantPageQueues_.erase(iter); + return Status::OK(); + } + RETURN_STATUS(K_NOT_FOUND, FormatString("not found page queue for tenant %s", tenantId)); +} + +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.h b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.h new file mode 100644 index 0000000..bdad3de6 --- /dev/null +++ b/src/datasystem/worker/stream_cache/page_queue/shared_page_queue_group.h @@ -0,0 +1,60 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: SharedPageQueueGroup + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_SHARED_PAGE_QUEUE_GROUP_H +#define DATASYSTEM_WORKER_STREAM_CACHE_SHARED_PAGE_QUEUE_GROUP_H + +#include +#include + +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue.h" +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class SharedPageQueueGroup final { +public: + SharedPageQueueGroup(HostPort remoteWorker, std::shared_ptr scAllocateManager, + ClientWorkerSCServiceImpl *scSvc); + ~SharedPageQueueGroup(); + + std::vector GetAllSharedPageName(); + + void GetOrCreateSharedPageQueue(const std::string &namespaceUri, std::shared_ptr &pageQueue); + + Status GetSharedPageQueue(const std::string &namespaceUri, std::shared_ptr &pageQueue); + + Status RemoveSharedPageQueueForTenant(const std::string &tenantId); + +private: + size_t GetPartId(const std::string &streamName) const; + + const size_t partCount_; + const HostPort remoteWorker_; + std::shared_timed_mutex mutex_; + std::unordered_map>> tenantPageQueues_; + std::shared_ptr scAllocateManager_; + ClientWorkerSCServiceImpl *scSvc_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/producer.cpp b/src/datasystem/worker/stream_cache/producer.cpp new file mode 100644 index 0000000..175cc17 --- /dev/null +++ b/src/datasystem/worker/stream_cache/producer.cpp @@ -0,0 +1,55 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/producer.h" + +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/iam/tenant_auth_manager.h" +#include "datasystem/common/log/log.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +Producer::Producer(std::string producerId, std::string streamName, std::shared_ptr cursor) + : id_(std::move(producerId)), streamName_(std::move(streamName)), cursor_(std::move(cursor)) +{ +} + +std::string Producer::GetId() const +{ + return id_; +} + +Status Producer::CleanupProducer() +{ + if (cursor_) { + RETURN_IF_NOT_OK(cursor_->Init()); + RETURN_IF_NOT_OK(cursor_->SetLastPage(ShmView(), DEFAULT_TIMEOUT_MS)); + } + return Status::OK(); +} + +void Producer::SetForceClose() +{ + if (cursor_) { + cursor_->SetForceClose(); + } +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/producer.h b/src/datasystem/worker/stream_cache/producer.h new file mode 100644 index 0000000..aaefcdb --- /dev/null +++ b/src/datasystem/worker/stream_cache/producer.h @@ -0,0 +1,107 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_PRODUCER_H +#define DATASYSTEM_WORKER_STREAM_CACHE_PRODUCER_H + +#include +#include + +#include "datasystem/common/stream_cache/cursor.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class Producer { +public: + /** + * @brief Construct Producer. + * @param[in] producerId The producer id. + */ + Producer(std::string producerId, std::string streamName, std::shared_ptr cursor); + Producer(const Producer &producer) = delete; + Producer &operator=(const Producer &producer) = delete; + Producer(Producer &&producer) noexcept = delete; + Producer &operator=(Producer &&producer) noexcept = delete; + + virtual ~Producer() = default; + + /** + * @brief Seals the current page and clears flush count. + * @return Status of the call. + */ + Status CleanupProducer(); + + /** + * @brief Get producer id. + * @return Id of producer. + */ + [[nodiscard]] std::string GetId() const; + + /** + * Get the element count of the cursor + */ + uint64_t GetElementCount() + { + return cursor_ == nullptr ? 0 : cursor_->GetElementCount(); + } + + /** + * Get the request count of the cursor and reset it to 0 + */ + uint64_t GetRequestCountAndReset() + { + return cursor_ == nullptr ? 0 : cursor_->GetRequestCountAndReset(); + } + + void SetForceClose(); + + /** + * @brief Set the element count to val + * @param val value to set element count to + */ + void SetElementCount(uint64_t val) + { + if (cursor_) { + cursor_->SetElementCount(val); + } + } + + /** + * @brief Get the element count and reset it to 0. + * @return + */ + uint64_t GetElementCountAndReset() + { + return cursor_ == nullptr ? 0 : cursor_->GetElementCountAndReset(); + } + + void SetCursor(std::shared_ptr &&cursor) + { + cursor_ = std::move(cursor); + } + +protected: + const std::string id_; + const std::string streamName_; + // A work area that is shared between the corresponding client::stream_cache::ProducerImpl + // sz is the size of this work area. It is set up in the function StreamDataObject::AddCursor. + std::shared_ptr cursor_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif diff --git a/src/datasystem/worker/stream_cache/remote_worker_manager.cpp b/src/datasystem/worker/stream_cache/remote_worker_manager.cpp new file mode 100644 index 0000000..9586970 --- /dev/null +++ b/src/datasystem/worker/stream_cache/remote_worker_manager.cpp @@ -0,0 +1,1651 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/common/rpc/rpc_stub_base.h" +#include "datasystem/common/rpc/rpc_stub_cache_mgr.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/gflag/common_gflags.h" +#include "datasystem/common/util/memory.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/protos/stream_posix.stub.rpc.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue.h" +#include "datasystem/worker/stream_cache/remote_worker_manager.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +DS_DEFINE_int32(remote_send_thread_num, 8, "The num of threads used to send elements to remote worker."); +DS_DEFINE_validator(remote_send_thread_num, &Validator::ValidateThreadNum); + +namespace datasystem { +namespace worker { +namespace stream_cache { + +std::string SendElementView::StreamName() const +{ + return streamName_; +} + +std::string SendElementView::ProducerName() const +{ + return remoteWorker_; +} + +std::string SendElementView::ProducerInstanceId() const +{ + // we only care about instance on recv side + return ""; +} + +uint64_t SendElementView::StreamHash() const +{ + // We will swap the position of stream and worker address so to hash differently + StreamProducerKey key(ProducerName(), KeyName(), ProducerInstanceId()); + return std::hash{}(key); +} + +Status SendElementView::CreateSendElementView(const std::shared_ptr &page, + const std::string &remoteWorker, DataElement &dataElement, + std::shared_ptr obj, + RemoteWorkerManager *remoteWorkerManager, + std::shared_ptr &out) +{ + bool isSharedPage = dataElement.GetStreamNo() != 0; + std::shared_ptr elementView = std::make_shared(); + std::string streamName; + + if (isSharedPage) { + streamName = ""; + Status rc = remoteWorkerManager->StreamNoToName(dataElement.GetStreamNo(), streamName); + VLOG_IF(SC_NORMAL_LOG_LEVEL, rc.IsError()) << rc.ToString(); + } else { + streamName = obj->GetStreamName(); + } + + elementView->page_ = page; + elementView->streamName_ = streamName; + elementView->begCursor_ = dataElement.id; + elementView->remote_ = dataElement.IsRemote(); + elementView->bigElement_ = dataElement.IsBigElement(); + // Note the order of constructing elementView!! After "ExtractBigElement" is executed, "dataElement.size" will be + // replaced with the actual size of the big element. + if (elementView->bigElement_) { + elementView->bigElementMetaSize_ = dataElement.size; + } + if (elementView->bigElement_ && !elementView->remote_) { + RETURN_IF_NOT_OK(obj->ExtractBigElement(dataElement, elementView->bigElementPage_)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[RW:%s, S:%s] Page<%s> Cursor %zu BigElement<%s>", remoteWorker, + elementView->streamName_, page->GetPageId(), dataElement.id, + elementView->bigElementPage_->GetPageId()); + } + + // If it is a big element, the pointer here will be the pointer to the real data of the big element. + (void)elementView->PackDataElement(dataElement, true); + elementView->dataObj_ = obj; + + if (isSharedPage) { + auto sharedPageView = std::make_shared(); + sharedPageView->sharedPageName_ = obj->GetStreamName(); + sharedPageView->remoteWorker_ = remoteWorker; + sharedPageView->traceId_ = Trace::Instance().GetTraceID(); + sharedPageView->dataObj_ = obj; + sharedPageView->elementViews_.emplace_back(elementView); + out = std::static_pointer_cast(std::move(sharedPageView)); + } else { + elementView->remoteWorker_ = remoteWorker; + elementView->traceId_ = Trace::Instance().GetTraceID(); + out = std::move(elementView); + } + + return Status::OK(); +} + +Status StreamElementView::ReleasePage() +{ + bool expected = true; + if (ref_.compare_exchange_strong(expected, false)) { + RETURN_IF_NOT_OK(page_->ReleasePage(FormatString("RW:%s, S:%s", remoteWorker_, StreamName()))); + if (bigElement_) { + RETURN_IF_NOT_OK(dataObj_->DecBigElementPageRefCount(bigElementPage_->GetPageId())); + } + } + return Status::OK(); +} + +Status StreamElementView::IncRefCount() +{ + RETURN_OK_IF_TRUE(ref_); + bool bigElementLocked = false; + Raii raii([&bigElementLocked, this]() { + // Unlock in case of error. + if (!ref_ && bigElementLocked) { + std::string pageId = bigElementPage_->GetPageId(); + (void)dataObj_->DecBigElementPageRefCount(pageId); + } + }); + if (bigElement_) { + std::string pageId = bigElementPage_->GetPageId(); + RETURN_IF_NOT_OK(dataObj_->IncBigElementPageRefCount(pageId)); + bigElementLocked = true; + } + RETURN_IF_NOT_OK(page_->RefPage(FormatString("RW:%s, S:%s", remoteWorker_, StreamName()))); + ref_ = true; + return Status::OK(); +} + +Status StreamElementView::MoveBufToAlternateMemory() +{ + std::unique_lock lock(mux_); + // Check if someone beat us to do it already + RETURN_OK_IF_TRUE(!shmEnabled_); + // This function is only called when we hit OOM on sending. Without holding up the + // shared memory page, we will save the data from shared memory to alternate place. + size_t totalSize = std::accumulate(sz_.begin(), sz_.end(), 0ul); + shmUnit_ = std::make_unique(); + // Acquire some shared memory which is already allocated from the Arena, + // as using private memory can lead to worker getting OOMKilled. + RETURN_IF_NOT_OK(shmUnit_->AllocateMemory(DEFAULT_TENANT_ID, totalSize, true, ServiceType::STREAM)); + secondaryAddr_ = reinterpret_cast(shmUnit_->pointer); + RETURN_IF_NOT_OK(HugeMemoryCopy(secondaryAddr_, totalSize, buf_, totalSize)); + RETURN_IF_NOT_OK(ReleasePage()); + shmEnabled_ = false; + localBufSize_ = totalSize; + return Status::OK(); +} + +Status StreamElementView::MoveBufToShmUnit() +{ + std::unique_lock lock(mux_); + // Check if someone beat us to do it already + RETURN_OK_IF_TRUE(!shmEnabled_); + // This function is called when we hit OOM on sending or when the stream is already blocked before sending. + // Without holding up the shared page, we will save the data from the shared page to alternate place. + size_t totalSize = std::accumulate(sz_.begin(), sz_.end(), 0ul); + shmUnit_ = std::make_unique(); + // Fixme: use shared page queue for the allocation, + // deal with per stream memory limit, and also not skip retry. + auto tenantId = TenantAuthManager::ExtractTenantId(streamName_); + RETURN_IF_NOT_OK(shmUnit_->AllocateMemory(tenantId, totalSize, true, ServiceType::STREAM)); + secondaryAddr_ = reinterpret_cast(shmUnit_->pointer); + RETURN_IF_NOT_OK(HugeMemoryCopy(secondaryAddr_, totalSize, buf_, totalSize)); + RETURN_IF_NOT_OK(ReleasePage()); + shmEnabled_ = false; + localBufSize_ = totalSize; + return Status::OK(); +} + +uint8_t *StreamElementView::GetBufferPointer() +{ + std::shared_lock rlock(mux_); + bool shmEnabled = shmEnabled_.load(); + uint8_t *ptr = shmEnabled ? buf_.load() : secondaryAddr_; + return ptr; +} + +RemoteAckInfo::AckRange StreamElementView::GetAckRange() +{ + return std::make_pair(begCursor_, sz_.size()); +} + +bool StreamElementView::IsSharedPage() +{ + return false; +} + +bool StreamElementView::PackDataElement(const DataElement &dataElement, bool skipChecks, + RemoteWorkerManager *remoteworkerManager) +{ + (void)remoteworkerManager; + // The first element skips the checks. + if (!skipChecks) { + // We do not pack the next element into the element view if: + // 1. the element view contains big element already + // 2. the next element is big element + // 3. the element is not with contiguous memory + // 4. the remote field mismatches between the next element and the existing elements + if (bigElement_ || dataElement.IsBigElement() || dataElement.ptr + dataElement.size != buf_ + || dataElement.IsRemote() != remote_) { + return false; + } + } + buf_ = dataElement.ptr; + sz_.emplace_back(dataElement.size); + headerBits_.emplace_back(dataElement.HasHeader()); + return true; +} + +uint64_t StreamElementView::GetElementNum() +{ + return sz_.size(); +} + +void StreamElementView::DiscardBufferFromList(std::list &dataLst, std::list::iterator &iter) +{ + auto p = GetAckRange(); + LOG(INFO) << FormatString("[S:%s] Discard range [%zu, %zu)", StreamName(), p.first, p.first + p.second); + if (!ref_) { + // This buffer has been released already + iter = dataLst.erase(iter); + return; + } + (void)dataObj_->UpdatePageRefIfExist(page_->GetShmView(), FormatString("S:%s", StreamName()), false); + if (bigElement_) { + (void)dataObj_->DecBigElementPageRefCount(bigElementPage_->GetPageId()); + } + iter = dataLst.erase(iter); +} + +std::string SharedPageElementView::KeyName() const +{ + // The name for StreamProducerKey purposes. + // It is the unique identifier for the shared page queue instead of an actual stream name. + return sharedPageName_; +} + +bool SharedPageElementView::IsSharedPage() +{ + return true; +} + +bool SharedPageElementView::PackDataElement(const DataElement &dataElement, bool skipChecks, + RemoteWorkerManager *remoteWorkerManager) +{ + // For shared page scenario, elements to be packed should be from the same stream. + // Fixme: allow different streams in a shared page stream view + // Fixme: change to detect empty stream name + auto isDifferentStream = [this, &dataElement, remoteWorkerManager]() { + // stream number being 0 means shared page is not enabled for the stream, so it should be exclusive page. + if (dataElement.GetStreamNo() == 0) { + return false; + } + std::string streamNameFromNumber; + (void)remoteWorkerManager->StreamNoToName(dataElement.GetStreamNo(), streamNameFromNumber); + // If the stream number cannot be mapped, elements of the same nature can still be combined, + // data will be eventually discarded. + auto streamName = elementViews_.back()->StreamName(); + return streamNameFromNumber != streamName; + }; + if (isDifferentStream()) { + return false; + } + // Fixme: actually deal with list of element views. + return elementViews_.back()->PackDataElement(dataElement, skipChecks); +} + +uint64_t SharedPageElementView::RecordSeqNo(std::function fetchAddSeqNo) +{ + for (auto &view : elementViews_) { + seqNums_.emplace_back(view->RecordSeqNo(fetchAddSeqNo)); + } + return fetchAddSeqNo(sharedPageName_); +} + +Status SharedPageElementView::ReleasePage() +{ + for (auto &view : elementViews_) { + RETURN_IF_NOT_OK(view->ReleasePage()); + } + return Status::OK(); +} + +Status SharedPageElementView::IncRefCount() +{ + for (auto &view : elementViews_) { + RETURN_IF_NOT_OK(view->IncRefCount()); + } + return Status::OK(); +} + +RemoteAckInfo::AckRange SharedPageElementView::GetAckRange() +{ + uint64_t begCursor = elementViews_.front()->begCursor_; + return std::make_pair(begCursor, GetElementNum()); +} + +uint64_t SharedPageElementView::GetElementNum() +{ + uint64_t totalNum = 0; + for (auto &view : elementViews_) { + totalNum += view->sz_.size(); + } + return totalNum; +} + +void SharedPageElementView::DiscardBufferFromList(std::list &dataLst, std::list::iterator &iter) +{ + // Fixme: actually deal with list of stream views + elementViews_.front()->DiscardBufferFromList(dataLst, iter); +} + +Status SharedPageElementView::MoveBufToShmUnit() +{ + // Fixme: actually deal with the list of the element views. + RETURN_IF_NOT_OK(elementViews_.front()->MoveBufToShmUnit()); + return Status::OK(); +} + +// Class RemoteWorker part. +RemoteWorker::RemoteWorker(HostPort localAddress, HostPort remoteAddress, std::shared_ptr akSkManager, + ClientWorkerSCServiceImpl *scSvc, std::string &workerInstanceId, + std::shared_ptr scAllocateManager, RemoteWorkerManager *manager) + : localWorkerAddr_(std::move(localAddress)), + remoteWorkerAddr_(remoteAddress), + akSkManager_(std::move(akSkManager)), + scSvc_(scSvc), + sharedPageGroup_(std::move(remoteAddress), std::move(scAllocateManager), scSvc), + workerInstanceId_(workerInstanceId), remoteWorkerManager_(manager) +{ +} + +Status RemoteWorker::Init() +{ + return Status::OK(); +} + +RemoteWorker::~RemoteWorker() +{ + LOG(INFO) << "Start Destroy RemoteWorker for remote worker:" << remoteWorkerAddr_.ToString(); + auto pages = sharedPageGroup_.GetAllSharedPageName(); + for (auto &page : pages) { + remoteWorkerManager_->RemoveStream(page, ""); + } +} + +Status RemoteWorker::GetAccessor(const std::string &streamName, RemoteStreamInfoTbbMap::accessor &accessor) +{ + return remoteConsumers_.GetAccessor(streamName, accessor, LogPrefix()); +} + +Status RemoteWorker::AddRemoteConsumer(const std::string &streamName, const SubscriptionConfig &subConfig, + const std::string &consumerId, uint64_t windowCount, uint64_t lastAckCursor) +{ + if (subConfig.subscriptionType != SubscriptionType::STREAM) { + RETURN_STATUS(StatusCode::K_INVALID, + FormatString("Only support STREAM mode. <%s> mode not supported.", subConfig.subscriptionName)); + } + // Register this consumer onto that remote worker, one remote worker contains a lot of related stream. + RETURN_IF_NOT_OK_EXCEPT(remoteConsumers_.AddConsumer(streamName, consumerId, windowCount, lastAckCursor), + K_DUPLICATED); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s], Add remote consumer success", LogPrefix(), streamName, + consumerId); + return Status::OK(); +} + +Status RemoteWorker::DelRemoteConsumer(const std::string &streamName, const std::string &consumerId, + Optional &mapEmpty) +{ + RETURN_IF_NOT_OK(remoteConsumers_.DeleteConsumer(streamName, consumerId, mapEmpty)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s] Delete remote consumer success", LogPrefix(), + streamName, consumerId); + return Status::OK(); +} + +bool RemoteWorker::HasRemoteConsumers(const std::string &streamName) +{ + return remoteConsumers_.HasRemoteConsumers(streamName); +} + +bool RemoteWorker::IsStreamSendBlocked(const std::string &streamName) +{ + return remoteConsumers_.IsStreamSendBlocked(streamName); +} + +uint64_t RemoteWorker::GetMaxWindowCount(const std::string &streamName) const +{ + return remoteConsumers_.GetMaxWindowCount(streamName); +} + +Status RemoteWorker::DeleteStream(const std::string &streamName, Optional &mapEmpty) +{ + LOG(INFO) << FormatString("[%s] ClearAllRemoteConsumer for stream %s", LogPrefix(), streamName); + return remoteConsumers_.DeleteStream(streamName, mapEmpty); +} + +void RemoteWorker::PostRecvCleanup(const std::string &streamName, const Status &status, + PendingFlushList &pendingFlushList, const PushReqPb &pushReq, + const PushRspPb &pushRspPb, std::unordered_map &raii) +{ + // Here worker instance id is empty as its for sender worker + const std::string &producerId = pushReq.producer_id(); + const std::string workerInstanceId = ""; + // TraceID of StreamProducerKey is stored inside request, set in ParseProducerPendingFlushList + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReq.trace_id()); + // Key name is equivalent to stream name in exclusive page case. + StreamProducerKey key(streamName, producerId, workerInstanceId); + // Iterate over elements in pendingFlushList that match to current request + auto iter = std::find_if(pendingFlushList.begin(), pendingFlushList.end(), + [key](const auto &kv) { return kv.first == key; }); + if (status.GetCode() == K_SC_CONSUMER_NOT_FOUND) { + // Discard the buffer when consumer does not exist, instead of putting to retry. + // Note that currently one (batched) request is sent at a time, + // so it is feasible to discard all the remaining buffers in the list. + if (iter != pendingFlushList.end()) { + VLOG(SC_INTERNAL_LOG_LEVEL) << "No consumer found: Discarding buffers for stream: " << streamName; + RemoteStreamInfoTbbMap::accessor accessor; + if (GetAccessor(streamName, accessor).IsOk()) { + std::for_each(iter->second.begin(), iter->second.end(), [this, &accessor](const auto &kv) { + auto remoteElementView = std::static_pointer_cast(kv.first); + auto p = remoteElementView->GetAckRange(); + SyncStreamLastAckCursor(accessor, Optional(p)); + }); + } + DiscardBuffers(iter->second); + } + return; + } + if (status.IsError() && status.GetCode() != K_OUT_OF_MEMORY) { + RecordRemoteSendSuccess(false); // Rpc error, all data of the request fails to be sent. + return; + } + // Post cleanup + if (iter != pendingFlushList.end()) { + RemoteStreamInfoTbbMap::accessor accessor; + auto rc = GetAccessor(streamName, accessor); + if (rc.IsError()) { + LOG(ERROR) << FormatString("[%s, S:%s] Stream not found", LogPrefix(), streamName); + return; + } + auto streamMgr = (*(raii.find(streamName)->second))->second; + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReq.trace_id()); + PostRecvCleanup(streamName, pushReq, pushRspPb, accessor, iter->second, raii); + accessor.release(); + // If all elements are received. Move up the last ack cursor. + // We may also send a portion of the frames because of the window count and in which + // case we have to check what is remaining. + // Technically, we can also do partially ack by moving up to the first element on the list. + LOG_IF_ERROR(streamMgr->RemoteAck(), FormatString("[%s, S:%s] Remote ack failed", LogPrefix(), streamName)); + } +} + +void RemoteWorker::PostRecvCleanup(const std::string &streamName, const PushReqPb &rq, const PushRspPb &pushRspPb, + RemoteStreamInfoTbbMap::accessor &accessor, std::list &dataLst, + std::unordered_map &streamRaii) +{ + uint64_t releaseSize = 0; + Raii raii([&releaseSize, &streamName, &streamRaii]() { + if (releaseSize > 0) { + auto itr = streamRaii.find(streamName); + if (itr == streamRaii.end()) { + LOG(ERROR) << FormatString( + "Decrease shared memory usage for stream[%s] failed because no worker area was found", streamName); + return; + } + LOG_IF_ERROR( + (*(itr->second))->second->TryDecUsage(releaseSize), + FormatString("Decrease shared memory usage for stream[%s] failed during push element to remote worker", + streamName)); + } + }); + for (auto i = 0; i < pushRspPb.error_size(); ++i) { + auto &err = pushRspPb.error(i); + auto seqNo = rq.seq(i); + // Find the matching seqNo. + auto it = std::find_if(dataLst.begin(), dataLst.end(), [seqNo](const auto &kv) { return kv.second == seqNo; }); + if (it == dataLst.end()) { + LOG(ERROR) << FormatString("[%s, S:%s] Unable to find seqNo %zu", LogPrefix(), streamName, seqNo); + continue; + } + auto status = Status(static_cast(err.error_code()), err.error_msg()); + auto remoteElementView = std::static_pointer_cast(it->first); + auto p = remoteElementView->GetAckRange(); + auto begCursor = p.first; + auto endCursor = begCursor + p.second; + if (status.IsOk()) { + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s, S:%s] Remote send elements [seq:%zu] [%zu, %zu) to remote worker %s is successful. Ref count %zu", + LogPrefix(), streamName, seqNo, begCursor, endCursor, remoteWorkerAddr_.ToString(), + remoteElementView.use_count()); + // Ack and release page. Same logic as in BatchFlushAsyncRead + SyncStreamLastAckCursor(accessor, Optional(p)); + // If we have a reference on the page, decrement the count + remoteElementView->ReleasePage(); + releaseSize = std::accumulate(rq.element_meta(i).element_sizes().begin(), + rq.element_meta(i).element_sizes().end(), releaseSize); + if (std::static_pointer_cast(it->first)->bigElement_) { + releaseSize += std::static_pointer_cast(it->first)->bigElementMetaSize_; + } + it = dataLst.erase(it); + // If the sending status is K_OK, the success rate is 1. + RecordRemoteSendSuccess(true); + continue; + } + // The rest is error handling + LOG(INFO) << FormatString( + "[%s, S:%s, I:%s] Remote send elements [seq:%zu] [%zu, %zu) to remote worker %s gave status: %s", + LogPrefix(), streamName, workerInstanceId_, seqNo, begCursor, endCursor, remoteWorkerAddr_.ToString(), + status.ToString()); + if (status.GetCode() != K_OUT_OF_MEMORY) { + // If the sending status is not K_OK, the success rate is 0. Ignore OOM. + RecordRemoteSendSuccess(false); + continue; + } + Status allocRc = remoteElementView->MoveBufToAlternateMemory(); + if (allocRc.IsError()) { + LOG(ERROR) << FormatString("[%s, S:%s] Cursor [%zu, %zu) MoveBufToAlternateMemory failed. %s", LogPrefix(), + streamName, begCursor, endCursor, allocRc.ToString()); + continue; + } + // The function MoveBufToAlternateMemory can be called by other RemoteWorker on the same PV. + // We need to check carefully if we need to do the ack or not. + auto lastAckCursor = std::get(accessor->second).GetStreamLastAckCursor(); + SyncStreamLastAckCursor(accessor, Optional(p)); + if (lastAckCursor < std::get(accessor->second).GetStreamLastAckCursor()) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "[%s, S:%s] Cursor [%zu, %zu) MoveBufToAlternateMemory. Ref count %zu", LogPrefix(), streamName, + begCursor, endCursor, remoteElementView.use_count()); + } + } +} + +void RemoteWorker::PostRecvCleanup(const std::string &keyName, const Status &status, PendingFlushList &pendingFlushList, + const SharedPagePushReqPb &pushReq, const PushRspPb &pushRspPb, + std::unordered_map &raii) +{ + (void)status; + // Here worker instance id is empty as its for sender worker + // Key name here is the shared page name + const std::string &producerId = pushReq.producer_id(); + const std::string workerInstanceId = ""; + // TraceID of StreamProducerKey is stored inside request, set in ParseProducerPendingFlushList + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReq.trace_id()); + StreamProducerKey key(keyName, producerId, workerInstanceId); + // Iterate over elements in pendingFlushList that match to current request + auto iter = std::find_if(pendingFlushList.begin(), pendingFlushList.end(), + [key](const auto &kv) { return kv.first == key; }); + // Post cleanup + if (iter != pendingFlushList.end()) { + RemoteStreamInfoTbbMap::accessor accessor; + auto rc = GetAccessor(keyName, accessor); + if (rc.IsError()) { + LOG(ERROR) << FormatString("[%s, S:%s] Stream not found", LogPrefix(), keyName); + return; + } + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReq.trace_id()); + auto sharedPageQueue = std::static_pointer_cast( + std::static_pointer_cast(iter->second.front().first)->dataObj_); + PostRecvCleanup(keyName, pushReq, pushRspPb, accessor, iter->second, raii); + accessor.release(); + + // If all elements are received. Move up the last ack cursor. + // We may also send a portion of the frames because of the window count and in which + // case we have to check what is remaining. + // Technically, we can also do partially ack by moving up to the first element on the list. + LOG_IF_ERROR(sharedPageQueue->RemoteAck(), FormatString("[%s, S:%s] Remote ack failed", LogPrefix(), keyName)); + } +} + +void RemoteWorker::PostRecvCleanup(const std::string &keyName, const SharedPagePushReqPb &rq, + const PushRspPb &pushRspPb, RemoteStreamInfoTbbMap::accessor &accessor, + std::list &dataLst, std::unordered_map &raii) +{ + std::unordered_map streamName2BytesNumSendSuccess; + for (auto i = 0; i < pushRspPb.error_size(); ++i) { + auto &err = pushRspPb.error(i); + auto seqNo = rq.metas(i).seq(); + // Find the matching seqNo and also matching stream name. + // Note that seqNo in shared page case is at actual stream name granularity. + const std::string &streamName = rq.stream_names(rq.metas(i).stream_index()); + auto it = std::find_if(dataLst.begin(), dataLst.end(), [seqNo, &streamName](const auto &kv) { + auto sharedPageView = std::static_pointer_cast(kv.first); + for (auto seqIter = sharedPageView->seqNums_.begin(), viewIter = sharedPageView->elementViews_.begin(); + seqIter != sharedPageView->seqNums_.end(); seqIter++, viewIter++) { + if (*seqIter == seqNo && (*viewIter)->StreamName() == streamName) { + return true; + } + } + return false; + }); + if (it == dataLst.end()) { + LOG(ERROR) << FormatString("[%s, S:%s] Unable to find seqNo %zu", LogPrefix(), keyName, seqNo); + continue; + } + auto status = Status(static_cast(err.error_code()), err.error_msg()); + auto remoteElementView = std::static_pointer_cast(it->first); + auto p = remoteElementView->GetAckRange(); + auto begCursor = p.first; + auto endCursor = begCursor + p.second; + if (status.IsOk()) { + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s, S:%s] Remote send elements [seq:%zu] [%zu, %zu) to remote worker %s is successful. Ref count %zu", + LogPrefix(), keyName, seqNo, begCursor, endCursor, remoteWorkerAddr_.ToString(), + remoteElementView.use_count()); + streamName2BytesNumSendSuccess[streamName] = std::accumulate( + rq.metas(i).element_meta().element_sizes().begin(), rq.metas(i).element_meta().element_sizes().end(), + streamName2BytesNumSendSuccess[streamName]); + if (remoteElementView->elementViews_.back()->bigElement_) { + streamName2BytesNumSendSuccess[streamName] += + remoteElementView->elementViews_.back()->bigElementMetaSize_; + } + // Ack and release page. Same logic as in BatchFlushAsyncRead + SyncStreamLastAckCursor(accessor, Optional(p)); + // If we have a reference on the page, decrement the count + remoteElementView->ReleasePage(); + it = dataLst.erase(it); + // If the sending status is K_OK, the success rate is 1. + RecordRemoteSendSuccess(true); + continue; + } + // The rest is error handling + LOG(INFO) << FormatString( + "[%s, S:%s, I:%s] Remote send elements [seq:%zu] [%zu, %zu) to remote worker %s gave status: %s", + LogPrefix(), keyName, workerInstanceId_, seqNo, begCursor, endCursor, remoteWorkerAddr_.ToString(), + status.ToString()); + if (status.GetCode() == K_SC_CONSUMER_NOT_FOUND) { + // Discard the buffer when consumer does not exist, instead of putting to retry. + // Note that in shared page scenario, we can only discard the current buffer, + // as the other buffers can be from different streams. + VLOG(SC_INTERNAL_LOG_LEVEL) << "No consumer found: Discarding one buffer from shared page: " << keyName; + SyncStreamLastAckCursor(accessor, Optional(p)); + DiscardBufferHelper(dataLst, it); + } + if (status.GetCode() != K_OUT_OF_MEMORY) { + // If the sending status is not K_OK, the success rate is 0. Ignore OOM. + RecordRemoteSendSuccess(false); + continue; + } + Status allocRc = remoteElementView->MoveBufToShmUnit(); + if (allocRc.IsError()) { + LOG(ERROR) << FormatString("[%s, S:%s] Cursor [%zu, %zu) MoveBufToShmUnit failed. %s", LogPrefix(), + streamName, begCursor, endCursor, allocRc.ToString()); + continue; + } + // The function MoveBufToShmUnit is for shared page scenario, currently it only supports single consumer. + // We allow ack so then the page would not be occupied in back-pressure case. + SyncStreamLastAckCursor(accessor, Optional(p)); + } + for (const auto &kv : streamName2BytesNumSendSuccess) { + auto itr = raii.find(kv.first); + if (itr == raii.end()) { + LOG(ERROR) << FormatString( + "Decrease shared memory usage for stream[%s] failed because no worker area was found", kv.first); + continue; + } + LOG_IF_ERROR( + (*(itr->second))->second->TryDecUsage(kv.second), + FormatString("Decrease shared memory usage for stream[%s] failed during push element to remote worker", + kv.first)); + } +} + +void RemoteWorker::DiscardBuffers(std::list &dataLst) +{ + // Discard all the buffers. + auto iter = dataLst.begin(); + while (iter != dataLst.end()) { + DiscardBufferHelper(dataLst, iter); + } + dataLst.clear(); +} + +void RemoteWorker::DiscardBufferHelper(std::list &dataLst, std::list::iterator &iter) +{ + auto remoteElementView = std::static_pointer_cast(iter->first); + remoteElementView->DiscardBufferFromList(dataLst, iter); +} + +Status RemoteWorker::ParseProducerPendingFlushList(const std::string &streamName, const std::string &producerId, + std::list &dataLst, std::vector &requests, + std::vector> &payloads, + std::unordered_map &raii, + std::list> &moveList, + std::unordered_set> &needAckList) +{ + RETURN_OK_IF_TRUE(dataLst.empty()); + uint64_t firstSeqNo = dataLst.begin()->second; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("Processing pending send Flush List [%s, %s, %zu]", streamName, + producerId, firstSeqNo); + // Try not to send too many rpc that the worker can't handle + uint64_t windowCount = GetMaxWindowCount(streamName); + auto firstCursor = std::numeric_limits::max(); + auto lastCursor = std::numeric_limits::min(); + auto it = dataLst.begin(); + bool sharedPage = std::static_pointer_cast(it->first)->IsSharedPage(); + // back up shared page queue in case all elements are to be discarded. + std::shared_ptr backupSharedPage = + sharedPage ? std::static_pointer_cast( + std::static_pointer_cast(it->first)->dataObj_) + : nullptr; + bool needAck = true; + while (it != dataLst.end() && windowCount-- > 0) { + std::variant pushReq; + std::vector elements; + if (!sharedPage) { + PushReqPb pushReqPb; + RETURN_IF_NOT_OK(FillExclusivePushReqHelper(streamName, producerId, firstSeqNo, dataLst, it, firstCursor, + lastCursor, pushReqPb, elements, raii)); + pushReq = std::move(pushReqPb); + } else { + SharedPagePushReqPb pushReqPb; + Status rc = FillSharedPushReqHelper(producerId, dataLst, it, firstCursor, lastCursor, pushReqPb, elements, + raii, moveList); + if (rc.GetCode() == K_SC_STREAM_NOT_FOUND) { + RemoteStreamInfoTbbMap::accessor accessor; + if (GetAccessor(streamName, accessor).IsOk()) { + auto remoteElementView = std::static_pointer_cast(it->first); + auto p = remoteElementView->GetAckRange(); + SyncStreamLastAckCursor(accessor, Optional(p)); + } + DiscardBufferHelper(dataLst, it); + continue; + } else if (rc.GetCode() == K_NOT_READY) { + // Request is not ready because element views are all skipped. + // Note that this code should be internal to this function, it should not propagate further down. + continue; + } + RETURN_IF_NOT_OK(rc); + pushReq = std::move(pushReqPb); + } + requests.emplace_back(std::move(pushReq), streamName); + payloads.push_back(std::move(elements)); + needAck = false; + } + // If elements are to be sent to remote, there is no need to ack up here. + // Otherwise if elements are skipped due to blocking in shared page case, + // or if elements are being discarded, + // we need to perform ack so that elements do not occupy twice the shm. + if (needAck) { + needAckList.emplace(backupSharedPage); + } + // All cursors before firstSeqNo has been reclaimed by the remote worker. The next potential + // one is the last element on the list. + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, S:%s] Ack range [%zu, %zu]. Window size %zu", LogPrefix(), + streamName, firstCursor, lastCursor, requests.size()); + return Status::OK(); +} + +Status RemoteWorker::FillExclusivePushReqHelper(const std::string &streamName, const std::string &producerId, + uint64_t firstSeqNo, std::list &dataLst, + std::list::iterator &it, uint64_t &firstCursor, + uint64_t &lastCursor, PushReqPb &pushReqPb, + std::vector &elements, + std::unordered_map &raii) +{ + RETURN_IF_NOT_OK(LockStreamManagerHelper(streamName, raii)); + pushReqPb.set_stream_name(streamName); + pushReqPb.set_producer_id(producerId); + pushReqPb.set_worker_addr(localWorkerAddr_.ToString()); + pushReqPb.set_first_seq(firstSeqNo); + pushReqPb.set_worker_instance_id(workerInstanceId_); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(it->first->traceId_); + pushReqPb.set_trace_id(Trace::Instance().GetTraceID()); + size_t chunkSz = 0; + const size_t zmqChunkSz = static_cast(FLAGS_zmq_chunk_sz); + // Only batch up to FLAGS_zmq_chunk_sz. Make sure we send at least one PV + do { + auto seqNo = it->second; + auto streamElementView = std::static_pointer_cast(it->first); + auto &eleSzs = streamElementView->sz_; + size_t payloadSz = std::accumulate(eleSzs.begin(), eleSzs.end(), 0ul); + if ((chunkSz > 0 && ((payloadSz > zmqChunkSz) || chunkSz > zmqChunkSz - payloadSz))) { + break; + } + chunkSz += payloadSz; + auto *ele = pushReqPb.mutable_element_meta()->Add(); + ele->mutable_element_sizes()->Add(eleSzs.begin(), eleSzs.end()); + auto &headerBits = streamElementView->headerBits_; + ele->mutable_header_bits()->Add(headerBits.begin(), headerBits.end()); + pushReqPb.mutable_seq()->Add(seqNo); + elements.emplace_back(streamElementView->GetBufferPointer(), payloadSz); + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s, S:%s, I:%s] Remote send elements [seq:%zu] [%zu, %zu) to remote worker %s, page: %s", LogPrefix(), + streamName, workerInstanceId_, seqNo, streamElementView->begCursor_, + streamElementView->begCursor_ + streamElementView->sz_.size(), remoteWorkerAddr_.ToString(), + streamElementView->page_->GetPageId()); + firstCursor = std::min(firstCursor, streamElementView->begCursor_); + lastCursor = std::max(lastCursor, streamElementView->begCursor_ + streamElementView->sz_.size() - 1); + ++it; + } while (it != dataLst.end()); + pushReqPb.set_chunk_size(chunkSz); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(pushReqPb)); + return Status::OK(); +} + +Status RemoteWorker::FillSharedPushReqHelper(const std::string &producerId, std::list &dataLst, + std::list::iterator &it, uint64_t &firstCursor, + uint64_t &lastCursor, SharedPagePushReqPb &pushReqPb, + std::vector &elements, + std::unordered_map &raii, + std::list> &moveList) +{ + std::unordered_map streamIndexMapping; + std::unordered_map streamBlockInfoMap; + pushReqPb.set_producer_id(producerId); + pushReqPb.set_worker_addr(localWorkerAddr_.ToString()); + pushReqPb.set_worker_instance_id(workerInstanceId_); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(it->first->traceId_); + pushReqPb.set_trace_id(Trace::Instance().GetTraceID()); + bool requestReady = false; + size_t chunkSz = 0; + const size_t zmqChunkSz = static_cast(FLAGS_zmq_chunk_sz); + // Only batch up to FLAGS_zmq_chunk_sz. Make sure we send at least one PV + do { + // Fixme: actually deal with list of element views. + auto sharedPageElementView = std::static_pointer_cast(it->first); + auto streamElementView = sharedPageElementView->elementViews_.front(); + auto seqNo = sharedPageElementView->seqNums_.front(); + const std::string &streamName = streamElementView->streamName_; + RETURN_IF_NOT_OK(LockStreamManagerHelper(streamName, raii)); + auto &eleSzs = streamElementView->sz_; + size_t payloadSz = std::accumulate(eleSzs.begin(), eleSzs.end(), 0ul); + // Skip if blocked. Record the stream names so we get consistent results in this loop. + auto streamBlockInfo = streamBlockInfoMap.find(streamName); + if (streamBlockInfo == streamBlockInfoMap.end()) { + streamBlockInfo = streamBlockInfoMap.emplace(streamName, IsStreamSendBlocked(streamName)).first; + } + if (streamBlockInfo->second) { + ++it; + // delay the move to after the requests are sent. + moveList.emplace_back(sharedPageElementView); + continue; + } + auto iter = streamIndexMapping.find(streamName); + if (iter == streamIndexMapping.end()) { + iter = streamIndexMapping.emplace(streamName, streamIndexMapping.size()).first; + pushReqPb.mutable_stream_names()->Add(streamName.c_str()); + } + if ((chunkSz > 0 && ((payloadSz > zmqChunkSz) || (chunkSz + payloadSz) > zmqChunkSz))) { + break; + } + chunkSz += payloadSz; + // Fill in StreamElementsMetaPb with stream index, sequence number and the actual view meta. + auto *meta = pushReqPb.mutable_metas()->Add(); + meta->set_stream_index(iter->second); + meta->set_seq(seqNo); + auto *ele = meta->mutable_element_meta(); + ele->mutable_element_sizes()->Add(eleSzs.begin(), eleSzs.end()); + auto &headerBits = streamElementView->headerBits_; + ele->mutable_header_bits()->Add(headerBits.begin(), headerBits.end()); + elements.emplace_back(streamElementView->GetBufferPointer(), payloadSz); + const int logPerCount = VLOG_IS_ON(SC_INTERNAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s, S:%s, I:%s] Remote send elements [seq:%zu] [%zu, %zu) to remote worker %s, page: %s", LogPrefix(), + streamName, workerInstanceId_, seqNo, streamElementView->begCursor_, + streamElementView->begCursor_ + streamElementView->sz_.size(), remoteWorkerAddr_.ToString(), + streamElementView->page_->GetPageId()); + firstCursor = std::min(firstCursor, streamElementView->begCursor_); + lastCursor = std::max(lastCursor, streamElementView->begCursor_ + streamElementView->sz_.size() - 1); + ++it; + // If all views are skipped due to blocking, then request is not prepared. + requestReady = true; + } while (it != dataLst.end()); + CHECK_FAIL_RETURN_STATUS(requestReady, K_NOT_READY, "All element views are skipped, request is not ready"); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(pushReqPb)); + return Status::OK(); +} + +Status RemoteWorker::LockStreamManagerHelper(const std::string &streamName, + std::unordered_map &raii) +{ + if (raii.find(streamName) == raii.end()) { + // We can't allow the stream to be deleted while we are traversing the shared memory. + StreamRaii rlock = std::make_unique(); + RETURN_IF_NOT_OK(scSvc_->GetStreamManager(streamName, *rlock)); + std::shared_ptr streamMgr = (*rlock)->second; + // Check the state (delete, reset, etc) + RETURN_IF_NOT_OK(streamMgr->CheckIfStreamActive()); + raii.emplace(streamName, std::move(rlock)); + } + return Status::OK(); +} + +Status RemoteWorker::ProcessEndOfStream(const std::shared_ptr &streamMgr, std::list dataLst, + const std::string &streamName, const std::string &producerId) +{ + (void)streamMgr; + (void)producerId; + // Move up the ack. We aren't going to send them. + RemoteStreamInfoTbbMap::accessor accessor; + Status rc = GetAccessor(streamName, accessor); + if (rc.IsOk()) { + std::get(accessor->second).Reset(); + accessor.release(); + } + // Discard all the buffers. + DiscardBuffers(dataLst); + // If the stream is blocked, unblock it. + RETURN_IF_NOT_OK_EXCEPT(remoteConsumers_.ToggleStreamBlocking(streamName, false), K_SC_STREAM_NOT_FOUND); + // Signal this job is done. + return Status::OK(); +} + +Status RemoteWorker::ParsePendingFlushList(const PendingFlushList &pendingFlushList, std::vector &requests, + std::vector> &payloads, + std::unordered_map &raii, + std::list> &moveList, + std::unordered_set> &needAckList) +{ + for (const auto &ele : pendingFlushList) { + const StreamProducerKey key = ele.first; + std::list &dataLst = ele.second; + const std::string &streamName = key.firstKey_; + const std::string &producerId = key.producerId_; + if (IsStreamSendBlocked(streamName)) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Ignore stream %s producer %s", LogPrefix(), streamName, + producerId); + continue; + } + if (dataLst.empty()) { + continue; + } + Status rc = ParseProducerPendingFlushList(streamName, producerId, dataLst, requests, payloads, raii, moveList, + needAckList); + if (rc.GetCode() == K_SC_STREAM_NOT_FOUND || rc.GetCode() == K_SC_STREAM_DELETE_IN_PROGRESS + || rc.GetCode() == K_SC_STREAM_IN_RESET_STATE) { + continue; + } + RETURN_IF_NOT_OK(rc); + } + return Status::OK(); +} + +int RemoteWorker::BatchFlushAsyncWrite(const std::shared_ptr &stub, + std::vector &requests, std::vector> &payloads) +{ + int numReqSent = 0; + for (size_t i = 0; i < requests.size(); ++i) { + Status &status = requests.at(i).rc_; + auto &pushReq = requests.at(i).req_; + const auto visitor = [&](auto &&pushReqPb) { + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReqPb.trace_id()); + if constexpr (std::is_same_v, PushReqPb>) { + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("Calling PushElementsCursorsAsyncWrite for %s with %zu PV", + pushReqPb.stream_name(), pushReqPb.element_meta_size()); + status = stub->PushElementsCursorsAsyncWrite(pushReqPb, requests.at(i).tag_, payloads.at(i)); + } else { + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString( + "Calling PushSharedPageCursorsAsyncWrite for shared page with %zu PV", pushReqPb.metas_size()); + status = stub->PushSharedPageCursorsAsyncWrite(pushReqPb, requests.at(i).tag_, payloads.at(i)); + } + }; + + PerfPoint point(PerfKey::REMOTE_WORKER_SEND_ONE_STREAM); + std::visit(visitor, pushReq); + // Need to count AsyncWrite even if the AsyncRead returns K_TRY_AGAIN + numReqSent++; + } + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("Number of outstanding PushElementsCursorsAsyncWrite request %d", + numReqSent); + return numReqSent; +} + +void RemoteWorker::BatchFlushAsyncRead(const std::shared_ptr &stub, + PendingFlushList &pendingFlushList, std::vector &requests, + std::unordered_map &raii) +{ + size_t numAsync = requests.size(); + for (size_t i = 0; i < numAsync; ++i) { + Status &status = requests.at(i).rc_; + // Check the return code from PushElementsCursorsAsyncWrite + if (status.IsError()) { + continue; + } + const auto visitor = [&](auto &&pushReq) { + // TraceID of StreamProducerKey is stored inside request, set in ParseProducerPendingFlushList + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReq.trace_id()); + PushRspPb pushRspPb; + PerfPoint point(PerfKey::REMOTE_WORKER_MAIN_RECV); + INJECT_POINT("RemoteWorker.SleepBeforeAsyncRead", [](uint64_t timeoutMs) -> void { + std::this_thread::sleep_for(std::chrono::milliseconds(timeoutMs)); + return; + }); + if constexpr (std::is_same_v, PushReqPb>) { + status = stub->PushElementsCursorsAsyncRead(requests.at(i).tag_, pushRspPb, RpcRecvFlags::NONE); + } else { + status = stub->PushSharedPageCursorsAsyncRead(requests.at(i).tag_, pushRspPb, RpcRecvFlags::NONE); + } + point.Record(); + INJECT_POINT_NO_RETURN("RemoteWorker.BatchFlushAsyncRead.rpc.timeout", [&status]() { + status = { K_RPC_UNAVAILABLE, "Fake worker not responding" }; + }); + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("PushElementsCursorsAsyncRead rc for stream %s: %s", + requests.at(i).keyName_, status.ToString()); + PostRecvCleanup(requests.at(i).keyName_, status, pendingFlushList, pushReq, pushRspPb, raii); + }; + std::visit(visitor, requests.at(i).req_); + } +} + +void RemoteWorker::HandleBlockedElements(std::list> &moveList, + std::unordered_set> &needAckList) +{ + std::unordered_set oomList; + for (auto &sharedPageElementView : moveList) { + // Avoid occupying the shared page when stream is blocked. + // Move buf to shm when we are about to batch and send to remote. + // And then allow page to be acked. + const auto &keyName = sharedPageElementView->KeyName(); + // If we previously got OOM, skip it and continue to the next stream. + // 1. we want to preserve the order so that we are more likely to ack pages in order + // 2. it is likely that OOM will still happen at least for the same stream, so skip the stream to save the time. + if (oomList.find(keyName) != oomList.end()) { + continue; + } + // Fixme: Actually deal with partial of the StreamElementView that needs to be moved. + Status allocRc = sharedPageElementView->MoveBufToShmUnit(); + if (allocRc.IsError()) { + auto p = sharedPageElementView->GetAckRange(); + LOG(WARNING) << FormatString("[%s, S:%s] Cursor [%zu, %zu) MoveBufToShmUnit failed. %s", LogPrefix(), + sharedPageElementView->streamName_, p.first, p.first + p.second, + allocRc.ToString()); + if (allocRc.GetCode() == K_OUT_OF_MEMORY) { + oomList.emplace(keyName); + } + continue; + } + RemoteStreamInfoTbbMap::accessor accessor; + if (GetAccessor(keyName, accessor).IsOk()) { + auto p = sharedPageElementView->GetAckRange(); + SyncStreamLastAckCursor(accessor, Optional(p)); + } + } + for (auto &sharedPageQueue : needAckList) { + LOG_IF_ERROR(sharedPageQueue->RemoteAck(), + FormatString("[%s, S:%s] Remote ack failed", LogPrefix(), sharedPageQueue->GetPageQueueId())); + } +} + +Status RemoteWorker::BatchAsyncFlushEntry(PendingFlushList &pendingFlushList) +{ + INJECT_POINT("RemoteWorker.BatchAsyncFlushEntry.Sleep", [](int sleepSecond) { + std::this_thread::sleep_for(std::chrono::seconds(sleepSecond)); + return Status::OK(); + }); + std::vector requests; + std::vector> payloads; + std::unordered_map raii; // holds all the const_accessor + std::list> moveList; + std::unordered_set> needAckList; + RETURN_IF_NOT_OK(ParsePendingFlushList(pendingFlushList, requests, payloads, raii, moveList, needAckList)); + std::shared_ptr stub; + RETURN_IF_NOT_OK(RpcStubCacheMgr::Instance().GetStub(remoteWorkerAddr_, StubType::WORKER_WORKER_SC_SVC, stub)); + auto derivedStub = std::dynamic_pointer_cast(stub); + RETURN_RUNTIME_ERROR_IF_NULL(derivedStub); + // This code is driven by BufferPool async flush code path which will retry on error. + auto numRequestSent = BatchFlushAsyncWrite(derivedStub, requests, payloads); + // Handle the blocked elements in between async write and read. + HandleBlockedElements(moveList, needAckList); + // If nothing is sent out and there is no need to continue. + RETURN_OK_IF_TRUE(numRequestSent == 0); + BatchFlushAsyncRead(derivedStub, pendingFlushList, requests, raii); + // Return the first non-ok error + for (auto &req : requests) { + if (req.rc_.IsError()) { + return req.rc_; + } + } + return Status::OK(); +} + +Status RemoteWorker::GetStreamLastAckCursor(const std::string &streamName, uint64_t &cursor) +{ + RETURN_IF_NOT_OK(remoteConsumers_.GetStreamLastAckCursor(streamName, cursor)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] remoteAck = %zu", LogPrefix(), streamName, cursor); + return Status::OK(); +} + +void RemoteWorker::SyncStreamLastAckCursor(RemoteStreamInfoTbbMap::accessor &accessor, + Optional ackRange) +{ + remoteConsumers_.SyncStreamLastAckCursor(accessor, ackRange, + FormatString("%s, S:%s", LogPrefix(), accessor->first)); +} + +std::string RemoteWorker::LogPrefix() const +{ + return FormatString("RW:%s", remoteWorkerAddr_.ToString()); +} + +bool RemoteWorker::ExistsRemoteConsumer() +{ + return !remoteConsumers_.Empty(); +} + +void RemoteWorker::GetOrCreateSharedPageQueue(const std::string &namespaceUri, + std::shared_ptr &pageQueue) +{ + sharedPageGroup_.GetOrCreateSharedPageQueue(namespaceUri, pageQueue); +} + +// Class RemoteWorkerManager part +RemoteWorkerManager::RemoteWorkerManager(ClientWorkerSCServiceImpl *scSvc, std::shared_ptr akSkManager, + std::shared_ptr scAllocateManager) + : akSkManager_(std::move(akSkManager)), scSvc_(scSvc), scAllocateManager_(scAllocateManager) +{ +} + +RemoteWorkerManager::~RemoteWorkerManager() +{ + // The remoteWorkerDict_ keeps the pointer of RemoteWorkerManager, it needs to be cleared first. + remoteWorkerDict_.clear(); + if (dataMap_) { + dataMap_->Stop(); + } +} + +Status RemoteWorkerManager::Init() +{ + // for remote send + dataMap_ = std::make_unique( + FLAGS_remote_send_thread_num, "ScPushToRemote", + std::bind(&RemoteWorkerManager::BatchAsyncFlushEntry, this, std::placeholders::_1, std::placeholders::_2)); + RETURN_IF_NOT_OK(dataMap_->Init()); + // for scan + dataPool_ = std::make_unique(); + RETURN_IF_NOT_OK(dataPool_->Init()); + // An unique id will be generated on each restart + workerInstanceId_ = GetStringUuid(); + return Status::OK(); +} + +Status RemoteWorkerManager::GetRemoteWorker(const std::string &address, std::shared_ptr &remoteWorker) +{ + std::shared_lock lock(mutex_); + auto iter = remoteWorkerDict_.find(address); + CHECK_FAIL_RETURN_STATUS(iter != remoteWorkerDict_.end(), StatusCode::K_NOT_FOUND, + FormatString("Remote worker:<%s> does not exist", address)); + RETURN_RUNTIME_ERROR_IF_NULL(iter->second); + remoteWorker = iter->second; + return Status::OK(); +} + +uint64_t RemoteWorkerManager::GetLastAckCursor(const std::string &streamName) +{ + std::shared_lock lock(mutex_); + if (remoteWorkerDict_.empty()) { + return 0; + } + uint64_t lastAckCursor = std::numeric_limits::max(); + for (auto &ele : remoteWorkerDict_) { + auto &remoteWorker = ele.second; + uint64_t cursor; + Status rc = remoteWorker->GetStreamLastAckCursor(streamName, cursor); + if (rc.IsOk()) { + lastAckCursor = std::min(lastAckCursor, cursor); + } + } + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString("[S:%s] Remote consumer(s) lastAckCursor = %zu", streamName, + lastAckCursor); + return lastAckCursor; +} + +void RemoteWorkerManager::RemoveStream(const std::string &keyName, const std::string &sharedPageName) +{ + dataMap_->RemoveStream(keyName, sharedPageName); +} + +void RemoteWorkerManager::PurgeBuffer(const std::shared_ptr &streamMgr) +{ + dataMap_->PurgeBuffer(streamMgr->GetStreamName(), + std::bind(&RemoteWorkerManager::ProcessEndOfStream, this, streamMgr, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3)); +} + +Status RemoteWorkerManager::ProcessEndOfStream(const std::shared_ptr &streamMgr, + std::list dataLst, const std::string &streamName, + const std::string &producerId) +{ + std::shared_lock lock(mutex_); + RETURN_OK_IF_TRUE(remoteWorkerDict_.empty()); + std::vector status(remoteWorkerDict_.size()); + size_t i = 0; + auto iter = remoteWorkerDict_.begin(); + while (iter != remoteWorkerDict_.end()) { + auto rw = iter->second; + status.at(i) = rw->ProcessEndOfStream(streamMgr, dataLst, streamName, producerId); + ++iter; + ++i; + } + auto rc = std::find_if(status.begin(), status.end(), [](auto &kv) { return kv.IsError(); }); + if (rc != status.end()) { + return (*rc); + } + return Status::OK(); +} + +Status RemoteWorkerManager::StreamNoToName(uint64_t streamNo, std::string &streamName) +{ + return scSvc_->StreamNoToName(streamNo, streamName); +} + +Status RemoteWorkerManager::SendElementsView(const std::shared_ptr &eleView) +{ + const std::string &streamName = eleView->StreamName(); + auto &remoteWorker = eleView->remoteWorker_; + std::shared_ptr rw; + if (!eleView->remote_) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[RW:%s, S:%s] Flush %zu elements remotely", remoteWorker, + streamName, eleView->GetElementNum()); + RETURN_IF_NOT_OK(eleView->IncRefCount()); + INJECT_POINT("RemoteWorkerManager.SendElementsView.PostIncRefCount"); + dataMap_->Insert(eleView); + } else { + // We don't send remote elements back to remote workers, or we will run into infinite loop. + // But we need to keep track of the gap and move up the ack cursor accordingly. + RETURN_IF_NOT_OK(GetRemoteWorker(remoteWorker, rw)); + RemoteStreamInfoTbbMap::accessor accessor; + RETURN_IF_NOT_OK(rw->GetAccessor(streamName, accessor)); + RemoteAckInfo::AckRange p = eleView->GetAckRange(); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, S:%s] Ack cursor [%zu, %zu)", rw->LogPrefix(), streamName, + p.first, p.first + p.second); + // Move up the ack if necessary + rw->SyncStreamLastAckCursor(accessor, Optional(p)); + } + return Status::OK(); +} + +Status RemoteWorkerManager::BatchAsyncFlushEntry(int myId, const PendingFlushList &pendingFlushList) +{ + (void)myId; + auto traceGuard = Trace::Instance().SetTraceUUID(); + std::unordered_map, PendingFlushList>> flushMap; + for (const auto &ele : pendingFlushList) { + const StreamProducerKey key = ele.first; + std::list &dataLst = ele.second; + const std::string &streamName = key.firstKey_; + const std::string &remoteWorker = key.producerId_; + auto it = flushMap.find(remoteWorker); + if (it == flushMap.end()) { + std::shared_lock lock(mutex_); + auto iter = remoteWorkerDict_.find(remoteWorker); + if (iter == remoteWorkerDict_.end()) { + // Discard all the buffers. + RemoteWorker::DiscardBuffers(dataLst); + continue; + } + it = flushMap.emplace(remoteWorker, std::make_pair(iter->second, PendingFlushList())).first; + } + auto &rw = it->second.first; + // Check again if we still have remote consumer. + if (rw->HasRemoteConsumers(streamName)) { + it->second.second.push_back(ele); + } else { + RemoteWorker::DiscardBuffers(dataLst); + continue; + } + } + RETURN_OK_IF_TRUE(flushMap.empty()); + std::vector status(flushMap.size()); + size_t i = 0; + auto it = flushMap.begin(); + while (it != flushMap.end()) { + auto &rw = it->second.first; + status.at(i) = rw->BatchAsyncFlushEntry(it->second.second); + ++it; + ++i; + } + auto rc = std::find_if(status.begin(), status.end(), [](auto &kv) { return kv.IsError(); }); + if (rc != status.end()) { + return (*rc); + } + return Status::OK(); +} + +bool RemoteWorkerManager::HasRemoteConsumers(const std::string &streamName) +{ + std::shared_lock rlock(mutex_); + return std::any_of(remoteWorkerDict_.begin(), remoteWorkerDict_.end(), + [&streamName](const auto &kv) { return kv.second->HasRemoteConsumers(streamName); }); +} + +Status RemoteWorkerManager::DeleteStream(const std::string &streamName) +{ + // Stop to push new buffer into RW + RETURN_IF_NOT_OK_EXCEPT(dataPool_->RemoveStreamObject(streamName, {}), K_SC_STREAM_NOT_FOUND); + std::lock_guard lock(mutex_); + auto iter = remoteWorkerDict_.begin(); + while (iter != remoteWorkerDict_.end()) { + RETURN_RUNTIME_ERROR_IF_NULL(iter->second); + auto &rw = iter->second; + Optional mapEmpty(false); + RETURN_IF_NOT_OK_EXCEPT(rw->DeleteStream(streamName, mapEmpty), K_SC_STREAM_NOT_FOUND); + if (mapEmpty.value()) { + LOG(INFO) << "Erase remote worker " << rw->remoteWorkerAddr_.ToString() << " from remoteWorkerDict_"; + iter = remoteWorkerDict_.erase(iter); + } else { + ++iter; + } + } + return Status::OK(); +} + +Status RemoteWorkerManager::DoneScanning(const std::string &streamName) +{ + RETURN_IF_NOT_OK_EXCEPT(dataPool_->RemoveStreamObject(streamName, {}), K_SC_STREAM_NOT_FOUND); + return Status::OK(); +} + +std::string RemoteWorkerManager::GetSCRemoteSendSuccessRate() +{ + return remoteSendRateVec_.BlockingGetRateToStringAndClean(); +} + +Status RemoteWorkerManager::ToggleStreamBlocking(const std::string &workerAddr, const std::string &streamName, + bool enable) +{ + if (enable) { + INJECT_POINT("RemoteWorker.EnableStreamBlocking.sleep"); + } + std::shared_ptr rw; + RETURN_IF_NOT_OK(GetRemoteWorker(workerAddr, rw)); + VLOG(SC_NORMAL_LOG_LEVEL) << (enable ? "Blocking" : "Unblocking") << " Producer for stream: " << streamName + << " From remote worker: " << workerAddr; + RemoteStreamInfoTbbMap::accessor accessor; + RETURN_IF_NOT_OK(rw->GetAccessor(streamName, accessor)); + std::get(accessor->second) = enable; + // Update stream metrics if it is enabled + if (ScMetricsMonitor::Instance()->IsEnabled()) { + uint64_t numRemoteConsumers = std::get(accessor->second).size(); + accessor.release(); + StreamManagerMap::const_accessor streamMgrAccessor; + RETURN_IF_NOT_OK(scSvc_->GetStreamManager(streamName, streamMgrAccessor)); + if (enable) { + streamMgrAccessor->second->GetSCStreamMetrics()->IncrementMetric(StreamMetric::NumRemoteConsumersBlocking, + numRemoteConsumers); + } else { + streamMgrAccessor->second->GetSCStreamMetrics()->DecrementMetric(StreamMetric::NumRemoteConsumersBlocking, + numRemoteConsumers); + } + } + return Status::OK(); +} + +Status RemoteWorkerManager::DelRemoteConsumer(const std::string &workerAddr, const std::string &streamName, + const std::string &consumerId) +{ + std::vector dest; + { + std::unique_lock lock(mutex_); + auto iter = remoteWorkerDict_.find(workerAddr); + CHECK_FAIL_RETURN_STATUS(iter != remoteWorkerDict_.end(), K_NOT_FOUND, + FormatString("Remote worker:<%s> does not exist", workerAddr)); + auto &rw = iter->second; + Optional mapEmpty(false); + RETURN_IF_NOT_OK(rw->DelRemoteConsumer(streamName, consumerId, mapEmpty)); + if (mapEmpty.value()) { + LOG(INFO) << "Erase remote worker " << workerAddr << " from remoteWorkerDict_"; + (void)remoteWorkerDict_.erase(iter); + } + dest = GetRemoteWorkers(streamName); + } + // Update the scan list. + RETURN_IF_NOT_OK_EXCEPT(dataPool_->RemoveStreamObject(streamName, dest), K_SC_STREAM_NOT_FOUND); + return Status::OK(); +} + +Status RemoteWorkerManager::AddRemoteConsumer(const std::shared_ptr &streamMgr, + const HostPort &localWorkerAddress, const HostPort &remoteWorkerAddress, + const std::string &streamName, const SubscriptionConfig &subConfig, + const std::string &consumerId, uint64_t lastAckCursor) +{ + StreamFields streamFields; + streamMgr->GetStreamFields(streamFields); + std::shared_ptr sharedPage; + std::vector dest; + { + std::shared_ptr rw; + std::lock_guard lock(mutex_); + auto iter = remoteWorkerDict_.find(remoteWorkerAddress.ToString()); + if (iter == remoteWorkerDict_.end()) { + auto remoteWorker = std::make_shared(localWorkerAddress, remoteWorkerAddress, akSkManager_, + scSvc_, workerInstanceId_, scAllocateManager_, this); + RETURN_IF_NOT_OK(remoteWorker->Init()); + remoteWorker->RegisterRecordRemoteSendRateCallBack( + [this](int successNum, int totalNum) { remoteSendRateVec_.BlockingEmplaceBack(successNum, totalNum); }); + + iter = remoteWorkerDict_.emplace(remoteWorkerAddress.ToString(), remoteWorker).first; + } + rw = iter->second; + RETURN_IF_NOT_OK( + rw->AddRemoteConsumer(streamName, subConfig, consumerId, streamMgr->GetMaxWindowCount(), lastAckCursor)); + if (StreamManager::EnableSharedPage(streamFields.streamMode_)) { + rw->GetOrCreateSharedPageQueue(streamName, sharedPage); + streamMgr->SetSharedPageQueue(sharedPage); + const std::string &keyName = sharedPage->GetStreamName(); + // Add fake consumer for shared page remote ack purposes. + // Calculate the ack cursor in case of shared page queue re-added to scan list. + auto lastAppendCursor = sharedPage->GetLastAppendCursor(); + lastAckCursor = lastAppendCursor; + uint64_t cursor; + // Shared page only supports single consumer, so only need to check this remote worker for ack cursor. + Status rc = rw->GetStreamLastAckCursor(keyName, cursor); + if (rc.IsOk()) { + lastAckCursor = std::min(lastAckCursor, cursor); + } + RETURN_IF_NOT_OK(rw->AddRemoteConsumer(keyName, SubscriptionConfig(), keyName, 1, 0)); + } + dest = GetRemoteWorkers(streamName); + } + // Add the stream to the scan list, and update the destination. + if (!StreamManager::EnableSharedPage(streamFields.streamMode_)) { + RETURN_IF_NOT_OK_EXCEPT(dataPool_->AddStreamObject(streamMgr, streamName, dest, lastAckCursor), K_DUPLICATED); + } else { + RETURN_IF_NOT_OK_EXCEPT(dataPool_->AddSharedPageObject(sharedPage, streamName, dest, lastAckCursor), + K_DUPLICATED); + } + return Status::OK(); +} + +Status RemoteWorkerManager::ClearAllRemoteConsumer(const std::string &streamName, bool forceClose) +{ + if (forceClose) { + LOG(INFO) << "Client has crashed cleaning up the stream: " << streamName; + // If last producer is force closed due to client crash + // Irrespective of Ack position just discard the data + // by Stop scanning and remove stream from RemoteConsumerMap + RETURN_IF_NOT_OK(DeleteStream(streamName)); + } + return Status::OK(); +} + +Status RemoteWorkerManager::ResetStreamScanList(const std::string &streamName) +{ + RETURN_IF_NOT_OK_EXCEPT(dataPool_->ResetStreamScanPosition(streamName), K_SC_STREAM_NOT_FOUND); + return Status::OK(); +} + +std::vector RemoteWorkerManager::GetRemoteWorkers(const std::string &streamName) +{ + std::vector v; + std::for_each(remoteWorkerDict_.begin(), remoteWorkerDict_.end(), [&v, &streamName](const auto &kv) { + auto &rw = kv.second; + if (rw->HasRemoteConsumers(streamName)) { + v.emplace_back(kv.first); + } + }); + return v; +} + +void RemoteAckInfo::SyncStreamLastAckCursor(Optional ackRange, const std::string &logPrefix) +{ + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Last remote ack cursor %zu", logPrefix, lastAckCursor_); + if (!ackQue_.empty()) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Most recent ack cursor %zu", logPrefix, ackQue_.top().first); + } + if (ackRange) { + auto begCursor = ackRange.value().first; + if (lastAckCursor_ < begCursor) { + ackQue_.push(ackRange.value()); + } + } + while (!ackQue_.empty() && lastAckCursor_ + 1 == ackQue_.top().first) { + // We can pop the top and move up + auto ele = ackQue_.top(); + ackQue_.pop(); + lastAckCursor_ = ele.first + ele.second - 1; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Move up ack cursor to %zu", logPrefix, lastAckCursor_); + } +} + +uint64_t RemoteAckInfo::GetStreamLastAckCursor() const +{ + return lastAckCursor_; +} + +void RemoteAckInfo::Reset() +{ + lastAckCursor_ = 0; + // There is no clear() function call, and we will pop them all. + while (!ackQue_.empty()) { + ackQue_.pop(); + } +} + +RemoteAckInfo::RemoteAckInfo(uint64_t cursor) : lastAckCursor_(cursor) +{ +} + +Status RemoteConsumerMap::AddConsumer(const std::string &streamName, const std::string &consumerId, + uint64_t windowCount, uint64_t lastAckCursor) +{ + std::shared_lock lock(consumerMutex_); + RemoteStreamInfoTbbMap::accessor accessor; + if (streamConsumers_.find(accessor, streamName)) { + auto ret = std::get(accessor->second).emplace(consumerId); + CHECK_FAIL_RETURN_STATUS( + ret.second, K_DUPLICATED, + FormatString("[S:%s] Add remote consumer error. Duplicate consumer id %s", streamName, consumerId)); + } else { + std::set remoteConsumerId; + remoteConsumerId.emplace(consumerId); + auto RemoteConsumer = + std::make_tuple(false, RemoteAckInfo(lastAckCursor), windowCount, std::move(remoteConsumerId)); + streamConsumers_.emplace(accessor, streamName, std::move(RemoteConsumer)); + } + return Status::OK(); +} + +Status RemoteConsumerMap::DeleteConsumer(const std::string &streamName, const std::string &consumerId, + Optional &mapEmpty) +{ + std::unique_lock lock(consumerMutex_); + RemoteStreamInfoTbbMap::accessor accessor; + bool find = streamConsumers_.find(accessor, streamName); + CHECK_FAIL_RETURN_STATUS(find, StatusCode::K_SC_STREAM_NOT_FOUND, + FormatString("[S:%s] Stream not found", streamName)); + CHECK_FAIL_RETURN_STATUS(std::get(accessor->second).erase(consumerId) == 1, + StatusCode::K_SC_CONSUMER_NOT_FOUND, + FormatString("[S:%s, C:%s] Consumer not belong to Stream", streamName, consumerId)); + if (std::get(accessor->second).empty()) { + (void)streamConsumers_.erase(accessor); + } + if (mapEmpty) { + *mapEmpty = streamConsumers_.empty(); + } + return Status::OK(); +} + +Status RemoteConsumerMap::DeleteStream(const std::string &streamName, Optional &mapEmpty) +{ + std::unique_lock lock(consumerMutex_); + RemoteStreamInfoTbbMap::accessor accessor; + bool find = streamConsumers_.find(accessor, streamName); + CHECK_FAIL_RETURN_STATUS(find, StatusCode::K_SC_STREAM_NOT_FOUND, + FormatString("Can not find stream:<%s>", streamName)); + (void)streamConsumers_.erase(accessor); + if (mapEmpty) { + *mapEmpty = streamConsumers_.empty(); + } + return Status::OK(); +} + +bool RemoteConsumerMap::HasRemoteConsumers(const std::string &streamName) +{ + std::shared_lock lock(consumerMutex_); + RemoteStreamInfoTbbMap::const_accessor accessor; + return streamConsumers_.find(accessor, streamName); +} + +Status RemoteConsumerMap::ToggleStreamBlocking(const std::string &streamName, bool enable) +{ + std::shared_lock lock(consumerMutex_); + LOG(INFO) << (enable ? "Block" : "Unblock") << " Producer for stream " << streamName; + RemoteStreamInfoTbbMap::accessor accessor; + bool find = streamConsumers_.find(accessor, streamName); + CHECK_FAIL_RETURN_STATUS(find, StatusCode::K_SC_STREAM_NOT_FOUND, + FormatString("Can not find stream:<%s>", streamName)); + std::get(accessor->second) = enable; + return Status::OK(); +} + +bool RemoteConsumerMap::IsStreamSendBlocked(const std::string &streamName) +{ + std::shared_lock lock(consumerMutex_); + RemoteStreamInfoTbbMap::const_accessor accessor; + return streamConsumers_.find(accessor, streamName) && std::get(accessor->second); +} + +Status RemoteConsumerMap::GetStreamLastAckCursor(const std::string &streamName, uint64_t &cursor) +{ + std::shared_lock lock(consumerMutex_); + RemoteStreamInfoTbbMap::accessor accessor; + if (streamConsumers_.find(accessor, streamName)) { + cursor = std::get(accessor->second).GetStreamLastAckCursor(); + return Status::OK(); + } + RETURN_STATUS(K_SC_CONSUMER_NOT_FOUND, FormatString("Can not find stream:<%s>", streamName)); +} + +void RemoteConsumerMap::SyncStreamLastAckCursor(RemoteStreamInfoTbbMap::accessor &accessor, + Optional ackRange, + const std::string &logPrefix) +{ + std::get(accessor->second).SyncStreamLastAckCursor(ackRange, logPrefix); +} + +bool RemoteConsumerMap::Empty() const +{ + std::unique_lock lock(consumerMutex_); + return streamConsumers_.empty(); +} + +uint64_t RemoteConsumerMap::GetMaxWindowCount(const std::string &streamName) const +{ + std::shared_lock lock(consumerMutex_); + RemoteStreamInfoTbbMap::accessor accessor; + if (streamConsumers_.find(accessor, streamName)) { + return std::get(accessor->second); + } + return 1; +} + +Status RemoteConsumerMap::GetAccessor(const std::string &streamName, RemoteStreamInfoTbbMap::accessor &accessor, + const std::string &logPrefix) +{ + std::shared_lock lock(consumerMutex_); + auto success = streamConsumers_.find(accessor, streamName); + CHECK_FAIL_RETURN_STATUS(success, K_SC_STREAM_NOT_FOUND, + FormatString("[%s, S:%s] Stream not found", logPrefix, streamName)); + return Status::OK(); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/remote_worker_manager.h b/src/datasystem/worker/stream_cache/remote_worker_manager.h new file mode 100644 index 0000000..25bb9e1 --- /dev/null +++ b/src/datasystem/worker/stream_cache/remote_worker_manager.h @@ -0,0 +1,667 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_REMOTE_WORKER_MANAGER_H +#define DATASYSTEM_WORKER_STREAM_CACHE_REMOTE_WORKER_MANAGER_H + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/common/metrics/metrics_vector/metrics_sc_remote_vector.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/common/util/wait_post.h" +#include "datasystem/protos/stream_posix.service.rpc.pb.h" +#include "datasystem/protos/worker_stream.stub.rpc.pb.h" +#include "datasystem/worker/stream_cache/buffer_pool.h" +#include "datasystem/worker/stream_cache/page_queue/exclusive_page_queue.h" +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue_group.h" +#include "datasystem/worker/stream_cache/stream_data_pool.h" +#include "datasystem/worker/stream_cache/subscription.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +struct PushReq { + std::variant req_; + std::string keyName_; + int64_t tag_; + Status rc_; + PushReq(std::variant &&req, const std::string &keyName) + : req_(std::move(req)), keyName_(keyName){}; +}; +/** + * A small class to keep track of which sequence of elements have been acknowledged + * by the remote worker. + */ +class RemoteAckInfo { +public: + explicit RemoteAckInfo(uint64_t cursor); + ~RemoteAckInfo() = default; + using AckRange = std::pair; + struct Compare { + bool operator()(const AckRange &a, const AckRange &b) + { + return a.first > b.first; + } + }; + + /** + * @brief Get the last ack cursor + * @param streamName + * @param cursor + * @return + */ + uint64_t GetStreamLastAckCursor() const; + + /** + * @brief Move up last ack cursor + * @param nextAckCursor + * @return + */ + void SyncStreamLastAckCursor(Optional ackRange, const std::string &logPrefix); + + void Reset(); + +private: + uint64_t lastAckCursor_; + std::priority_queue, Compare> ackQue_; +}; + +// View of an element on a shared memory page. +// Note that Elements are packed in reverse order +struct SendElementView : public BaseBufferData { + std::string streamName_; + std::shared_ptr dataObj_{ nullptr }; + bool remote_{ false }; + std::string remoteWorker_; + std::string StreamName() const override; + std::string ProducerName() const override; + std::string ProducerInstanceId() const override; + uint64_t StreamHash() const override; + virtual bool IsSharedPage() = 0; + virtual bool PackDataElement(const DataElement &element, bool skipChecks, + RemoteWorkerManager *remoteworkerManager = nullptr) = 0; + virtual Status ReleasePage() = 0; + virtual Status IncRefCount() = 0; + virtual RemoteAckInfo::AckRange GetAckRange() = 0; + virtual uint64_t GetElementNum() = 0; + virtual void DiscardBufferFromList(std::list &dataLst, std::list::iterator &iter) = 0; + + /** + * @brief Create StreamElementView for the remote send. + * @param[in] page The stream data page of the element. + * @param[in] remoteWorker The remote worker of the element. + * @param[in] dataElement The element. + * @param[in] obj The page queue base class object. + * @param[in] remoteWorkerManager The remote worker manager ptr. + * @param[out] out The send element view. + * @return Status of the call. + */ + static Status CreateSendElementView(const std::shared_ptr &page, const std::string &remoteWorker, + DataElement &dataElement, std::shared_ptr obj, + RemoteWorkerManager *remoteWorkerManager, + std::shared_ptr &out); +}; + +struct StreamElementView : public SendElementView { + std::shared_ptr page_; + std::shared_ptr bigElementPage_; + uint64_t begCursor_; + std::vector sz_; + std::atomic buf_; + bool bigElement_{ false }; + size_t bigElementMetaSize_{ 0 }; + std::unique_ptr shmUnit_; + uint8_t *secondaryAddr_{ nullptr }; + std::unique_ptr localBuf_; // when oom hit or to perform encryption + size_t localBufSize_; + std::atomic shmEnabled_{ true }; + std::atomic ref_{ false }; + std::shared_timed_mutex mux_; + std::vector headerBits_; + + Status ReleasePage() override; + virtual Status IncRefCount(); + Status MoveBufToAlternateMemory(); + uint8_t *GetBufferPointer(); + RemoteAckInfo::AckRange GetAckRange() override; + bool IsSharedPage() override; + Status MoveBufToShmUnit(); + bool PackDataElement(const DataElement &element, bool skipChecks, + RemoteWorkerManager *remoteworkerManager = nullptr) override; + uint64_t GetElementNum() override; + void DiscardBufferFromList(std::list &dataLst, std::list::iterator &iter) override; +}; + +struct SharedPageElementView : public SendElementView { + std::string sharedPageName_; + // Element views can be of different stream names. + std::list> elementViews_; + std::list seqNums_; + bool IsSharedPage() override; + std::string KeyName() const override; + bool PackDataElement(const DataElement &element, bool skipChecks, + RemoteWorkerManager *remoteworkerManager = nullptr) override; + uint64_t RecordSeqNo(std::function fetchAddSeqNo) override; + Status ReleasePage() override; + Status IncRefCount() override; + Status MoveBufToShmUnit(); + RemoteAckInfo::AckRange GetAckRange() override; + uint64_t GetElementNum() override; + void DiscardBufferFromList(std::list &dataLst, std::list::iterator &iter) override; +}; + +using RemoteConsumers = std::tuple>; +using RemoteStreamInfoTbbMap = tbb::concurrent_hash_map; +using StreamManagerMap = tbb::concurrent_hash_map>; +using StreamRaii = std::unique_ptr; +constexpr static int K_BLOCKED = 0; +constexpr static int K_ACK = 1; +constexpr static int K_WINDOW_COUNT = 2; +constexpr static int K_CONSUMER_ID = 3; + +class RemoteConsumerMap { +public: + RemoteConsumerMap() = default; + ~RemoteConsumerMap() = default; + + /** + * @brief Add remote consumer to the map. + * @param[in] streamName The stream name. + * @param[in] consumerId consumer id + * @param[in] windowCount tcp/ip window count + * @param[in] lastAckCursor starting cursor + * @return Status of the call. + */ + Status AddConsumer(const std::string &streamName, const std::string &consumerId, uint64_t windowCount, + uint64_t lastAckCursor); + + /** + * @brief Delete remote consumer from the map. + * @param[in] streamName The stream name. + * @param[in] consumerId The consumer id. + * @return Status of the call. + */ + Status DeleteConsumer(const std::string &streamName, const std::string &consumerId, Optional &lastConsumer); + + /** + * @brief Delete a stream entry + * @param streamName + * @return + */ + Status DeleteStream(const std::string &streamName, Optional &mapEmpty); + + /** + * @brief Enable/Disable blocking of a stream so no remote push is done. + * @param[in] streamName Target stream. + * @param[in] enable T/F + * @return Status of the call. + */ + Status ToggleStreamBlocking(const std::string &streamName, bool enable); + + /** + * @brief If the stream is blocked + * @param streamName + * @return T/F + */ + bool IsStreamSendBlocked(const std::string &streamName); + + /** + * @brief Check if there is any remote consumer for a given stream + * @return + */ + bool HasRemoteConsumers(const std::string &streamName); + + /** + * @brief Get the last ack cursor + * @param streamName + * @param cursor + * @return + */ + Status GetStreamLastAckCursor(const std::string &streamName, uint64_t &cursor); + + /** + * @brief Move up last ack cursor to K_NEXT value + * @param streamName + * @param nextAckCursor + * @return + */ + void SyncStreamLastAckCursor(RemoteStreamInfoTbbMap::accessor &accessor, Optional ackRange, + const std::string &logPrefix); + + /** + * @brief Identify whether have remote consumer. + * @return True if have no remote consumer. + */ + bool Empty() const; + + /** + * @brief Get max window count + */ + uint64_t GetMaxWindowCount(const std::string &streamName) const; + + Status GetAccessor(const std::string &streamName, RemoteStreamInfoTbbMap::accessor &accessor, + const std::string &logPrefix); + +private: + // key: streamName, value: dictionary of consumers on the remote node for corresponding stream. + RemoteStreamInfoTbbMap streamConsumers_; + mutable std::shared_timed_mutex consumerMutex_; // protect streamConsumers_. +}; + +class RemoteWorker { +public: + RemoteWorker(HostPort localAddress, HostPort remoteAddress, std::shared_ptr akSkManager, + ClientWorkerSCServiceImpl *scSvc, std::string &workerInstanceId, + std::shared_ptr scAllocateManager, RemoteWorkerManager *manager); + ~RemoteWorker(); + + /** + * @brief Init thread pool used by remote worker. + * @return Status of the call. + */ + Status Init(); + + /** + * @brief stream with streamName has remote consumer + * @param streamName + * @return T/F + */ + bool HasRemoteConsumers(const std::string &streamName); + + /** + * @brief Add one remote consumer for a stream on this remote worker node. + * @param[in] streamName Target stream. + * @param[in] subConfig Remote consumer's subscription config. + * @param[in] consumerId Remote consumer's id. + * @param[in] lastAckCursor Remote consumer's last ack cursor. + * @return Status of the call. + */ + Status AddRemoteConsumer(const std::string &streamName, const SubscriptionConfig &subConfig, + const std::string &consumerId, uint64_t windowCount, uint64_t lastAckCursor); + + /** + * @brief Del one remote consumer for a stream on this remote worker node. Search dict by streamName at first, then + * search the target remote consumer in dict by consumerId. + * @param[in] streamName Target stream. + * @param[in] consumerId Remote consumer's id. + * @return Status of the call. + */ + Status DelRemoteConsumer(const std::string &streamName, const std::string &consumerId, + Optional &lastConsumer); + + /** + * @brief Delete a remote stream + * @param streamName + * @return + */ + Status DeleteStream(const std::string &streamName, Optional &mapEmpty); + + /** + * @brief Call back function from BufferPool class + * @return Status object + */ + Status BatchAsyncFlushEntry(PendingFlushList &pendingFlushList); + + /** + * @brief Check whether exists remote consumer. + * @return True if exists. + */ + bool ExistsRemoteConsumer(); + + /** + * @brief Get log prefix + * @return The log prefix + */ + std::string LogPrefix() const; + + /** + * @brief Get the last ack cursor from the remote worker + * @param streamName + * @param cursor + * @return + */ + Status GetStreamLastAckCursor(const std::string &streamName, uint64_t &cursor); + + /** + * @brief Used by reset to stop sending to remote node. + */ + Status ProcessEndOfStream(const std::shared_ptr &streamMgr, std::list dataLst, + const std::string &streamName, const std::string &producerId); + + /** + * @brief Obtain the success rate of sending data to the remote worker manager. + */ + void RegisterRecordRemoteSendRateCallBack(std::function callBackFunction) + { + recordRemoteSendRate_ = callBackFunction; + } + + bool IsStreamSendBlocked(const std::string &streamName); + + /** + * @brief Get max window count + */ + uint64_t GetMaxWindowCount(const std::string &streamName) const; + + /** + * @brief Get the or create SharedPageQueue. + * @param[in] namespaceUri The stream name. + * @param[out] pageQueue The instance of SharedPageQueue. + */ + void GetOrCreateSharedPageQueue(const std::string &namespaceUri, std::shared_ptr &pageQueue); + +private: + friend class RemoteWorkerManager; + + const HostPort localWorkerAddr_; + const HostPort remoteWorkerAddr_; + + RemoteConsumerMap remoteConsumers_; + std::shared_ptr akSkManager_{ nullptr }; + ClientWorkerSCServiceImpl *scSvc_; + SharedPageQueueGroup sharedPageGroup_; + + /** + * @brief Record remote send success rate if callback function is exist. + * @param[in] success Success or not. + */ + void RecordRemoteSendSuccess(bool success) + { + const int totalNum = 1; + const int successNum = success ? 1 : 0; + if (recordRemoteSendRate_ != nullptr) { + recordRemoteSendRate_(successNum, totalNum); + } + } + + int BatchFlushAsyncWrite(const std::shared_ptr &stub, std::vector &requests, + std::vector> &payloads); + void BatchFlushAsyncRead(const std::shared_ptr &stub, + PendingFlushList &pendingFlushList, std::vector &requests, + std::unordered_map &raii); + void HandleBlockedElements(std::list> &moveList, + std::unordered_set> &needAckList); + Status ParseProducerPendingFlushList(const std::string &streamName, const std::string &producerId, + std::list &dataLst, std::vector &requests, + std::vector> &payloads, + std::unordered_map &raii, + std::list> &moveList, + std::unordered_set> &needAckList); + Status ParsePendingFlushList(const PendingFlushList &pendingFlushList, std::vector &requests, + std::vector> &payloads, + std::unordered_map &raii, + std::list> &moveList, + std::unordered_set> &needAckList); + Status FillExclusivePushReqHelper(const std::string &streamName, const std::string &producerId, uint64_t firstSeqNo, + std::list &dataLst, std::list::iterator &it, + uint64_t &firstCursor, uint64_t &lastCursor, PushReqPb &pushReqPb, + std::vector &elements, + std::unordered_map &raii); + Status FillSharedPushReqHelper(const std::string &producerId, std::list &dataLst, + std::list::iterator &it, uint64_t &firstCursor, uint64_t &lastCursor, + SharedPagePushReqPb &pushReqPb, std::vector &elements, + std::unordered_map &raii, + std::list> &moveList); + Status LockStreamManagerHelper(const std::string &streamName, std::unordered_map &raii); + void PostRecvCleanup(const std::string &streamName, const Status &status, PendingFlushList &pendingFlushList, + const PushReqPb &pushReq, const PushRspPb &pushRspPb, + std::unordered_map &raii); + void PostRecvCleanup(const std::string &streamName, const PushReqPb &rq, const PushRspPb &pushRspPb, + RemoteStreamInfoTbbMap::accessor &accessor, std::list &dataLst, + std::unordered_map &raii); + void PostRecvCleanup(const std::string &keyName, const Status &status, PendingFlushList &pendingFlushList, + const SharedPagePushReqPb &pushReq, const PushRspPb &pushRspPb, + std::unordered_map &raii); + void PostRecvCleanup(const std::string &keyName, const SharedPagePushReqPb &rq, const PushRspPb &pushRspPb, + RemoteStreamInfoTbbMap::accessor &accessor, std::list &dataLst, + std::unordered_map &raii); + + /** + * @brief Helper function to discard buffers when consumer does not exist. + * @param[in] dataLst list of PVs. + */ + static void DiscardBuffers(std::list &dataLst); + + /** + * @brief Helper function to discard one buffer. + * @param[in] dataLst list of PVs. + * @param[in/out] iter The iterator of dataLst. + */ + static void DiscardBufferHelper(std::list &dataLst, std::list::iterator &iter); + + void SyncStreamLastAckCursor(RemoteStreamInfoTbbMap::accessor &accessor, + Optional ackRange); + Status GetAccessor(const std::string &streamName, RemoteStreamInfoTbbMap::accessor &accessor); + + std::function recordRemoteSendRate_; + std::string workerInstanceId_; // unique id generated for each worker instance + RemoteWorkerManager *remoteWorkerManager_; +}; + +// The RemoteWorkerManager structure introduction. +// remoteWorkerDict_: +// Key: remote worker address, Value: RemoteWorker +// ================================== +// address0 -> RemoteWorker0 +// address1 -> RemoteWorker1 +// address2 -> RemoteWorker2 +// ... ... +// ... ... +// addressI -> RemoteWorkerI +// ================================== +// The RemoteWorker structure +// streamConsumers_: +// Key: stream name, Value: stream's consumer dict map +// ================================= +// stream0 -> ConsumerDict0 +// stream1 -> ConsumerDict1 +// ... ... +// streamJ -> ConsumerDictJ +// ================================= +// The ConsumerDict structure +// Key: consumer name, Value: pair(SubscriptionConfig, Consumer object) +// ===================================== +// consumerId0 -> (sub0, Consumer0) +// consumerId1 -> (sub1, Consumer1) +// ... ... +// consumerIdK -> (subK, ConsumerK) +// ===================================== + +class RemoteWorkerManager { +public: + explicit RemoteWorkerManager(ClientWorkerSCServiceImpl *scSvc, std::shared_ptr akSkManager, + std::shared_ptr scAllocateManager); + ~RemoteWorkerManager(); + + /** + * @brief Init thread pool used by remote worker manager. + * @return Status of the call. + */ + Status Init(); + + /** + * @brief Check if there are tasks to be processed + * @return T/F + */ + bool HaveTasksToProcess() + { + return dataMap_->HaveTasksToProcess(); + } + + /** + * @brief Add stream data in list of pending send data. + * @param[in] eleView The data to send to remote worker. + * @return Status of the call. + */ + Status SendElementsView(const std::shared_ptr &eleView); + + /** + * @brief Get the remote ack from remote workers + * @param streamName + * @return remote ack + */ + uint64_t GetLastAckCursor(const std::string &streamName); + + /** + * @brief Purge buffer from the given stream manager + * @param streamMgr + */ + void PurgeBuffer(const std::shared_ptr &streamMgr); + + /** + * @brief Remove the info of useless stream from BufferPool + * @param keyName The stream name or page name. + * @param sharedPageName The shared page name. Empty if the stream use exclusive page or the keyName is page. + */ + void RemoveStream(const std::string &keyName, const std::string &sharedPageName); + + /** + * @brief stream with streamName has remote consumer + * @param streamName + * @return T/F + */ + bool HasRemoteConsumers(const std::string &streamName); + + /** + * @brief Delete a remote stream from all workers. + * @param streamName + * @return + */ + Status DeleteStream(const std::string &streamName); + + /** + * @brief Stop to push new buffer into RW. + * @param streamName + * @return Status of the call. + */ + Status DoneScanning(const std::string &streamName); + + std::string GetSCRemoteSendSuccessRate(); + + /** + * @brief Del one remote consumer for a stream on this remote worker node. Search dict by streamName at first, then + * search the target remote consumer in dict by consumerId. + * @param[in] streamName Target stream. + * @param[in] consumerId Remote consumer's id. + * @return Status of the call. + */ + Status DelRemoteConsumer(const std::string &workerAddr, const std::string &streamName, + const std::string &consumerId); + + /** + * @brief Clear all remote consumer node for target stream on current worker node. + * Invoked when last producer disappears within current node. + * @param[in] forceClose Force close from master. + * @return K_OK on success; the error code otherwise. + */ + Status ClearAllRemoteConsumer(const std::string &streamName, bool forceClose); + + /** + * @brief Enable/Disable blocking of a stream so no remote push is done or resume. + * @param[in] workerAddr remote worker + * @param[in] streamName Target stream. + * @param[in] enable T/F + * @return Status of the call. + */ + Status ToggleStreamBlocking(const std::string &workerAddr, const std::string &streamName, bool enable); + + /** + * @brief Add one remote consumer for a stream on this remote worker node. + * @param[in] streamManager The stream manager ptr. + * @param[in] localWorkerAddress Local worker's address. + * @param[in] remoteWorkerAddress Target remote worker's address. + * @param[in] streamName Target stream. + * @param[in] subConfig Remote consumer's subscription config. + * @param[in] consumerId Remote consumer's id. + * @param[in] lastAckCursor Remote consumer's last ack cursor. + * @return Status of the call. + */ + Status AddRemoteConsumer(const std::shared_ptr &streamMgr, const HostPort &localWorkerAddress, + const HostPort &remoteWorkerAddress, const std::string &streamName, + const SubscriptionConfig &subConfig, const std::string &consumerId, + uint64_t lastAckCursor); + + /** + * @brief Reset scan position + * @param[in] streamName + * @return + */ + Status ResetStreamScanList(const std::string &streamName); + + /** + * @brief Convert from stream number to the corresponding stream name. + * @param[in] streamNo The stream number. + * @param[out] streamName The stream name. + * @return Status of the call. + */ + Status StreamNoToName(uint64_t streamNo, std::string &streamName); + +private: + /** + * @brief Call back function from BufferPool class + * @return Status object + */ + Status BatchAsyncFlushEntry(int myId, const PendingFlushList &pendingFlushList); + + /** + * @brief Used by reset to stop sending to remote node. + */ + Status ProcessEndOfStream(const std::shared_ptr &streamMgr, std::list dataLst, + const std::string &streamName, const std::string &producerId); + + /** + * @brief Get remote worker object. + * @param[in] address Remote worker's address. + * @param[out] remoteWorker Pointer to the remote worker object. + * @return Status of the call. + */ + Status GetRemoteWorker(const std::string &address, std::shared_ptr &remoteWorker); + + /** + * @brief Get a list of remote worker address from a given streamName + */ + std::vector GetRemoteWorkers(const std::string &streamName); + + MetricsScRemoteVector remoteSendRateVec_; + std::unordered_map> remoteWorkerDict_; + mutable std::shared_timed_mutex mutex_; + std::shared_ptr akSkManager_{ nullptr }; + ClientWorkerSCServiceImpl *scSvc_; + std::unique_ptr dataMap_; + std::unique_ptr dataPool_; // holds all the stream pages + std::string workerInstanceId_; // unique id generated for each worker instance + std::shared_ptr scAllocateManager_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif // DATASYSTEM_WORKER_STREAM_CACHE_REMOTE_WORKER_MANAGER_H diff --git a/src/datasystem/worker/stream_cache/stream_data_pool.cpp b/src/datasystem/worker/stream_cache/stream_data_pool.cpp new file mode 100644 index 0000000..4224a52 --- /dev/null +++ b/src/datasystem/worker/stream_cache/stream_data_pool.cpp @@ -0,0 +1,380 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Stream data page pool + */ +#include + +#include "datasystem/common/constants.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/request_counter.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/consumer.h" +#include "datasystem/worker/stream_cache/stream_data_pool.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +DS_DEFINE_int32(sc_scan_num_buckets, 1024, "Number of partitions for scanning streams"); +DS_DEFINE_int32(sc_scan_interval_ms, 10, "Scan interval for remote send. Default to 10ms"); +DS_DEFINE_int32(sc_scan_thread_num, 16, "Number of threads for scanning shared memory changes"); +DS_DEFINE_validator(sc_scan_thread_num, &Validator::ValidateThreadNum); + +namespace datasystem { +namespace worker { +namespace stream_cache { + +StreamDataPool::StreamDataPool() : interrupt_(false), numPartitions_(std::max(1, FLAGS_sc_scan_num_buckets)) +{ + const size_t MIN_THREADS = 1; + const size_t MAX_THREADS = std::max(1, FLAGS_sc_scan_thread_num); + threadPool_ = std::make_unique(MIN_THREADS, MAX_THREADS, "RemoteWorkerManager", true); + threadPool_->SetWarnLevel(ThreadPool::WarnLevel::LOW); + partitionList_.reserve(numPartitions_); + for (auto i = 0; i < numPartitions_; ++i) { + partitionList_.emplace_back(std::make_unique(i)); + } +} + +StreamDataPool::~StreamDataPool() +{ + Stop(); + if (scanner_.joinable()) { + scanner_.join(); + } + if (threadPool_) { + threadPool_.reset(); + } +} + +Status StreamDataPool::Init() +{ + RETURN_IF_EXCEPTION_OCCURS(scanner_ = Thread([this] { + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + ScanChanges(); + })); + return Status::OK(); +} + +void StreamDataPool::Stop() +{ + interrupt_ = true; + for (auto &part : partitionList_) { + part->interrupt_ = true; + } +} + +template +Status StreamDataPool::ObjectPartition::AddScanObject(const std::shared_ptr &streamObj, const std::string &keyName, + const std::vector &dest, uint64_t lastAckCursor, + std::unique_ptr &pool) +{ + RETURN_RUNTIME_ERROR_IF_NULL(streamObj); + LOG(INFO) << FormatString("[S:%s, P:%zu] Started adding Data object", keyName, myId_); + WriteLockHelper wlock(objMux_, [this, &keyName, funName = __FUNCTION__] { + return FormatString("S:%s P:%zu %s:%s", keyName, myId_, funName, __LINE__); + }); + auto it = objMap_.find(keyName); + if (it == objMap_.end()) { + auto future = std::make_unique>( + pool->Submit([this, keyName]() { return SendElementsToRemote(keyName); })); + std::shared_ptr scanInfo = std::make_shared(streamObj, lastAckCursor, dest, std::move(future)); + objMap_.emplace(keyName, std::static_pointer_cast(scanInfo)); + for (auto &rw : dest) { + LOG(INFO) << FormatString("[RW:%s, S:%s, P:%zu] Data object added to scan list", rw, keyName, myId_); + } + return Status::OK(); + } + LOG(INFO) << FormatString("[S:%s, P:%zu] Found in scan list", keyName, myId_); + // Update the new destination. Others remain unchanged. + auto &scanInfo = *(it->second); + scanInfo.dest_ = dest; + return Status::OK(); +} + +void StreamDataPool::ObjectPartition::ScanChanges(std::unique_ptr &pool) +{ + std::shared_lock rlock(objMux_); + if (objMap_.empty()) { + return; + } + Timer timer; + std::for_each(objMap_.begin(), objMap_.end(), [this, &pool](auto &kv) { + if (interrupt_) { + return; + } + const std::string streamName = kv.first; + auto &scanInfo = *(kv.second); + auto &fut = scanInfo.future_; + if (fut->valid()) { + if (fut->wait_for(std::chrono::seconds(0)) == std::future_status::ready) { + Status rc = fut->get(); + if (rc.IsError() && rc.GetCode() != K_NOT_FOUND && rc.GetCode() != K_TRY_AGAIN) { + LOG(INFO) << FormatString("[S:%s] Scan changes failed. %s", kv.first, rc.ToString()); + } + } else { + // Scan result not ready. + return; + } + } + // Submit a new one after some specified interval + auto start = scanInfo.start_; + auto now = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(now - start).count() >= FLAGS_sc_scan_interval_ms + && !interrupt_) { + scanInfo.start_ = now; + scanInfo.future_ = std::make_unique>( + pool->Submit([this, streamName]() { return SendElementsToRemote(streamName); })); + } + }); + const uint32_t intervalMs = 1000; + if (timer.ElapsedMilliSecond() > intervalMs) { + LOG(WARNING) << FormatString("[P:%zu] Data object map traversal takes %d ms for %d streams.", myId_, + timer.ElapsedMilliSecond(), objMap_.size()); + } +} + +Status StreamDataPool::ObjectPartition::RemoveScanObject(const std::string &streamName, + const std::vector &dest) +{ + LOG(INFO) << FormatString("[S:%s, P:%zu] Started removing Data object", streamName, myId_); + INJECT_POINT("StreamDataPool::ObjectPartition::RemoveStreamObject.sleep"); + WriteLockHelper wlock(objMux_, [this, &streamName, funName = __FUNCTION__] { + return FormatString("S:%s P:%zu %s:%s", streamName, myId_, funName, __LINE__); + }); + auto it = objMap_.find(streamName); + CHECK_FAIL_RETURN_STATUS(it != objMap_.end(), K_SC_STREAM_NOT_FOUND, + FormatString("Stream %s already not on scan list", streamName)); + // If there is no more remote worker, remove it from the scan list. + if (dest.empty()) { + LOG(INFO) << FormatString("[S:%s, P:%zu] Data object removed from scan list", streamName, myId_); + // We no longer scan this stream for newly added element + (void)objMap_.erase(it); + } else { + // Otherwise, update the destination + auto &scanInfo = *(it->second); + scanInfo.dest_ = dest; + for (auto &rw : dest) { + LOG(INFO) << FormatString("[RW:%s, S:%s, P:%zu] Data object updated in scan list", rw, streamName, myId_); + } + } + return Status::OK(); +} + +Status StreamDataPool::ObjectPartition::ResetStreamScanPosition(const std::string &streamName) +{ + // There is a thread which continuously scans the objects for changes + // We need to pause this thread. + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s, P:%zu] Started Reseting Scan Position", streamName, myId_); + WriteLockHelper wlock(objMux_, [this, &streamName, funName = __FUNCTION__] { + return FormatString("S:%s P:%zu %s:%s", streamName, myId_, funName, __LINE__); + }); + auto it = objMap_.find(streamName); + CHECK_FAIL_RETURN_STATUS(it != objMap_.end(), K_SC_STREAM_NOT_FOUND, + FormatString("Stream %s already not on scan list", streamName)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[S:%s, P:%zu] Reset data object found in scan list", streamName, myId_); + auto &scanInfo = *(it->second); + scanInfo.cursor_ = 0; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s, P:%zu] ResetStreamScanPosition Successful", streamName, myId_); + return Status::OK(); +} + +uint64_t StreamDataPool::GetPartId(const std::string &streamName) const +{ + return std::hash{}(streamName) % numPartitions_; +} + +Status StreamDataPool::AddStreamObject(std::shared_ptr streamMgr, const std::string &streamName, + const std::vector &dest, uint64_t lastAckCursor) +{ + auto partitionID = GetPartId(streamName); + auto &part = partitionList_[partitionID]; + // We aren't passing the accessor to the RWM. That is, we have a copy of + // stream manager but there is no lock protection. We will check again + // later at the RW layer + return part->AddScanObject(streamMgr, streamName, dest, lastAckCursor, threadPool_); +} + +Status StreamDataPool::AddSharedPageObject(std::shared_ptr sharedPageQueue, + const std::string &streamName, const std::vector &dest, + uint64_t lastAckCursor) +{ + auto queueId = sharedPageQueue->GetPageQueueId(); + auto partitionID = GetPartId(queueId); + auto &part = partitionList_[partitionID]; + RETURN_IF_NOT_OK((part->AddScanObject(sharedPageQueue, queueId, dest, + lastAckCursor, threadPool_))); + std::unique_lock xlock(queueIdMux_); + auto iter = queueIdMap_.find(queueId); + if (iter == queueIdMap_.end()) { + bool success; + std::tie(iter, success) = queueIdMap_.emplace(queueId, std::unordered_set()); + } + iter->second.emplace(streamName); + return Status::OK(); +} + +Status StreamDataPool::RemoveStreamObject(const std::string &streamName, const std::vector &dest) +{ + std::string queueId; + bool unregister = false; + { + std::unique_lock xlock(queueIdMux_); + for (auto pair : queueIdMap_) { + auto &streamNames = pair.second; + auto it = streamNames.find(streamName); + if (streamNames.find(streamName) != streamNames.end()) { + queueId = pair.first; + // When destination is empty, we no longer scan a stream. + // And then when all of the streams sharing the shared page queue are gone, + // the scan for shared page queue can be stopped. + if (dest.empty()) { + streamNames.erase(it); + } + if (streamNames.empty()) { + unregister = true; + } + break; + } + } + } + const std::string &keyName = unregister ? queueId : streamName; + auto partitionID = GetPartId(keyName); + return partitionList_[partitionID]->RemoveScanObject(keyName, dest); +} + +Status StreamDataPool::ResetStreamScanPosition(const std::string &streamName) +{ + auto partitionID = GetPartId(streamName); + return partitionList_[partitionID]->ResetStreamScanPosition(streamName); +} + +Status StreamDataPool::ObjectPartition::ScanChangesAndEval( + std::unordered_map>::iterator &iter) +{ + auto &scanInfo = *(iter->second); + auto &mux = *(scanInfo.mux_); + std::unique_lock xlock(mux); + std::shared_ptr pageQueue; + RETURN_IF_NOT_OK(scanInfo.GetPageQueue(pageQueue)); + uint64_t lastAckCursor = scanInfo.cursor_; + std::vector remoteWorkers = scanInfo.dest_; + const std::string keyName = iter->first; + ScanFlags flag = ScanFlags::PAGE_BREAK | ScanFlags::EVAL_BREAK; + const int timeoutMs = 10; + Status rc = pageQueue->ScanAndEval(lastAckCursor, timeoutMs, remoteWorkers, flag); + // All these errors are okay + // K_TRY_AGAIN is ok when we reach the last cursor on the last page. + // K_SC_END_OF_PAGE is okay when we do page break. + if (rc.IsOk() || rc.GetCode() == K_TRY_AGAIN || rc.GetCode() == K_SC_END_OF_PAGE) { + rc = Status::OK(); + INJECT_POINT("StreamDataPool.ScanChangesAndEval.delaywakeup", [&lastAckCursor](int lastAckReplace) { + lastAckCursor = lastAckReplace; + return Status::OK(); + }); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s, P:%zu] Scan position moves from %zu to %zu", keyName, myId_, + scanInfo.cursor_, lastAckCursor); + scanInfo.cursor_ = lastAckCursor; + } + return rc; +} + +Status StreamDataPool::ObjectPartition::SendElementsToRemote(const std::string &streamName) +{ + INJECT_POINT("StreamDataPool.SendElementsToRemote.wait"); + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s, P:%zu] Scan for changes ...", streamName, myId_); + ReadLockHelper rlock(objMux_, [this, &streamName, funName = __FUNCTION__] { + return FormatString("S:%s P:%zu %s:%s", streamName, myId_, funName, __LINE__); + }); + Timer timer; + auto iter = objMap_.find(streamName); + CHECK_FAIL_RETURN_STATUS(iter != objMap_.end(), K_SC_STREAM_NOT_FOUND, + FormatString("Stream %s not found", streamName)); + auto rc = ScanChangesAndEval(iter); + const uint32_t intervalMs = 1000; + if (timer.ElapsedMilliSecond() > intervalMs) { + LOG(WARNING) << FormatString("[S:%s, P:%zu] Scan for changes takes %d ms.", streamName, myId_, + timer.ElapsedMilliSecond()); + } + return rc; +} + +void StreamDataPool::ScanChanges() +{ + LOG(INFO) << "StreamDataPool scanner starts up"; + const int intervalMs = FLAGS_sc_scan_interval_ms; + while (true) { + for (auto &part : partitionList_) { + part->ScanChanges(threadPool_); + } + if (interrupt_) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(intervalMs)); + } +} + +StreamDataPool::ScanInfo::ScanInfo(uint64_t cursor, std::vector dest, + std::unique_ptr> future) + : cursor_(cursor), + dest_(std::move(dest)), + future_(std::move(future)), + mux_(std::make_unique()), + start_(std::chrono::high_resolution_clock::now()) +{ +} + +StreamDataPool::StreamScanInfo::StreamScanInfo(std::shared_ptr mgr, uint64_t cursor, + std::vector dest, + std::unique_ptr> future) + : ScanInfo(cursor, dest, std::move(future)), mgr_(mgr) +{ +} + +Status StreamDataPool::StreamScanInfo::GetPageQueue(std::shared_ptr &pageQueue) +{ + RETURN_RUNTIME_ERROR_IF_NULL(mgr_); + // If the stream is being deleted, move on + RETURN_IF_NOT_OK(mgr_->CheckIfStreamActive()); + if (mgr_->IsRetainData()) { + return { K_TRY_AGAIN, + FormatString("[S:%s] The expected num of consumers are not yet created.", mgr_->GetStreamName()) }; + } + pageQueue = std::static_pointer_cast(mgr_->GetExclusivePageQueue()); + return Status::OK(); +} + +StreamDataPool::SharedPageScanInfo::SharedPageScanInfo(std::shared_ptr sharedPageQueue, + uint64_t cursor, std::vector dest, + std::unique_ptr> future) + : ScanInfo(cursor, dest, std::move(future)), sharedPageQueue_(sharedPageQueue) +{ +} + +Status StreamDataPool::SharedPageScanInfo::GetPageQueue(std::shared_ptr &pageQueue) +{ + std::shared_ptr sharedPageQue = sharedPageQueue_.lock(); + RETURN_RUNTIME_ERROR_IF_NULL(sharedPageQue); + pageQueue = std::static_pointer_cast(sharedPageQue); + return Status::OK(); +} + +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/stream_data_pool.h b/src/datasystem/worker/stream_cache/stream_data_pool.h new file mode 100644 index 0000000..b45120b --- /dev/null +++ b/src/datasystem/worker/stream_cache/stream_data_pool.h @@ -0,0 +1,145 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Stream data page pool + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_STREAM_DATA_POOL_H +#define DATASYSTEM_WORKER_STREAM_CACHE_STREAM_DATA_POOL_H + +#include + +#include +#include "datasystem/common/rpc/rpc_server_stream_base.h" +#include "datasystem/common/util/bitmask_enum.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class ClientWorkerSCServiceImpl; +class StreamManager; + +class StreamDataPool { +public: + struct ScanInfo { + uint64_t cursor_; + std::vector dest_; + std::unique_ptr> future_; + std::unique_ptr mux_; + std::chrono::high_resolution_clock::time_point start_; + ScanInfo(uint64_t cursor, std::vector dest, std::unique_ptr> future); + virtual Status GetPageQueue(std::shared_ptr &pageQueue) = 0; + }; + struct StreamScanInfo : ScanInfo { + std::shared_ptr mgr_; + StreamScanInfo(std::shared_ptr mgr, uint64_t cursor, std::vector dest, + std::unique_ptr> future); + virtual Status GetPageQueue(std::shared_ptr &pageQueue) override; + }; + struct SharedPageScanInfo : ScanInfo { + std::weak_ptr sharedPageQueue_; + SharedPageScanInfo(std::shared_ptr sharedPageQueue, uint64_t cursor, + std::vector dest, std::unique_ptr> future); + virtual Status GetPageQueue(std::shared_ptr &pageQueue) override; + }; + StreamDataPool(); + ~StreamDataPool(); + + /** + * Initialization + * @return + */ + Status Init(); + + /** + * @brief Add a stream data object to scan list + * @param[in] streamMgr The stream manager of the actual stream, for exclusive page queue purpose. + * @param[in] streamName The stream name to scan. + * @param[in] dest The remote worker destination. + * @param[in] lastAckCursor The last ack cursor. + * @return Status of the call. + */ + Status AddStreamObject(std::shared_ptr streamMgr, const std::string &streamName, + const std::vector &dest, uint64_t lastAckCursor); + + /** + * @brief Add a shared page queue to scan list. + * @param[in] sharedPageQueue The shared page queue to scan. + * @param[in] streamName The actual stream name to scan. + * @param[in] dest The remote worker destination. + * @param[in] lastAckCursor The last ack cursor. + * @return Status of the call. + */ + Status AddSharedPageObject(std::shared_ptr sharedPageQueue, const std::string &streamName, + const std::vector &dest, uint64_t lastAckCursor); + + /** + * @brief Remove a stream object from scan list + * @param mgr + * @return + */ + Status RemoveStreamObject(const std::string &streamName, const std::vector &dest); + + /** + * @brief Reset the scan position + * @param streamName + * @return + */ + Status ResetStreamScanPosition(const std::string &streamName); + +private: + std::atomic interrupt_; + const int numPartitions_; + std::unique_ptr threadPool_; + Thread scanner_; + mutable std::shared_timed_mutex queueIdMux_; + std::unordered_map> queueIdMap_; + struct ObjectPartition { + uint64_t myId_; + std::atomic interrupt_; + mutable std::shared_timed_mutex objMux_; + // The key is stream name for the normal streams, destination for the merge-streams. + std::unordered_map> objMap_; + explicit ObjectPartition(uint64_t i) : myId_(i), interrupt_(false) + { + } + ~ObjectPartition() = default; + + template + Status AddScanObject(const std::shared_ptr &streamObj, const std::string &keyName, + const std::vector &dest, uint64_t lastAckCursor, + std::unique_ptr &pool); + Status RemoveScanObject(const std::string &streamName, const std::vector &dest); + Status ScanChangesAndEval(std::unordered_map>::iterator &iter); + Status ResetStreamScanPosition(const std::string &streamName); + Status SendElementsToRemote(const std::string &streamName); + void ScanChanges(std::unique_ptr &pool); + }; + std::vector> partitionList_; + + void ScanChanges(); + void Stop(); + uint64_t GetPartId(const std::string &streamName) const; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem + +#endif // DATASYSTEM_WORKER_STREAM_CACHE_STREAM_DATA_POOL_H diff --git a/src/datasystem/worker/stream_cache/stream_manager.cpp b/src/datasystem/worker/stream_cache/stream_manager.cpp new file mode 100644 index 0000000..395067c --- /dev/null +++ b/src/datasystem/worker/stream_cache/stream_manager.cpp @@ -0,0 +1,1638 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/stream_manager.h" + +#include + +#include "datasystem/common/eventloop/timer_queue.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/log/log_helper.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/common/rpc/rpc_unary_client_impl.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/lock_helper.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/protos/stream_posix.stub.rpc.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" +#include "datasystem/worker/stream_cache/page_queue/page_queue_handler.h" + +DS_DECLARE_string(sc_encrypt_secret_key); + +namespace datasystem { +namespace worker { +namespace stream_cache { +template Status StreamManager::HandleBlockedRequestImpl( + std::shared_ptr> &&blockedReq, bool lock); +template Status StreamManager::HandleBlockedRequestImpl( + std::shared_ptr> &&blockedReq, bool lock); + +StreamManager::StreamManager(std::string streamName, RemoteWorkerManager *remoteWorkerManager, + std::string localWorkerAddr, std::shared_ptr akSkManager, + std::weak_ptr scSvc, + std::shared_ptr manager, + std::weak_ptr workerWorkerSCService, uint64_t localStreamNum) + : workerAddr_(std::move(localWorkerAddr)), + streamName_(std::move(streamName)), + remoteWorkerManager_(remoteWorkerManager), + akSkManager_(std::move(akSkManager)), + scSvc_(std::move(scSvc)), + lastAckCursor_(0), + wakeupPendingRecvOnProdFault_(false), + scAllocateManager_(std::move(manager)), + workerWorkerSCService_(std::move(workerWorkerSCService)), + localStreamNum_(localStreamNum) +{ + ackWp_.Set(); + reclaimWp_.Set(); +} + +StreamManager::~StreamManager() +{ + // Update stream metrics final time before exit + if (scStreamMetrics_) { + UpdateStreamMetrics(); + } + // Remove stream number from the map at stream manager deletion. + auto scSvc = scSvc_.lock(); + if (scSvc != nullptr) { + scSvc->RemoveStreamNo(localStreamNum_); + // Explicitly undo the memory reservation at stream deletion. + scSvc->UndoReserveMemoryFromUsageMonitor(GetStreamName()); + } + // remove stream info in BufferPool + if (auto workerWorkerSCServicePtr = workerWorkerSCService_.lock()) { + workerWorkerSCServicePtr->RemoveStream(GetStreamName(), pageQueueHandler_->GetSharedPageQueueId()); + } + remoteWorkerManager_->RemoveStream(GetStreamName(), pageQueueHandler_->GetSharedPageQueueId()); +} + +Status StreamManager::CreatePageQueueHandler(Optional cfg) +{ + pageQueueHandler_ = std::make_unique(this, cfg); + // Reserve memory if not enable shared page. + // if enable shared page + // 1. The produer node will not reserve memory. + // 2. The consumer node will reserve memory after StremFields update in UpdateStreamFields. + if (cfg && !EnableSharedPage(cfg->streamMode_)) { + auto pageQueue = GetExclusivePageQueue(); + Status rc = pageQueue->ReserveStreamMemory(); + if (rc.IsOk()) { + LOG(INFO) << FormatString("[%s] %zu bytes of shared memory has been reserved", LogPrefix(), + pageQueue->GetReserveSize()); + } + // Page size unknown at this point is not an error. + if (rc.GetCode() == K_NOT_READY) { + rc = Status::OK(); + } + return rc; + } + return Status::OK(); +} + +void StreamManager::BlockMemoryReclaim() +{ + reclaimMutex_.lock_shared(); // To be unlocked by the caller, not by this function. + reclaimWp_.Wait(); +} + +void StreamManager::UnblockMemoryReclaim() +{ + reclaimMutex_.unlock(); +} + +Status StreamManager::AddCursorForProducer(const std::string &producerId, ShmView &shmView) +{ + std::shared_ptr cursor; + RETURN_IF_NOT_OK(pageQueueHandler_->AddCursor(producerId, true, cursor, shmView)); + bool needRollback = true; + Raii raii([this, &needRollback, &producerId]() { + if (needRollback) { + pageQueueHandler_->DeleteCursor(producerId); + } + }); + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto iter = pubs_.find(producerId); + CHECK_FAIL_RETURN_STATUS(iter != pubs_.end(), K_RUNTIME_ERROR, + FormatString("can not find producer[%s] when add cursor", producerId)); + needRollback = false; + iter->second->SetCursor(std::move(cursor)); + iter->second->SetElementCount(0); + return Status::OK(); +} + +Status StreamManager::AddProducer(const std::string &producerId, + DataVerificationHeader::SenderProducerNo &senderProducerNo) +{ + PerfPoint point(PerfKey::MANAGER_ADD_PRODUCER); + // Allocate a work area (in shared memory) to be shared between this worker and the client producer + + bool needRollback = true; + Raii raii([this, &needRollback, &producerId]() { + if (needRollback) { + pubs_.erase(producerId); + } + }); + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto ret = pubs_.emplace(producerId, std::make_shared(producerId, GetStreamName(), nullptr)); + CHECK_FAIL_RETURN_STATUS(ret.second, StatusCode::K_DUPLICATED, + FormatString("Failed to add new producer <%s> into streamManager", producerId)); + + // Assign the new producer with a locally unique number for data verification. + senderProducerNo = ++lifetimeLocalProducerCount_; + needRollback = false; + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumLocalProducers, pubs_.size()); + } + return Status::OK(); +} + +void StreamManager::ForceUnlockByCursor(const std::string &cursorId, bool isProducer, uint32_t lockId) +{ + if (pageQueueHandler_) { + pageQueueHandler_->ForceUnlockByCursor(cursorId, isProducer, lockId); + } +} + +void StreamManager::ForceUnlockMemViemForPages(uint32_t lockId) +{ + if (pageQueueHandler_) { + pageQueueHandler_->ForceUnlockMemViemForPages(lockId); + } +} + +Status StreamManager::CloseProducer(const std::string &producerId, bool forceClose) +{ + INJECT_POINT("StreamManager.CloseProducer.timing"); + std::shared_ptr producerPtr; + bool isLastProducer = false; + { + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto producer = pubs_.find(producerId); + CHECK_FAIL_RETURN_STATUS(producer != pubs_.end(), StatusCode::K_SC_PRODUCER_NOT_FOUND, + FormatString("Stream:<%s>, Producer:<%s> does not exist", streamName_, producerId)); + auto elementCount = producer->second->GetElementCountAndReset(); + LOG(INFO) << FormatString("[%s] Stream manager close producer: %s, element sent: %zu", LogPrefix(), producerId, + elementCount); + if (scStreamMetrics_) { + scStreamMetrics_->IncrementMetric(StreamMetric::NumTotalElementsSent, elementCount); + scStreamMetrics_->IncrementMetric(StreamMetric::NumSendRequests, + producer->second->GetRequestCountAndReset()); + } + pubs_.erase(producer); + RETURN_IF_NOT_OK_EXCEPT(pageQueueHandler_->DeleteCursor(producerId), K_NOT_FOUND); + // Process local ClearAllRemoteConsumer when it is the last producer on the worker for the stream. + // This is to replace the ClearAllRemoteConsumer RPC. + isLastProducer = pubs_.empty(); + if (isLastProducer) { + RETURN_IF_NOT_OK_EXCEPT(remoteWorkerManager_->ClearAllRemoteConsumer(streamName_, forceClose), + K_SC_STREAM_NOT_FOUND); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ClearAllRemoteConsumerUnlocked(forceClose), + "streamManager ClearAllRemoteConsumer failed"); + LOG(INFO) << "worker ClearAllRemoteConsumer done, streamname: " << streamName_; + } + } + if (pageQueueHandler_ && isLastProducer) { + // At the local ClearAllRemoteConsumer handling, we no longer flush through FlushAllChanges, + // but now still log the last append cursor at the last CloseProducer for diagnostic purposes. + RETURN_IF_NOT_OK(pageQueueHandler_->MoveUpLastPage()); + uint64_t lastAppendCursor = GetExclusivePageQueue()->GetLastAppendCursor(); + LOG(INFO) << FormatString("[S:%s] Last append cursor at %zu when producer %s close", streamName_, + lastAppendCursor, producerId); + if (!IsRetainData()) { + RETURN_IF_NOT_OK(EarlyReclaim()); + } + } + if (CheckIfStreamInState(StreamState::RESET_IN_PROGRESS)) { + std::unique_lock lock(resetMutex_); + std::vector prodList(1, producerId); + (void)RemovePubSubFromResetList(prodList); + } + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumLocalProducers, pubs_.size()); + } + return Status::OK(); +} + +Status StreamManager::AddConsumer(const SubscriptionConfig &config, const std::string &consumerId, + uint64_t &lastAckCursor, ShmView &waView) +{ + Raii resumeAck([this]() { ResumeAckThread(); }); + // We are adding a local consumer and on return we will set up a cursor to begin with. + // We also need to ensure the garbage collector thread is not purging the required page + // from memory. + std::shared_ptr cursor; + bool needRollback = true; + // Force update last page for last append cursor. + RETURN_IF_NOT_OK(pageQueueHandler_->MoveUpLastPage()); + RETURN_IF_NOT_OK(pageQueueHandler_->AddCursor(consumerId, false, cursor, waView)); + Raii raii([this, &needRollback, &consumerId]() { + if (needRollback) { + pageQueueHandler_->DeleteCursor(consumerId); + } + }); + // Trigger AckCursors to update last ack cursor. + RETURN_IF_NOT_OK(AckCursors()); + PauseAckThread(); + PerfPoint point(PerfKey::MANAGER_ADD_CONSUMER); + // Should the subscriber wakeup pending Receive() on getting notification about a Pub node crash/force close. + // Implicitly create subscription for the target stream. + RETURN_IF_NOT_OK(CreateSubscriptionIfMiss(config, lastAckCursor)); + // By now we can obtain the target subscription definitely. + std::shared_ptr subscription; + RETURN_IF_NOT_OK(GetSubscription(config.subscriptionName, subscription)); + RETURN_IF_NOT_OK(subscription->AddConsumer(config, consumerId, lastAckCursor, cursor)); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumLocalConsumers, subs_.size()); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, C:%s] AddConsumer success, lastAckCursor:%llu", LogPrefix(), + consumerId, lastAckCursor); + needRollback = false; + return Status::OK(); +} + +Status StreamManager::CloseConsumer(const std::string &subName, const std::string &consumerId) +{ + std::shared_ptr subPtr; + RETURN_IF_NOT_OK(GetSubscription(subName, subPtr)); + CHECK_FAIL_RETURN_STATUS(subPtr != nullptr, StatusCode::K_RUNTIME_ERROR, + "Failed to get stream by name: " + subName); + PerfPoint point(PerfKey::MANAGER_CLOSE_CONSUMER); + if (scStreamMetrics_) { + scStreamMetrics_->IncrementMetric(StreamMetric::NumReceiveRequests, subPtr->GetRequestCountAndReset()); + } + RETURN_IF_NOT_OK(subPtr->RemoveConsumer(consumerId)); + RETURN_IF_NOT_OK_EXCEPT(pageQueueHandler_->DeleteCursor(consumerId), K_NOT_FOUND); + bool isLastConsumer = false; + if (!subPtr->HasConsumer()) { // If this subscription has no consumer, we delete it from subs_ hash map. + { + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, Sub:%s] Delete sub due to no consumer inside it", + LogPrefix(), subName); + CHECK_FAIL_RETURN_STATUS( + subs_.erase(subName) == 1, StatusCode::K_SC_CONSUMER_NOT_FOUND, + FormatString("Consumer <%s> does not exist in Subscription <%s>", consumerId, subName)); + isLastConsumer = subs_.empty(); + if (isLastConsumer) { + // Early reclaim of local cache memory reservation when consumers are all closed. + auto scSvc = scSvc_.lock(); + if (scSvc != nullptr) { + scSvc->UndoReserveMemoryFromUsageMonitor(GetStreamName()); + } + // If this stream has no more consumers clear all remote pubs. + ClearAllRemotePubUnlocked(); + } + } + if (CheckIfStreamInState(StreamState::RESET_IN_PROGRESS)) { + std::unique_lock lock(resetMutex_); + std::vector conList(1, consumerId); + (void)RemovePubSubFromResetList(conList); + } + } + point.Record(); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumLocalConsumers, subs_.size()); + } + if (isLastConsumer && !IsRetainData()) { + RETURN_IF_NOT_OK(EarlyReclaim()); + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, Sub:%s, C:%s] CloseConsumer success.", LogPrefix(), subName, + consumerId); + return Status::OK(); +} + +Status StreamManager::CheckDeleteStreamCondition() +{ + PerfPoint point(PerfKey::MANAGER_DELETE_STREAM); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + if (pubs_.empty() && subs_.empty() && remotePubWorkerDict_.empty() && remoteSubWorkerDict_.empty()) { + return Status::OK(); + } + if (!pubs_.empty() || !subs_.empty()) { + LOG(ERROR) << "Not allowed to delete stream, pub count:" << pubs_.size() << ", sub count:" << subs_.size(); + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, "Not allowed to delete stream when producer/consumer is running."); + } + if (!remotePubWorkerDict_.empty()) { + RETURN_STATUS(StatusCode::K_RUNTIME_ERROR, + FormatString("Not allowed to delete stream when remote producer is running.\nList: [%s]", + VectorToString(remotePubWorkerDict_))); + } + if (!remoteSubWorkerDict_.empty()) { + std::stringstream ss; + for (const auto &entry : remoteSubWorkerDict_) { + ss << entry.first << " "; + } + RETURN_STATUS( + StatusCode::K_RUNTIME_ERROR, + FormatString( + "Not allowed to delete stream when remote consumer is running\nList: [%s]\n Possibility:\n1. Remote " + "Consumer not closed yet\n2. Sending data on local node to remote consumer.", + ss.str())); + } + point.Record(); + return Status::OK(); +} + +Status StreamManager::AllocDataPage(BlockedCreateRequest *blockedReq) +{ + auto req = blockedReq->GetCreateRequest(); + const auto &producerId = req.producer_id(); + ShmView curView = { .fd = req.cur_view().fd(), + .mmapSz = req.cur_view().mmap_size(), + .off = static_cast(req.cur_view().offset()), + .sz = req.cur_view().size() }; + std::shared_ptr lastPage; + RETURN_IF_NOT_OK(CreateOrGetLastDataPage(producerId, RPC_TIMEOUT, curView, lastPage, false)); + CreateShmPageRspPb &rsp = blockedReq->rsp_; + ShmView shmView = lastPage->GetShmView(); + ShmViewPb pb; + pb.set_fd(shmView.fd); + pb.set_mmap_size(shmView.mmapSz); + pb.set_offset(shmView.off); + pb.set_size(shmView.sz); + rsp.mutable_last_page_view()->CopyFrom(pb); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(blockedReq->Write(), "Write reply to client stream failed"); + const int logPerCount = VLOG_IS_ON(SC_NORMAL_LOG_LEVEL) ? 1 : 1000; + LOG_EVERY_N(INFO, logPerCount) << FormatString( + "[%s, P:%s] CreateShmPage success. ProdId: %s, PageId: %s. Retry count: %zu", LogPrefix(), producerId, + req.producer_id(), lastPage->GetPageId(), blockedReq->retryCount_.load()); + return Status::OK(); +} + +Status StreamManager::AllocDataPageInternalReq(uint64_t timeoutMs, const ShmView &curView, ShmView &outView) +{ + CreateShmPageReqPb req; + req.set_stream_name(streamName_); + // We need to fake a producer id as a unique key into MemAllocRequestList + auto producerId = GetStringUuid(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, P:%s] Send an internal AllocDataPage request for size %zu.", + LogPrefix(), producerId, GetStreamPageSize()); + req.set_producer_id(producerId); + ShmViewPb pb; + pb.set_fd(curView.fd); + pb.set_mmap_size(curView.mmapSz); + pb.set_offset(curView.off); + pb.set_size(curView.sz); + req.mutable_cur_view()->CopyFrom(pb); + + auto fn = std::bind(&StreamManager::AllocDataPage, shared_from_this(), std::placeholders::_1); + auto blockedReq = std::make_shared>( + streamName_, req, GetStreamPageSize(), nullptr, fn); + // Lock to compete with StreamManager::UnblockProducers + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + RETURN_IF_NOT_OK(AddBlockedCreateRequest(scSvc.get(), blockedReq, true)); + scSvc->AsyncSendMemReq(streamName_); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, P:%s] Wait for internal AllocDataPage reply.", LogPrefix(), + blockedReq->req_.producer_id()); + RETURN_IF_NOT_OK(blockedReq->Wait(timeoutMs)); + CreateShmPageRspPb &rsp = blockedReq->rsp_; + outView.off = static_cast(rsp.last_page_view().offset()); + outView.sz = rsp.last_page_view().size(); + outView.mmapSz = rsp.last_page_view().mmap_size(); + outView.fd = rsp.last_page_view().fd(); + return Status::OK(); +} + +Status StreamManager::CreateOrGetLastDataPage(const std::string &producerId, uint64_t timeoutMs, + const ShmView &lastView, std::shared_ptr &lastPage, + bool retryOnOOM) +{ + PerfPoint point(PerfKey::MANAGER_CREATE_STREAM_PAGE); + // Create new page or return existing one + RETURN_IF_NOT_OK(pageQueueHandler_->CreateOrGetLastDataPage(timeoutMs, lastView, lastPage, retryOnOOM)); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, P:%s] LastPage %s", LogPrefix(), producerId, + lastPage->GetPageId()); + // Notify a new page has been created + TryWakeUpPendingReceive(); + return Status::OK(); +} + +Status StreamManager::AllocBigShmMemory(BlockedCreateRequest *blockedReq) +{ + auto req = blockedReq->GetCreateRequest(); + const auto &producerId = req.producer_id(); + std::shared_ptr pageUnitInfo; + size_t pageSize = req.page_size(); + Status allocRc = pageQueueHandler_->AllocMemory(pageSize, true, pageUnitInfo, false); + if (allocRc.GetCode() == K_OUT_OF_MEMORY) { + LOG_IF_ERROR(pageQueueHandler_->ReclaimAckedChain(blockedReq->req_.sub_timeout()), "Reclaim ack chain error"); + if (!CheckHadEnoughMem(pageSize)) { + pageQueueHandler_->DumpPoolPages(FLAGS_v); + } + } + RETURN_IF_NOT_OK(allocRc); + CHECK_FAIL_RETURN_STATUS(pageUnitInfo != nullptr, K_RUNTIME_ERROR, "pageUnitInfo is nullptr"); + LOG(INFO) << FormatString("[%s, P:%s] AllocBigShmMemory success.", LogPrefix(), producerId); + // From now on make sure we free the memory on error exit + bool needRollback = true; + Raii raii([this, &producerId, &needRollback, &pageUnitInfo]() { + if (needRollback) { + ShmView v{ .fd = pageUnitInfo->fd, + .mmapSz = pageUnitInfo->mmapSize, + .off = pageUnitInfo->offset, + .sz = pageUnitInfo->size }; + LOG(INFO) << FormatString("[%s, P:%s] Undo previous AllocBigShmMemory", LogPrefix(), producerId); + (void)pageQueueHandler_->ReleaseMemory(v); + } + }); + ShmViewPb pb; + pb.set_fd(pageUnitInfo->fd); + pb.set_mmap_size(pageUnitInfo->mmapSize); + pb.set_offset(pageUnitInfo->offset); + pb.set_size(pageUnitInfo->size); + CreateLobPageRspPb &rsp = blockedReq->rsp_; + rsp.mutable_page_view()->CopyFrom(pb); + RETURN_IF_NOT_OK(blockedReq->Write()); + // For internal request, we need to coordinate with the caller because we don't chain big element like + // data page. If the requester has gone (due to timeout), the memory is stale. It is hard to check + // for rpc requester, but we can check that for internal requester. + INJECT_POINT("StreamManager.AllocBigShmMemory.NoHandShake1", [&needRollback]() { + needRollback = false; + return Status::OK(); + }); + INJECT_POINT("StreamManager.AllocBigShmMemory.NoHandShake2"); + RETURN_IF_NOT_OK(blockedReq->SenderHandShake()); + needRollback = false; + return Status::OK(); +} + +Status StreamManager::AllocBigShmMemoryInternalReq(uint64_t timeoutMs, size_t sz, ShmView &outView) +{ + CreateLobPageReqPb req; + req.set_stream_name(streamName_); + // We need to fake a producer id as a unique key into MemAllocRequestList + auto producerId = GetStringUuid(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, P:%s] Send an internal AllocBigShmMemory request for size %zu.", + LogPrefix(), producerId, sz); + req.set_producer_id(producerId); + req.set_page_size(sz); + auto fn = std::bind(&StreamManager::AllocBigShmMemory, shared_from_this(), std::placeholders::_1); + auto blockedReq = std::make_shared>(streamName_, req, + sz, nullptr, fn); + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + RETURN_IF_NOT_OK(AddBlockedCreateRequest(scSvc.get(), blockedReq, true)); + scSvc->AsyncSendMemReq(streamName_); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, P:%s] Wait for internal AllocBigShmMemory reply.", LogPrefix(), + blockedReq->req_.producer_id()); + auto waitTime = [timeoutMs]() { + INJECT_POINT("StreamManager.AllocBigShmMemoryInternalReq.SetTimeoutMs", [](uint64_t val) { return val; }); + return timeoutMs; + }; + INJECT_POINT("StreamManager.AllocBigShmMemoryInternalReq.sleep"); + RETURN_IF_NOT_OK(blockedReq->Wait(waitTime())); + // Handshake with the MemPool thread + INJECT_POINT("StreamManager.AllocBigShmMemoryInternalReq.NoHandShake"); + RETURN_IF_NOT_OK(blockedReq->ReceiverHandShake()); + CreateLobPageRspPb &rsp = blockedReq->rsp_; + outView.off = static_cast(rsp.page_view().offset()); + outView.sz = rsp.page_view().size(); + outView.mmapSz = rsp.page_view().mmap_size(); + outView.fd = rsp.page_view().fd(); + return Status::OK(); +} + +Status StreamManager::ReleaseBigShmMemory( + const std::shared_ptr> &serverApi, + const ReleaseLobPageReqPb &req) +{ + ShmView v; + v.fd = req.page_view().fd(); + v.mmapSz = req.page_view().mmap_size(); + v.off = static_cast(req.page_view().offset()); + v.sz = req.page_view().size(); + Status rc = pageQueueHandler_->ReleaseMemory(v); + if (rc.IsError()) { + return serverApi->SendStatus(rc); + } + ReleaseLobPageRspPb rsp; + return serverApi->Write(rsp); +} + +Status StreamManager::AddBlockedCreateRequest( + ClientWorkerSCServiceImpl *scSvc, + std::shared_ptr> blockedReq, bool lock) +{ + // Compete with UnblockCreators + std::shared_lock rlock(streamManagerBlockedListsMutex_, std::defer_lock); + if (lock) { + INJECT_POINT("StreamManager.AddBlockCreateRequest.sleep"); + rlock.lock(); + } + return dataBlockedList_.AddBlockedCreateRequest(scSvc, std::move(blockedReq)); +} + +Status StreamManager::AddBlockedCreateRequest( + ClientWorkerSCServiceImpl *scSvc, + std::shared_ptr> blockedReq, bool lock) +{ + // Compete with UnblockCreators + std::shared_lock rlock(streamManagerBlockedListsMutex_, std::defer_lock); + if (lock) { + rlock.lock(); + } + return lobBlockedList_.AddBlockedCreateRequest(scSvc, std::move(blockedReq)); +} + +Status StreamManager::GetBlockedCreateRequest( + std::shared_ptr> &blockedReq) +{ + return dataBlockedList_.GetBlockedCreateRequest(blockedReq); +} + +Status StreamManager::GetBlockedCreateRequest( + std::shared_ptr> &blockedReq) +{ + return lobBlockedList_.GetBlockedCreateRequest(blockedReq); +} + +Status StreamManager::UnblockCreators() +{ + // We want to clear as much as backlog as possible. + // At the same time, block new requests coming in. + + // We will handle BigElement first. + if (!lobBlockedList_.Empty()) { + // Block AddBlockedCreateRequest + std::unique_lock xlock(streamManagerBlockedListsMutex_); + LOG(INFO) << FormatString("[%s] Freed page result in unblocking a waiting AllocBigShmMemory.", LogPrefix()); + std::shared_ptr> blockedReq; + RETURN_IF_NOT_OK_EXCEPT(lobBlockedList_.GetBlockedCreateRequest(blockedReq), K_TRY_AGAIN); + if (blockedReq) { + // To avoid deadlock with itself, don't lock the streamManagerBlockedListsMutex_ again + RETURN_IF_NOT_OK_EXCEPT(HandleBlockedRequestImpl(std::move(blockedReq), false), K_OUT_OF_MEMORY); + } + } + // Next we handle regular page + if (!dataBlockedList_.Empty()) { + // Block AddBlockedCreateRequest + INJECT_POINT("UnblockCreators.sleep"); + std::unique_lock xlock(streamManagerBlockedListsMutex_); + LOG(INFO) << FormatString("[%s] Freed page result in unblocking a waiting CreateShmPage.", LogPrefix()); + // Because producers are sharing pages, we will need to keep popping. + Status rc; + while (rc.IsOk()) { + std::shared_ptr> blockedReq; + rc = dataBlockedList_.GetBlockedCreateRequest(blockedReq); + if (rc.IsOk()) { + // To avoid deadlock with itself, don't lock the streamManagerBlockedListsMutex_ again + rc = HandleBlockedRequestImpl(std::move(blockedReq), false); + } + if (rc.GetCode() == K_OUT_OF_MEMORY) { + break; + } + RETURN_IF_NOT_OK_EXCEPT(rc, K_TRY_AGAIN); + } + } + return Status::OK(); +} + +std::pair StreamManager::GetNextBlockedRequestSize() +{ + // Block AddBlockedCreateRequest + std::unique_lock xlock(streamManagerBlockedListsMutex_); + // We have two lists, and we will look at the oldest one. + if (lobBlockedList_.GetNextStartTime() < dataBlockedList_.GetNextStartTime()) { + return std::make_pair(lobBlockedList_.GetNextBlockedRequestSize(), true); + } + return std::make_pair(dataBlockedList_.GetNextBlockedRequestSize(), false); +} + +template +Status StreamManager::HandleBlockedRequestImpl(std::shared_ptr> &&blockedReq, + bool lockBeforeAdd) +{ + Status rc; + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(blockedReq->traceId_); + auto retryCount = blockedReq->retryCount_.load(std::memory_order_relaxed); + auto req = blockedReq->GetCreateRequest(); + const auto &producerId = req.producer_id(); + auto subTimeout = blockedReq->GetRemainingTimeMs(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, P:%s] Allocating shared memory. subTimeout: %zu", LogPrefix(), + producerId, subTimeout); + // Check for stream state. If CreatePage goes to sleep for OOM, + // it will return after waking up and seeing reset is going on. + rc = CheckIfStreamActive(); + if (rc.IsError()) { + return blockedReq->SendStatus(rc); + } + // The launch of this thread and creation of stream manager may have used up some time. + // If this elapsed time was more than the initial sub-time, then return now with timeout error. + if (retryCount > 0) { + RETURN_IF_NOT_OK(blockedReq->HandleBlockedCreateTimeout()); + } + // Invoke the call back to allocate the memory + rc = (*blockedReq)(); + // Refresh how much time left + subTimeout = blockedReq->GetRemainingTimeMs(); + INJECT_POINT("HandleBlockedRequestImpl.subTimeout", [&subTimeout]() mutable { + subTimeout = 0; + return Status::OK(); + }); + if (rc.GetCode() == K_OUT_OF_MEMORY && req.sub_timeout() > 0 && subTimeout > 0) { + // Add this request to a queue of blocked requests and then return. + // The client will remain waiting until a timer unblocks or another event (free pages) executes the request + // and returns to the client. + LOG(INFO) << FormatString( + "OOM. retry a blocked request to the blocked queue for stream %s with producer %s and new timeout %zu. " + "retry count %zu", + streamName_, producerId, subTimeout, retryCount); + + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + + Status blocked_rc = AddBlockedCreateRequest(scSvc.get(), std::move(blockedReq), lockBeforeAdd); + // Log error if we can not block and return original OOM error to the user + LOG_IF_ERROR(blocked_rc, "error while producer blocking"); + if (blocked_rc.IsError()) { + return blockedReq->SendStatus(rc); + } + // Return OOM back to the caller so the caller can distinguish between a successful retry vs + // a failed-but-requeue retry + return rc; + } + if (rc.IsError()) { + return blockedReq->SendStatus(rc); + } + return Status::OK(); +} + +Status StreamManager::GetDataPage( + const GetDataPageReqPb &req, const std::shared_ptr &sub, + const std::shared_ptr> &serverApi) +{ + const auto &consumerId = req.consumer_id(); + CHECK_FAIL_RETURN_STATUS(sub->GetSubscriptionType() == SubscriptionType::STREAM, StatusCode::K_INVALID, + "Only support STREAM mode."); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(sub->GetConsumer(consumerId, consumer)); + RETURN_IF_NOT_OK(GetExclusivePageQueue()->GetDataPage(req, consumer, serverApi)); + return Status::OK(); +} + +void StreamManager::TryWakeUpPendingReceive() +{ + if (pageQueueHandler_->ExistsSharedPageQueue()) { + return; + } + // Consumer node using exclusive page. + uint64_t lastCursor = GetExclusivePageQueue()->GetLastAppendCursor(); + PerfPoint point(PerfKey::MANAGER_TRY_WAKE_UP_RECV_GET_LOCK); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + point.Record(); + PerfPoint point1(PerfKey::MANAGER_TRY_WAKE_UP_RECV_LOGIC); + for (const auto &sub : subs_) { + auto status = sub.second->TryWakeUpPendingReceive(lastCursor); + if (status.IsError()) { + LOG(WARNING) << "Failed to wake up pending recv for sub:" << sub.first << ", " << status.ToString(); + } + } +} + +uint64_t StreamManager::UpdateLastAckCursorUnlocked(uint64_t minSubsAckCursor) +{ + if (pageQueueHandler_ == nullptr || IsRetainData()) { + return 0; + } + bool success = false; + do { + uint64_t val = lastAckCursor_.load(); + // Stream's lastAckCursor = min{sub0 lastAckCursor, sub1 lastAckCursor,..., subN lastAckCursor}. + for (const auto &sub : subs_) { + const auto &lastSubAck = sub.second->UpdateLastAckCursor(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[Stream %s, subscription %s] lastAckCursor = %zu", streamName_, + sub.first, lastSubAck); + minSubsAckCursor = std::min(minSubsAckCursor, lastSubAck); + } + // Also go through all remote consumers. We may in the process of sending elements + // to the remote worker. + if (!remoteSubWorkerDict_.empty() || remoteWorkerManager_->HasRemoteConsumers(streamName_)) { + auto remoteAckCursor = remoteWorkerManager_->GetLastAckCursor(streamName_); + minSubsAckCursor = std::min(minSubsAckCursor, remoteAckCursor); + } + if (minSubsAckCursor > val) { + INJECT_POINT_NO_RETURN("UpdateLastAckCursorUnlocked.sleep"); + success = lastAckCursor_.compare_exchange_strong(val, minSubsAckCursor); + if (success) { + LOG(INFO) << FormatString("[%s] The last ack of stream update from %zu to %zu", LogPrefix(), val, + minSubsAckCursor); + return minSubsAckCursor; + } + } else { + return minSubsAckCursor; + } + } while (true); +} + +Status StreamManager::RemoteAck() +{ + auto lastAppendCursor = GetLastAppendCursor(); + uint64_t newAckCursor = 0; + { + INJECT_POINT("StreamManager.RemoteAck.delay"); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + RETURN_OK_IF_TRUE(pageQueueHandler_ == nullptr); + // If local consumer exists, leave the ack to be done by the ack thread itself. + RETURN_OK_IF_TRUE(!subs_.empty()); + newAckCursor = UpdateLastAckCursorUnlocked(lastAppendCursor); + // Early release of the lock since StreamManager::mutex_ is mostly to protect pubs and subs structures. + } + RETURN_IF_NOT_OK(GetExclusivePageQueue()->Ack(newAckCursor)); + EarlyReclaim(true, lastAppendCursor, newAckCursor); + return Status::OK(); +} + +Status StreamManager::AckCursors() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + ackWp_.Wait(); + RETURN_IF_NOT_OK(CheckIfStreamActive()); + RETURN_OK_IF_TRUE(pageQueueHandler_ == nullptr); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] GC starts", LogPrefix()); + auto lastAppendCursor = GetLastAppendCursor(); + uint64_t newAckCursor; + { + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + INJECT_POINT("StreamManager.AckCursors.delay"); + newAckCursor = UpdateLastAckCursorUnlocked(lastAppendCursor); + } + RETURN_IF_NOT_OK(GetExclusivePageQueue()->Ack(newAckCursor, GetStreamMetaShm())); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] GC ends", LogPrefix()); + return Status::OK(); +} + +Status StreamManager::AddRemotePubNode(const std::string &pubWorkerAddr) +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto ret = remotePubWorkerDict_.emplace(pubWorkerAddr); + CHECK_FAIL_RETURN_STATUS(ret.second, StatusCode::K_DUPLICATED, + "One remote pub node can only make one one-time broadcast to all sub nodes"); + if (scStreamMetrics_) { + scStreamMetrics_->IncrementMetric(StreamMetric::NumRemoteProducers, 1); + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s], Add remote pub node <%s> success", LogPrefix(), pubWorkerAddr); + return Status::OK(); +} + +Status StreamManager::HandleClosedRemotePubNode(bool forceClose) +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + // Do not actually remove the remote publisher. Even if the remote side has closed and the master has given this + // notification, there is likely many elements still pending receive. They cannot be received unless there is a + // remote publisher instance. The remote pubs will be cleaned up later when the stream is removed. + Status rc; + if (forceClose && wakeupPendingRecvOnProdFault_) { + for (const auto &sub : subs_) { + Status rc1 = sub.second->SetForceClose(); + if (rc.IsOk()) { + rc = rc1; + } + } + } + return rc; +} + +Status StreamManager::AddRemoteSubNode(const HostPort &subWorker, const SubscriptionConfig &subConfig, + const std::string &consumerId, uint64_t &lastAckCursor) +{ + Raii resumeAck([this]() { ResumeAckThread(); }); + // We are adding a remote consumer and on return we will set up a cursor to begin with. + // We also need to ensure the garbage collector thread is not purging the required page + // from memory. + PauseAckThread(); + { + auto lastAppendCursor = GetLastAppendCursor(); + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + lastAckCursor = UpdateLastAckCursorUnlocked(lastAppendCursor); + // If a new worker node, we add it into remote subscription dict. + const auto &subWorkerHost = subWorker.ToString(); + auto iter = remoteSubWorkerDict_.find(subWorkerHost); + if (iter == remoteSubWorkerDict_.end()) { + bool success; + std::tie(iter, success) = + remoteSubWorkerDict_.emplace(subWorkerHost, std::make_shared(subWorker)); + } + RETURN_IF_NOT_OK(iter->second->AddConsumer(subConfig, consumerId)); + } + if (scStreamMetrics_) { + scStreamMetrics_->IncrementMetric(StreamMetric::NumRemoteConsumers, 1); + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, RW:%s, C:%s], Add remote consumer succeeded", LogPrefix(), + subWorker.ToString(), consumerId); + return Status::OK(); +} + +Status StreamManager::DelRemoteSubNode(const HostPort &subWorker, const std::string &consumerId) +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, RW:%s, C:%s], Delete remote consumer begin", LogPrefix(), + subWorker.ToString(), consumerId); + // key: subWorker value: SubWorkerDesc object. So if not exist, we raise runtime error. + auto iter = remoteSubWorkerDict_.find(subWorker.ToString()); + CHECK_FAIL_RETURN_STATUS(iter != remoteSubWorkerDict_.end(), StatusCode::K_NOT_FOUND, + FormatString("[%s]-[%s] Remote Sub node:<%s> not exist on worker:<%s>'s remoteSubDict", + streamName_, consumerId, subWorker.ToString(), workerAddr_)); + + RETURN_IF_NOT_OK(iter->second->DelConsumer(consumerId)); + if (iter->second->ConsumerNum() == 0) { + (void)remoteSubWorkerDict_.erase(iter); + } + if (scStreamMetrics_) { + scStreamMetrics_->DecrementMetric(StreamMetric::NumRemoteConsumers, 1); + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, RW:%s, C:%s], Delete remote consumer succeeded", LogPrefix(), + subWorker.ToString(), consumerId); + return Status::OK(); +} + +Status StreamManager::SyncSubTable(const std::vector &subTable, bool isRecon, uint64_t &lastAckCursor) +{ + Raii resumeAck([this]() { ResumeAckThread(); }); + // We are adding a remote consumer and on return we will set up a cursor to begin with. + // We also need to ensure the garbage collector thread is not purging the required page + // from memory. + PauseAckThread(); + { + auto lastAppendCursor = GetLastAppendCursor(); + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + lastAckCursor = UpdateLastAckCursorUnlocked(lastAppendCursor); + // Definition: ConsumerMeta = (consumerId_, workerAddress_, subConfig_, lastAckCursor_). + if (!isRecon) { + remoteSubWorkerDict_.clear(); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumRemoteConsumers, 0); + } + } + for (const auto &sub : subTable) { + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s, RW:%s, C:%s]", LogPrefix(), + sub.WorkerAddress().ToString(), sub.ConsumerId()); + auto iter = remoteSubWorkerDict_.find(sub.WorkerAddress().ToString()); + if (iter == remoteSubWorkerDict_.end()) { + auto newSubWorkerDesc = std::make_shared(sub.WorkerAddress()); + iter = remoteSubWorkerDict_.emplace(sub.WorkerAddress().ToString(), std::move(newSubWorkerDesc)).first; + } + auto &remoteWorkerDesc = iter->second; + RETURN_IF_NOT_OK(remoteWorkerDesc->AddConsumer(sub.SubConfig(), sub.ConsumerId())); + if (scStreamMetrics_) { + scStreamMetrics_->IncrementMetric(StreamMetric::NumRemoteConsumers, 1); + } + } + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] SyncSubTable success, table size:%zu", LogPrefix(), + remoteSubWorkerDict_.size()); + return Status::OK(); +} + +Status StreamManager::SyncPubTable(const std::vector &pubTable, bool isRecon) +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + // Always clear the elder remote dict if we get a new first consumer on local node. + if (!remotePubWorkerDict_.empty() && !isRecon) { + RETURN_STATUS_LOG_ERROR(StatusCode::K_RUNTIME_ERROR, + FormatString("Stream:<%s>, State:", + streamName_)); + } + for (const auto &pub : pubTable) { + auto ret = remotePubWorkerDict_.emplace(pub.ToString()); + CHECK_FAIL_RETURN_STATUS(ret.second, StatusCode::K_DUPLICATED, + "Runtime error: Fail to add pub worker into dict"); + } + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumRemoteProducers, remotePubWorkerDict_.size()); + } + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] SyncPubTable success, table size:%d", LogPrefix(), + remotePubWorkerDict_.size()); + return Status::OK(); +} + +void StreamManager::GetLocalProducers(std::vector &localProducers) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + for (const auto &kv : pubs_) { + const auto &producer = kv.second; + localProducers.emplace_back(producer->GetId()); + } +} + +void StreamManager::GetLocalConsumers(std::vector> &localConsumers) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + localConsumers.clear(); + for (const auto &kv : subs_) { + const auto &sub = kv.second; + std::vector consumerIds; + sub->GetAllConsumers(consumerIds); + for (auto &consumerId : consumerIds) { + localConsumers.emplace_back(std::move(consumerId), kv.second->GetSubscriptionConfig()); + } + } +} + +Status StreamManager::CreateSubscriptionIfMiss(const SubscriptionConfig &config, uint64_t &lastAckCursor) +{ + auto lastAppendCursor = GetLastAppendCursor(); + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + if (subs_.empty()) { + // Reserve local cache memory for stream consumer using the batch size FLAGS_zmq_chunk_sz, + // fail the request if memory is not available. + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + RETURN_IF_NOT_OK(scSvc->ReserveMemoryFromUsageMonitor(GetStreamName(), FLAGS_zmq_chunk_sz)); + } + lastAckCursor = UpdateLastAckCursorUnlocked(lastAppendCursor); + auto iter = subs_.find(config.subscriptionName); + if (iter == subs_.end()) { + auto ret = subs_.emplace(config.subscriptionName, + std::make_shared(config, lastAckCursor, GetStreamName())); + CHECK_FAIL_RETURN_STATUS(ret.second, StatusCode::K_DUPLICATED, + "Failed to add subscription into stream manager"); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, Sub:%s] Create new subscription succeeded", LogPrefix(), + config.subscriptionName); + } else { + CHECK_FAIL_RETURN_STATUS(iter->second->GetSubscriptionType() == config.subscriptionType, StatusCode::K_INVALID, + "The subscription type of request subscription is inconsistent with the type " + "stored in subs_ dict"); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, Sub:%s] Subscription already exist.", LogPrefix(), + config.subscriptionName); + } + return Status::OK(); +} + +Status StreamManager::ClearAllRemotePub() +{ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + ClearAllRemotePubUnlocked(); + return Status::OK(); +} + +void StreamManager::ClearAllRemotePubUnlocked() +{ + remotePubWorkerDict_.clear(); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumRemoteProducers, 0); + } +} + +Status StreamManager::EarlyReclaim(bool remoteAck, uint64_t lastAppendCursor, uint64_t newAckCursor) +{ + { + // If CreateProducer is running at this point, we will wait on the reclaimMutex_. + WriteLockHelper reclaimLck(STREAM_COMMON_LOCK_ARGS(reclaimMutex_)); + reclaimWp_.Clear(); + // We don't need to hold reclaimMutex_. The reclaim wait post is on. + // If CreateProducer is running at this point, it will be blocked on the WaitPost + // until we finish. + } + Raii wpRaii([this]() { reclaimWp_.Set(); }); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + // If local consumer still exists, leave the ack to be done by the ack thread itself. + RETURN_OK_IF_TRUE(!subs_.empty()); + // We will check if all remote consumers are gone. If it is empty, it means + // (a) all the local producers are gone, or + // (b) remote consumer has done an early exit. + // Both cases are driven by the master rpc to this worker. + // Either case there is no need to cache pages for the future. + RETURN_OK_IF_TRUE(!remoteSubWorkerDict_.empty()); + if (remoteAck) { + const uint64_t timeoutMs = 10; + LOG(INFO) << FormatString("[%s] Reclaim memory. Last append %zu. Last ack %zu", LogPrefix(), lastAppendCursor, + newAckCursor); + RETURN_IF_NOT_OK_EXCEPT(GetExclusivePageQueue()->ReclaimAckedChain(timeoutMs), K_TRY_AGAIN); + } + // Finally if all elements have been acked, or all producers are gone, we release all the shared memory. + RETURN_OK_IF_TRUE(!pubs_.empty()); + bool hasRemoteConsumers = remoteWorkerManager_->HasRemoteConsumers(streamName_); + auto remoteAckCursor = remoteWorkerManager_->GetLastAckCursor(streamName_); + // refresh last page and last append cursor since there can be producer inserted elements and closed since + // the last time we get last append cursor. + const bool updateLocalPubLastPage = false; + RETURN_IF_NOT_OK(GetExclusivePageQueue()->MoveUpLastPage(updateLocalPubLastPage)); + lastAppendCursor = GetLastAppendCursor(); + LOG(INFO) << FormatString("[%s] HasRemoteConsumers = %s remoteAckCursor = %zu lastAppendCursor = %zu", LogPrefix(), + (hasRemoteConsumers ? "true" : "false"), remoteAckCursor, lastAppendCursor); + if (!hasRemoteConsumers || remoteAckCursor == lastAppendCursor) { + // Data should have been purged by this time at early reclaim logic, so stop to push new buffer into RW + if (!pageQueueHandler_->ExistsSharedPageQueue()) { + RETURN_IF_NOT_OK(remoteWorkerManager_->DoneScanning(streamName_)); + } + RETURN_IF_NOT_OK(GetExclusivePageQueue()->ReleaseAllPages()); + } + return Status::OK(); +} + +Status StreamManager::ClearAllRemoteConsumerUnlocked(bool forceClose) +{ + Status rc; + if (forceClose && wakeupPendingRecvOnProdFault_) { + // add current node (self node) in the list of forced nodes. + for (const auto &sub : subs_) { + Status rc1 = sub.second->SetForceClose(); + if (rc.IsOk()) { + rc = rc1; + } + } + } + remoteSubWorkerDict_.clear(); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::NumRemoteConsumers, 0); + } + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Clear all remote consumer succeeded", LogPrefix()); + return rc; +} + +std::string StreamManager::LogPrefix() const +{ + return FormatString("S:%s", streamName_); +} + +Status StreamManager::GetSubscription(const std::string &subName, std::shared_ptr &subscription) +{ + PerfPoint point(PerfKey::MANAGE_GET_SUB); + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto iter = subs_.find(subName); + if (iter == subs_.end()) { + RETURN_STATUS(StatusCode::K_SC_CONSUMER_NOT_FOUND, "Subscription not found" + subName); + } + RETURN_RUNTIME_ERROR_IF_NULL(iter->second); + subscription = iter->second; + point.Record(); + return Status::OK(); +} + +Status StreamManager::RemovePubSubFromResetList(std::vector &prodConList) +{ + Status sc = Status::OK(); + for (auto pubSubId : prodConList) { + bool found = false; + for (auto iter = prodConResetList_.begin(); iter != prodConResetList_.end(); ++iter) { + if (*iter == pubSubId) { + prodConResetList_.erase(iter); + found = true; + break; + } + } + if (!found) { + sc = Status(K_NOT_FOUND, FormatString("%s Not found in the list of resetting pubs/subs", pubSubId)); + LOG(ERROR) << sc.GetMsg(); + } + } + if (prodConResetList_.empty() && CheckIfStreamInState(StreamState::RESET_IN_PROGRESS)) { + return ResetStreamEnd(); + } + return sc; +} + +Status StreamManager::ResetStreamStart(std::vector &prodConList) +{ + // Stop remote producer pushing more data with the stream status flag. + std::unique_lock lock(resetMutex_); + { + // protect create/close pubs_/subs_ + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + if (CheckIfStreamInState(StreamState::ACTIVE)) { + RETURN_IF_NOT_OK(SetNewState(StreamState::RESET_IN_PROGRESS)); + prodConResetList_.clear(); + for (const auto &prod : pubs_) { + prodConResetList_.emplace_back(prod.first); + } + for (auto &sub : subs_) { + std::vector consumerIds; + sub.second->GetAllConsumers(consumerIds); + prodConResetList_.insert(prodConResetList_.end(), consumerIds.begin(), consumerIds.end()); + } + } else if (CheckIfStreamInState(StreamState::DELETE_IN_PROGRESS)) { + RETURN_STATUS(K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Delete is in progress on Stream [%s].", streamName_)); + } else if (CheckIfStreamInState(StreamState::RESET_COMPLETE)) { + LOG(WARNING) << "Reset is already completed for stream: " << streamName_; + return Status::OK(); + } + } + return RemovePubSubFromResetList(prodConList); +} + +Status StreamManager::ResetStreamEnd() +{ + // As we have one BufferPool per remote worker + // We will be inserting EOS num of remoteWorkers x num of producers + // Create a placeholder element for producer + { + remoteWorkerManager_->PurgeBuffer(shared_from_this()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] ResetStreamStart for remote consumer", LogPrefix()); + if (auto workerWorkerSCServicePtr = workerWorkerSCService_.lock()) { + workerWorkerSCServicePtr->PurgeBuffer(shared_from_this()); + } + } + + // Check we got replies from all EOS messages that we inserted + INJECT_POINT("worker.stream.sleep_while_reset"); + // If we got replies from all -> Start cleanup + // Clear pointer and indexes of each producer + Status rc = Status::OK(); + { + // protect read pubs_/subs_ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + for (const auto &pub : pubs_) { + Status status = pub.second->CleanupProducer(); + if (status.IsError()) { + LOG(ERROR) << status.GetMsg(); + rc = status; + } + } + + // Clear pointer and indexes of each consumer + for (const auto &sub : subs_) { + sub.second->CleanupSubscription(); + } + } + + // Clear all the indexes and maps in StreamDataObject + RETURN_IF_NOT_OK(GetExclusivePageQueue()->Reset()); + RETURN_IF_NOT_OK(remoteWorkerManager_->ResetStreamScanList(streamName_)); + { + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + blockOnOOM_.clear(); + } + Status status = UnblockCreators(); + if (status.IsError()) { + LOG(ERROR) << status.GetMsg(); + rc = status; + } + // Notify the client + rc = SetNewState(StreamState::RESET_COMPLETE); + VLOG(SC_INTERNAL_LOG_LEVEL) << "Reset complete for " << streamName_; + return rc; +} + +void StreamManager::ForceCloseClients() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + for (const auto &pub : pubs_) { + pub.second->SetForceClose(); + } + + for (const auto &sub : subs_) { + sub.second->SetForceClose(); + } +} + +Status StreamManager::GetSubType(const std::string &subName, SubscriptionType &type) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto iter = subs_.find(subName); + if (iter == subs_.end()) { + RETURN_STATUS(StatusCode::K_SC_CONSUMER_NOT_FOUND, "Subscription not found" + subName); + } + type = iter->second->GetSubscriptionType(); + return Status::OK(); +} + +int64_t StreamManager::GetStreamPageSize() +{ + return pageQueueHandler_->GetPageSize(); +} + +double StreamManager::GetStreamMemAllocRatio() +{ + auto maxAllocatedMemorySC = scAllocateManager_->GetTotalMaxStreamSHMSize(); + if (maxAllocatedMemorySC != 0) { + return (GetExclusivePageQueue()->GetMaxStreamSize() / (double)maxAllocatedMemorySC); + } else { + return 1.0; + } +} + +Status StreamManager::CheckConsumerExist(const std::string &workerAddr) +{ + // No consumer on this node, deny data push. + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + CHECK_FAIL_RETURN_STATUS(!subs_.empty(), StatusCode::K_SC_CONSUMER_NOT_FOUND, + FormatString("No consumer on this node [%s - %s]", streamName_, workerAddr)); + return Status::OK(); +} + +Status StreamManager::SendBlockProducerReq(const std::string &remoteWorkerAddr) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + VLOG(SC_NORMAL_LOG_LEVEL) << "Blocking Producer for stream: " << streamName_ + << " sending to remote worker: " << remoteWorkerAddr; + HostPort workerHostPort; + RETURN_IF_NOT_OK(workerHostPort.ParseString(remoteWorkerAddr)); + std::shared_ptr stub; + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + RETURN_IF_NOT_OK(scSvc->GetWorkerStub(workerHostPort, stub)); + std::unique_ptr> clientApi; + RETURN_IF_NOT_OK(stub->BlockProducer(RpcOptions(), &clientApi)); + BlockProducerReqPb req; + req.set_stream_name(streamName_); + req.set_worker_addr(workerAddr_); + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + // Send only. No need to wait for any reply, and let clientApi goes out of scope by itself. + RETURN_IF_NOT_OK(clientApi->Write(req)); + // Wait for the reply to ensure ordering + BlockProducerRspPb rsp; + RETURN_IF_NOT_OK(clientApi->Read(rsp)); + + if (scStreamMetrics_) { + scStreamMetrics_->IncrementMetric(StreamMetric::NumRemoteProducersBlocked, 1); + } + VLOG(SC_NORMAL_LOG_LEVEL) << "Blocking Producer for stream: " << streamName_ + << " sent to remote worker: " << remoteWorkerAddr << " is Successful"; + INJECT_POINT("StreamManager.SendBlockProducerReq.delay"); + return Status::OK(); +} + +Status StreamManager::BlockProducer(const std::string &workerAddr, bool addCallBack) +{ + { + WriteLockHelper xlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto it = blockOnOOM_.find(workerAddr); + if (it == blockOnOOM_.end()) { + bool success; + std::tie(it, success) = blockOnOOM_.emplace(workerAddr, false); + } + if (it->second) { + return Status::OK(); + } + it->second = true; + LOG(INFO) << FormatString("[%s] BlockProducer from remote worker %s", LogPrefix(), workerAddr); + } + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + auto streamName = streamName_; + scSvc->GetThreadPool()->Execute([scSvc, streamName, workerAddr, addCallBack]() { + // Send the request + StreamManagerMap::const_accessor accessor; + Status rc = scSvc->GetStreamManager(streamName, accessor); + if (rc.IsError()) { + return; + } + std::shared_ptr streamMgr = accessor->second; + LOG_IF_ERROR(streamMgr->SendBlockProducerReq(workerAddr), "block error"); + // Add a call back after send block request to maintain ordering + std::weak_ptr weakStreamMgr = streamMgr; + if (!addCallBack) { + return; + } + streamMgr->AddUnblockCallback(workerAddr, [weakStreamMgr, workerAddr, streamName]() { + auto streamMgr = weakStreamMgr.lock(); + if (streamMgr != nullptr) { + LOG_IF_ERROR(streamMgr->UnBlockProducer(workerAddr), "unblock error"); + } else { + LOG(WARNING) + << "The StreamManager already destroy when execute UnBlockProducer callback for streamName " + << streamName; + } + }); + }); + return Status::OK(); +} + +Status StreamManager::SendUnBlockProducerReq(const std::string &remoteWorkerAddr) +{ + TraceGuard traceGuard = Trace::Instance().SetTraceUUID(); + LOG(INFO) << FormatString("[%s] UnBlockProducer from remote worker %s", LogPrefix(), remoteWorkerAddr); + ResetOOMState(remoteWorkerAddr); // Producer is unblocked + HostPort workerHostPort; + RETURN_IF_NOT_OK(workerHostPort.ParseString(remoteWorkerAddr)); + std::shared_ptr stub; + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + RETURN_IF_NOT_OK(scSvc->GetWorkerStub(workerHostPort, stub)); + std::unique_ptr> clientApi; + RETURN_IF_NOT_OK(stub->UnblockProducer(RpcOptions(), &clientApi)); + UnblockProducerReqPb req; + req.set_stream_name(streamName_); + req.set_worker_addr(workerAddr_); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + // Wait for the reply to ensure ordering + RETURN_IF_NOT_OK(clientApi->Write(req)); + UnblockProducerRspPb rsp; + RETURN_IF_NOT_OK(clientApi->Read(rsp)); + if (scStreamMetrics_) { + scStreamMetrics_->DecrementMetric(StreamMetric::NumRemoteProducersBlocked, 1); + } + VLOG(SC_NORMAL_LOG_LEVEL) << "UnBlocking Producer for stream: " << streamName_ + << " sent to remote worker: " << remoteWorkerAddr << " is Successful"; + return Status::OK(); +} + +void StreamManager::ResetOOMState(const std::string &remoteWorkerAddr) +{ + // Unblock call back is already set + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto it = blockOnOOM_.find(remoteWorkerAddr); + if (it != blockOnOOM_.end()) { + it->second = false; + } +} + +Status StreamManager::UnBlockProducer(const std::string &workerAddr) +{ + auto weakThis = weak_from_this(); + auto scSvc = scSvc_.lock(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(scSvc != nullptr, K_SHUTTING_DOWN, "worker shutting down."); + scSvc->GetThreadPool()->Execute([weakThis, workerAddr, streamName = streamName_]() { + auto streamManager = weakThis.lock(); + if (streamManager != nullptr) { + LOG_IF_ERROR(streamManager->SendUnBlockProducerReq(workerAddr), "unblock error"); + } else { + LOG(WARNING) << "The StreamManager already destroy when async UnBlockProducer for streamName " + << streamName; + } + }); + return Status::OK(); +} + +bool StreamManager::IsProducerBlocked(const std::string &workerAddr) +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + auto it = blockOnOOM_.find(workerAddr); + if (it == blockOnOOM_.end()) { + return false; + } + return it->second; +} + +Status StreamManager::CopyElementView(std::shared_ptr &recvElementView, UsageMonitor &usageMonitor, + uint64_t timeoutMs) +{ + // Decrypt failure is non-recoverable error. + auto pageQueue = GetExclusivePageQueue(); + size_t totalLength = 0; + // Close consumer can trigger memory reclaim, but at the same time, + // remote elements might still get written into shm page. + // So we block the potential reclaim before we finish the ongoing BatchInsert. + // Later push requests will get K_SC_CONSUMER_NOT_FOUND if consumer is indeed closed. + BlockMemoryReclaim(); + Raii raii([this]() { UnblockMemoryReclaim(); }); + std::vector sz(recvElementView->sz_.begin() + recvElementView->idx_, recvElementView->sz_.end()); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(!sz.empty(), K_INVALID, + FormatString("[%s] invalid idx %zu", LogPrefix(), recvElementView->idx_)); + // To reduce the chance of OOM and seal a page that is only partially filled, we will do few rows at a time + // while competing with local producers. We can also be resuming from where we left off last time. + std::pair res(0, 0); + auto rc = pageQueue->BatchInsert(recvElementView->GetBufferPointer(), sz, res, timeoutMs, + recvElementView->headerBits_, GetStreamMetaShm()); + totalLength = res.second; + recvElementView->idx_ += res.first; + // PageView is processed and will be removed from local cache + if (totalLength > 0) { + usageMonitor.DecUsage(streamName_, recvElementView->workerAddr_, totalLength); + } + if (rc.IsOk()) { + // Sanity check. If all successful, idx_ should now be at the end. + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(recvElementView->idx_ == recvElementView->sz_.size(), K_RUNTIME_ERROR, + FormatString("[%s] Expect %zu but got %zu", LogPrefix(), + recvElementView->sz_.size(), recvElementView->idx_)); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "Page is copied successfully stream name: %s, worker addr: %s, " + "seq: %zu, count: %zu, size: %zu, stream state: %s", + recvElementView->StreamName(), recvElementView->workerAddr_, recvElementView->seqNo_, + recvElementView->sz_.size(), totalLength, PrintStreamStatus()); + } else { + switch (rc.GetCode()) { + case K_OUT_OF_MEMORY: + LOG(WARNING) << FormatString("Out of memory for stream: %s, status %s, Stream Status: %s", streamName_, + rc.GetMsg(), PrintStreamStatus()); + LOG_IF_ERROR(BlockProducer(recvElementView->workerAddr_, true), + "Failed to block sender"); // Sends a Block RPC to other worker to wait + return rc; + case K_TRY_AGAIN: + return rc; + default: + LOG(ERROR) << FormatString("[%s] Non-recoverable error. %s", LogPrefix(), rc.ToString()); + } + } + recvElementView.reset(); // We will release memory except for OOM case where we retry + return Status::OK(); +} + +uint64_t StreamManager::GetLastAppendCursor() const +{ + if (pageQueueHandler_ != nullptr) { + return GetExclusivePageQueue()->GetLastAppendCursor(); + } + return 0; +} + +void StreamManager::PauseAckThread() +{ + // To ensure the GC thread is paused. Not only we clear the wait post + // we will hold the lock in exclusive in the same order GC thread is doing + WriteLockHelper wlock(STREAM_COMMON_LOCK_ARGS(ackMutex_)); + ackWp_.Clear(); +} + +void StreamManager::ResumeAckThread() +{ + ackWp_.Set(); +} + +uint64_t StreamManager::GetEleCount() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + uint64_t val = 0; + for (auto &sub : subs_) { + val += sub.second->GetElementCountReceived(); + } + return val; +} + +uint64_t StreamManager::GetEleCountAcked() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + uint64_t val = subs_.size() > 0 ? std::numeric_limits::max() : 0; + uint64_t count = 0; + for (auto &sub : subs_) { + count = sub.second->UpdateLastAckCursor(); + if (val > count) { + val = count; + } + } + return val; +} + +uint64_t StreamManager::GetEleCountSentAndReset() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + uint64_t val = 0; + for (auto &pub : pubs_) { + val += pub.second->GetElementCountAndReset(); + } + return val; +} + +uint64_t StreamManager::GetEleCountReceived() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + uint64_t val = subs_.size() > 0 ? std::numeric_limits::max() : 0; + uint64_t count = 0; + for (auto &sub : subs_) { + count = sub.second->GetElementCountReceived(); + if (val > count) { + val = count; + } + } + return val; +} + +uint64_t StreamManager::GetSendRequestCountAndReset() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + uint64_t val = 0; + for (auto &pub : pubs_) { + val += pub.second->GetRequestCountAndReset(); + } + return val; +} + +uint64_t StreamManager::GetReceiveRequestCountAndReset() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + uint64_t val = 0; + for (auto &sub : subs_) { + val += sub.second->GetRequestCountAndReset(); + } + return val; +} + +void StreamManager::AddUnblockCallback(const std::string &addr, std::function unblockCallback) +{ + if (pageQueueHandler_) { + GetExclusivePageQueue()->AddUnblockCallback(addr, std::move(unblockCallback)); + } +} + +bool StreamManager::AutoCleanup() const +{ + return GetExclusivePageQueue()->AutoCleanup(); +} + +std::vector StreamManager::GetRemoteWorkers() const +{ + std::vector remoteWorkers; + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + std::transform(remoteSubWorkerDict_.begin(), remoteSubWorkerDict_.end(), std::back_inserter(remoteWorkers), + [](auto &kv) { return kv.first; }); + return remoteWorkers; +} + +bool StreamManager::IsRemotePubEmpty() +{ + ReadLockHelper rlock(STREAM_COMMON_LOCK_ARGS(mutex_)); + return remotePubWorkerDict_.empty(); +} + +Status StreamManager::UpdateStreamFields(const StreamFields &streamFields, bool reserveShm) +{ + RETURN_IF_NOT_OK(pageQueueHandler_->UpdateStreamFields(streamFields)); + if (EnableSharedPage(streamFields.streamMode_)) { + ShmView shmViewOfStreamMeta; + RETURN_IF_NOT_OK( + GetOrCreateShmMeta(TenantAuthManager::Instance()->ExtractTenantId(streamName_), shmViewOfStreamMeta)); + } + if (reserveShm) { + // Reserve shared memory if we haven't done it previously + auto pageQueue = GetExclusivePageQueue(); + RETURN_IF_NOT_OK(pageQueue->ReserveStreamMemory()); + LOG(INFO) << FormatString("[%s] %zu bytes of shared memory has been reserved", LogPrefix(), + pageQueue->GetReserveSize()); + } + TryWakeUpPendingReceive(); + return Status::OK(); +} + +void StreamManager::GetStreamFields(StreamFields &streamFields) +{ + GetExclusivePageQueue()->GetStreamFields(streamFields); +} + +Status StreamManager::InitStreamMetrics() +{ + return ScMetricsMonitor::Instance()->AddStream(streamName_, weak_from_this(), scStreamMetrics_); +} + +bool StreamManager::CheckHadEnoughMem(size_t sz) const +{ + if (pageQueueHandler_->ExistsSharedPageQueue()) { + // Fixme stage 2 + return true; + } + return GetExclusivePageQueue()->CheckHadEnoughMem(sz).IsOk(); +} + +void StreamManager::UpdateStreamMetrics() +{ + if (scStreamMetrics_) { + scStreamMetrics_->IncrementMetric(StreamMetric::NumTotalElementsSent, GetEleCountSentAndReset()); + scStreamMetrics_->LogMetric(StreamMetric::NumTotalElementsReceived, GetEleCountReceived()); + scStreamMetrics_->LogMetric(StreamMetric::NumTotalElementsAcked, GetEleCountAcked()); + scStreamMetrics_->IncrementMetric(StreamMetric::NumSendRequests, GetSendRequestCountAndReset()); + scStreamMetrics_->IncrementMetric(StreamMetric::NumReceiveRequests, GetReceiveRequestCountAndReset()); + scStreamMetrics_->LogMetric(StreamMetric::NumLocalProducersBlocked, + lobBlockedList_.Size() + dataBlockedList_.Size()); + + // Make sure WorkerWorkerService is not destructed + if (auto workerWorkerSCServicePtr = workerWorkerSCService_.lock()) { + scStreamMetrics_->LogMetric(StreamMetric::LocalMemoryUsed, + workerWorkerSCServicePtr->GetUsageMonitor().GetLocalMemoryUsed(streamName_)); + } + auto pageQueue = GetExclusivePageQueue(); + if (pageQueue) { + const int workAreaSize = 64; + uint64_t workAreaMemUsed = (scStreamMetrics_->GetMetric(StreamMetric::NumLocalProducers) + + scStreamMetrics_->GetMetric(StreamMetric::NumLocalConsumers)) + * workAreaSize; + scStreamMetrics_->LogMetric(StreamMetric::SharedMemoryUsed, + pageQueue->GetSharedMemoryUsed() + workAreaMemUsed); + scStreamMetrics_->LogMetric(StreamMetric::NumPagesCreated, pageQueue->GetNumPagesCreated()); + scStreamMetrics_->LogMetric(StreamMetric::NumPagesReleased, pageQueue->GetNumPagesReleased()); + scStreamMetrics_->LogMetric(StreamMetric::NumPagesInUse, pageQueue->GetNumPagesInUse()); + scStreamMetrics_->LogMetric(StreamMetric::NumPagesCached, pageQueue->GetNumPagesCached()); + scStreamMetrics_->LogMetric(StreamMetric::NumBigPagesCreated, pageQueue->GetNumBigPagesCreated()); + scStreamMetrics_->LogMetric(StreamMetric::NumBigPagesReleased, pageQueue->GetNumBigPagesReleased()); + } + } +} + +void StreamManager::ClearBlockedList() +{ + std::unique_lock xlock(streamManagerBlockedListsMutex_); + dataBlockedList_.ClearBlockedList(); + lobBlockedList_.ClearBlockedList(); +} + +bool StreamManager::EnableSharedPage(StreamMode streamMode) +{ + return streamMode == StreamMode::MPSC || streamMode == StreamMode::SPSC; +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/stream_manager.h b/src/datasystem/worker/stream_cache/stream_manager.h new file mode 100644 index 0000000..d8cec12 --- /dev/null +++ b/src/datasystem/worker/stream_cache/stream_manager.h @@ -0,0 +1,995 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_STREAM_MANAGER_H +#define DATASYSTEM_WORKER_STREAM_CACHE_STREAM_MANAGER_H +#include +#include +#include +#include + +#include "datasystem/common/stream_cache/consumer_meta.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/protos/stream_posix.pb.h" +#include "datasystem/protos/worker_stream.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/optional.h" +#include "datasystem/worker/stream_cache/consumer.h" +#include "datasystem/worker/stream_cache/page_queue/page_queue_handler.h" +#include "datasystem/worker/stream_cache/remote_worker_manager.h" +#include "datasystem/worker/stream_cache/producer.h" +#include "datasystem/worker/stream_cache/stream_data_pool.h" +#include "datasystem/worker/stream_cache/subscription.h" +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" +#include "datasystem/worker/stream_cache/worker_worker_sc_service_impl.h" + +DS_DECLARE_int32(log_monitor_interval_ms); + +namespace datasystem { +namespace worker { +namespace stream_cache { +class ClientWorkerSCServiceImpl; +struct RecvElementView; +class WorkerWorkerSCServiceImpl; +enum class StreamState { ACTIVE = 0, RESET_IN_PROGRESS = 1, RESET_COMPLETE = 2, DELETE_IN_PROGRESS = 3 }; +class SubWorkerDesc { +public: + explicit SubWorkerDesc(HostPort subWorkerAddress) : hostPort_(std::move(subWorkerAddress)) + { + } + + ~SubWorkerDesc() = default; + SubWorkerDesc(const SubWorkerDesc &other) = delete; + SubWorkerDesc &operator=(const SubWorkerDesc &other) = delete; + SubWorkerDesc(SubWorkerDesc &&other) noexcept = default; + SubWorkerDesc &operator=(SubWorkerDesc &&other) noexcept = default; + + /** + * @brief Add consumer for this subWorkerDesc. + * @param[in] subConfig Consumer sub config. + * @param[in] consumerId Consumer id. + * @return Status of the call. + */ + Status AddConsumer(const SubscriptionConfig &subConfig, const std::string &consumerId) + { + CHECK_FAIL_RETURN_STATUS( + subConfig.subscriptionType == SubscriptionType::STREAM, K_INVALID, + FormatString("Only support STREAM mode. <%s> mode not supported yet.", subConfig.subscriptionName)); + auto ret = consumers_.emplace(consumerId); + CHECK_FAIL_RETURN_STATUS(ret.second, K_DUPLICATED, "duplicate consumer"); + return Status::OK(); + } + + /** + * @brief Delete consumer for this subWorkerDesc. + * @param[in] consumerId Consumer id. + * @return Status of the call. + */ + Status DelConsumer(const std::string &consumerId) + { + CHECK_FAIL_RETURN_STATUS( + consumers_.find(consumerId) != consumers_.end(), StatusCode::K_NOT_FOUND, + FormatString("Consumer:<%s>, Worker:<%s>, State:", consumerId, hostPort_.ToString())); + consumers_.erase(consumerId); + return Status::OK(); + } + + /** + * @brief Get the consumer number for a remote sub worker. + * @return The consumer number. + */ + size_t ConsumerNum() const + { + return consumers_.size(); + } + +private: + HostPort hostPort_; + + // Key: consumerName Value: Information structure for a consumer. + std::set consumers_; +}; + +class StreamManager : public std::enable_shared_from_this { +public: + explicit StreamManager(std::string streamName, RemoteWorkerManager *remoteWorkerManager, + std::string localWorkerAddr, std::shared_ptr akSkManager, + std::weak_ptr scSvc, + std::shared_ptr manager, + std::weak_ptr workerWorkerSCService, uint64_t localStreamNum); + ~StreamManager(); + + StreamManager(const StreamManager &streamManager) = delete; + StreamManager &operator=(const StreamManager &streamManager) = delete; + StreamManager(StreamManager &&streamManager) = delete; + StreamManager &operator=(StreamManager &&streamManager) = delete; + + /** + * @brief Create a producer, i.e., register a publisher to a stream. + * @details Update producer session. + * @param[in] producerId The generated producer id. + * @param[out] senderProducerNo A locally unique number for the new producer within this stream. + * @return K_OK on success; the error code otherwise. + */ + Status AddProducer(const std::string &producerId, DataVerificationHeader::SenderProducerNo &senderProducerNo); + + /** + * @brief Set the cursor to producer. + * @param[in] producerId The generated producer id. + * @param[out] shmView The work area SHM + * @return K_OK on success; the error code otherwise. + */ + Status AddCursorForProducer(const std::string &producerId, ShmView &shmView); + + /** + * @brief Close a producer, force flushing and page seal, unregister a publisher to a stream. + * @param[in] producerId The generated producer id. + * @param[in] forceClose If the pub node had a crash or regular close. + * @return K_OK on success; the error code otherwise. + */ + Status CloseProducer(const std::string &producerId, bool forceClose); + + /** + * @brief Subscribe to a stream, using a subscription name, i.e., register a consumer to a subscription. + * @param[in] config Subscription config. + * @param[in] consumerId Consumer id. + * @param[out] lastAckCursor The last ack cursor of the new Consumer. + * @return K_OK on success; the error code otherwise. + */ + Status AddConsumer(const SubscriptionConfig &config, const std::string &consumerId, uint64_t &lastAckCursor, + ShmView &waView); + + /** + * @brief Close a consumer, trigger subscription cursor change and unregister a subscribed consumer to a stream. + * @param[in] subName Subscription name. + * @param[in] consumerId Consumer id. + * @return K_OK on success; the error code otherwise. + */ + Status CloseConsumer(const std::string &subName, const std::string &consumerId); + + /** + * @brief Check stream if no producer/consumer session remains. + * @return Status of the call. + */ + Status CheckDeleteStreamCondition(); + + /** + * @brief Create a stream page and get its related shared memory meta to perform zero-copy send. + * @param[in] producerId Producer id. + * @param[in] pageId The page id. + * @param[out] shmView The 4-tuple to represent contiguous shared memory segment. + * @param[in] retryOnOOM retry at out of memory if the para is set to true + * @return K_OK on success; the error code otherwise. + */ + Status CreateOrGetLastDataPage(const std::string &producerId, uint64_t timeoutMs, const ShmView &lastView, + std::shared_ptr &lastPage, bool retryOnOOM); + + /** + * @brief Call back function for BlockedCreateRequest + */ + Status AllocDataPage(BlockedCreateRequest *); + + Status AllocDataPageInternalReq(uint64_t timeoutMs, const ShmView &curView, ShmView &outView); + + /** + * @brief Create a BigElement page + * @param serverApi + * @param req + * @return + */ + Status AllocBigShmMemory(BlockedCreateRequest *); + + Status AllocBigShmMemoryInternalReq(uint64_t timeoutMs, size_t sz, ShmView &outView); + + /** + * @brief Release a BigElement page + */ + Status ReleaseBigShmMemory( + const std::shared_ptr> &serverApi, + const ReleaseLobPageReqPb &req); + + /** + * @brief Adds a blocked CreateShmPage request with timer + * @param blockedReq The info about the blocked request to add + * @return K_OK on success; the error code otherwise. + */ + Status AddBlockedCreateRequest( + ClientWorkerSCServiceImpl *scSvc, + std::shared_ptr> blockedReq, bool lock); + + /** + * @brief Adds a blocked CreateShmPage request with timer + * @param blockedReq The info about the blocked request to add + * @return K_OK on success; the error code otherwise. + */ + Status AddBlockedCreateRequest( + ClientWorkerSCServiceImpl *scSvc, + std::shared_ptr> blockedReq, bool lock); + + /** + * @brief Pop a blocked CreateShmPage request + * @param blockedReq + * @return K_OK on success + */ + Status GetBlockedCreateRequest( + std::shared_ptr> &blockedReq); + + /** + * @brief Pop a blocked CreateLobPage request + * @param blockedReq + * @return K_OK on success + */ + Status GetBlockedCreateRequest( + std::shared_ptr> &blockedReq); + + /** + * @brief If there were any BlockedCreateRequests waiting for the stream, it will fetch the blocked request + * and then launch an async thread to execute it now (and cancel the timeout timer). + * @return K_OK on success; the error code otherwise. + */ + Status UnblockCreators(); + + /** + * @brief Return the oldest request's requested size and type + * @return + */ + std::pair GetNextBlockedRequestSize(); + + /** + * @brief Handle a blocked request + * @param blockedReq + * @return K_OK on success + */ + template + Status HandleBlockedRequestImpl(std::shared_ptr> &&blockedReq, bool lock); + + /** + * @brief As part of the Receive api to locate the starting page. + * @param[in] req The request of ReceiveElements. + * @param[in] req The subscription + * @param[in] stream The stream rpc channel that used to get request and write response. + * @return K_OK on success; the error code otherwise. + */ + Status GetDataPage(const GetDataPageReqPb &req, const std::shared_ptr &sub, + const std::shared_ptr> &serverApi); + + /** + * @brief Check whether consumer exists. + * @param[in] workerAddr Remote pub worker addr. + * @return K_OK on success; the error code otherwise. + */ + Status CheckConsumerExist(const std::string &workerAddr); + + /** + * @brief Add a remote pub node for this worker in particular stream. + * @param[in] pubWorkerAddr Remote pub node address + * @return K_OK on success; the error code otherwise. + */ + Status AddRemotePubNode(const std::string &pubWorkerAddr); + + /** + * @brief Performs handling for a closed remote publisher for this worker in particular stream. + * @param[in] forceClose If the pub node had a crash or regular close + * @return K_OK on success; the error code otherwise. + */ + Status HandleClosedRemotePubNode(bool forceClose); + + /** + * @brief Add a remote sub consumer node for this worker in particular stream. + * @param[in] subWorker Remote sub consumer node address protobuf. + * @param[in] subConfig Remote sub consumer node subscription config protobuf. + * @param[in] consumerId Remote sub consumer node consumer id. + * @param[out] lastAckCursor Remote sub consumer last acknowledge Cursor. + * @return K_OK on success; the error code otherwise. + */ + Status AddRemoteSubNode(const HostPort &subWorker, const SubscriptionConfig &subConfig, + const std::string &consumerId, uint64_t &lastAckCursor); + + /** + * @brief Delete a remote sub consumer node for this worker in particular stream. + * @param[in] subWorker Remote sub consumer node address. + * @param[in] consumerId Remote sub consumer node consumer id. + * @return K_OK on success; the error code otherwise. + */ + Status DelRemoteSubNode(const HostPort &subWorker, const std::string &consumerId); + + /** + * @brief Synchronize all remote sub consumer nodes for this worker in particular stream. + * @param[in] subTable A vector of all remote sub consumer nodes. + * @param[in] isRecon Is this part of reconciliation process. + * @param[out] lastAckCursor Remote sub consumer last acknowledge Cursor. + * @return K_OK on success; the error code otherwise. + */ + Status SyncSubTable(const std::vector &subTable, bool isRecon, uint64_t &lastAckCursor); + + /** + * @brief Synchronize all remote pub worker nodes for this worker in particular stream. + * @param[in] pubTable A vector of all remote pub worker nodes. + * @param[in] isRecon Is this part of reconciliation process. + * @return K_OK on success; the error code otherwise. + */ + Status SyncPubTable(const std::vector &pubTable, bool isRecon); + + /** + * @brief Find all local producers. + * @param[out] localProducers Producer name list on local node. + */ + void GetLocalProducers(std::vector &localProducers); + + /** + * @brief Find all local consumers. + * @param[out] localConsumers Consumers name list on local node. + */ + void GetLocalConsumers(std::vector> &localConsumers); + + /** + * @brief Clear all remote pub. + * @return Status of the call. + */ + Status ClearAllRemotePub(); + + /** + * @brief Helper function to clear all remote pub when lock is already held. + * @return Status of the call. + */ + void ClearAllRemotePubUnlocked(); + + /** + * @brief Clear all remote consumer, without lock. + * @param[in] forceClose If the pub client had a crash or regular close + * @return Status of the call. + */ + Status ClearAllRemoteConsumerUnlocked(bool forceClose); + + /** + * @brief Get subscription type by its subName. + * @param[in] subName The name of the subscription. + * @param[out] type The output sub type. + * @return Status of the call. + */ + Status GetSubType(const std::string &subName, SubscriptionType &type); + + /** + * @brief Verifies the input stream fields match the existing setting. + * If the existing settings are uninitialized, updates the values. + * @param[in] streamFields The stream fields with page size and max stream size to check + * @return Status of the call. + */ + Status UpdateStreamFields(const StreamFields &streamFields, bool reserveShm); + + /** + * @brief Send back the stream fields for the stream of this object + * @param[in] streamFields The stream fields with page size and max stream size + */ + void GetStreamFields(StreamFields &streamFields); + + /** + * @brief Copyies given pageview into stream shm pages + * @param[in] recvPageView pageview object. + * @return Status of the call. + */ + Status CopyElementView(std::shared_ptr &recvElementView, UsageMonitor &usageMonitor, + uint64_t timeoutMs); + + /** + * @brief Getter of lastAppendCursor_. + * @return lastAppendCursor_. + */ + uint64_t GetLastAppendCursor() const; + + /** + * @return stream name + */ + auto GetStreamName() const + { + return streamName_; + } + + /** + * @brief Block remote producer + * @param[in] workerAddr Address of the remote producer + * @return Status of the call. + */ + Status BlockProducer(const std::string &workerAddr, bool addCallBack); + + /** + * @brief UnBlock remote producer + * @param[in] workerAddr Address of the remote producer + * @return Status of the call. + */ + Status UnBlockProducer(const std::string &workerAddr); + + /** + * @brief Check whether remote producer blocked + * @param[in] workerAddr Address of the remote producer + * @return Ture if blocked. + */ + bool IsProducerBlocked(const std::string &workerAddr); + + /** + * @brief Wake up pending receive if the element is enough. + */ + void TryWakeUpPendingReceive(); + + /** + * @brief Get subscription by subName. + * @param[in] subName The name of the subscription. + * @param[out] subscription The output subscription. + * @return Status of the call. + */ + Status GetSubscription(const std::string &subName, std::shared_ptr &subscription); + + /** + * @brief Set stream state and start the process of cleaning up the buffer pool. + * @param[in] prodConList The list of producers and consumers of the invoking client. + * @return Status of the call. + */ + Status ResetStreamStart(std::vector &prodConList); + + /** + * @brief Complete cleaning up stream data and metadata. Wakeup pending reset request on this stream. + * @return Status of the call. + */ + Status ResetStreamEnd(); + + /** + * @brief Force all producers/consumers + */ + void ForceCloseClients(); + + std::string GetStateString() + { + switch (streamState_) { + case StreamState::ACTIVE: + return "ACTIVE"; + case StreamState::RESET_IN_PROGRESS: + return "RESET_IN_PROGRESS"; + case StreamState::RESET_COMPLETE: + return "RESET_COMPLETE"; + case StreamState::DELETE_IN_PROGRESS: + return "DELETE_IN_PROGRESS"; + default: + return "INVALID"; + } + return "INVALID"; + } + + /** + * @brief Check if stream is in active state or not. + * @return Status of the call. + */ + inline Status CheckIfStreamActive() + { + std::shared_lock lock(streamStateMutex_); + if (streamState_ == StreamState::RESET_IN_PROGRESS || streamState_ == StreamState::RESET_COMPLETE) { + RETURN_STATUS(K_SC_STREAM_IN_RESET_STATE, + FormatString("Reset is invoked on Stream [%s]. Current state: %s. Resume is needed.", + streamName_, GetStateString())); + } else if (streamState_ == StreamState::DELETE_IN_PROGRESS) { + RETURN_STATUS(K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Delete is in progress on Stream [%s].", streamName_)); + } + return Status::OK(); + } + + inline std::string PrintStreamStatus() + { + Status rc = CheckIfStreamActive(); + if (rc.IsOk()) { + return "Active"; + } + return rc.GetMsg(); + } + + /** + * @brief Check if stream is in Reset In Progress state or not. + * @param[in] state input to check against. + * @return true if state is same, false otherwise. + */ + inline bool CheckIfStreamInState(StreamState state) + { + std::shared_lock lock(streamStateMutex_); + return streamState_ == state; + } + + /** + * @brief Set Stream into delete state + */ + Status SetDeleteState(bool ignore = false) + { + // lock is used to protect streamState + std::unique_lock lock(streamStateMutex_); + // If status is already delete then return error + if (streamState_ == StreamState::DELETE_IN_PROGRESS && ignore) { + deleteStateRefCount_ += 1; + LOG(INFO) << FormatString("[S:%s] Ref Count is %d", streamName_, deleteStateRefCount_); + RETURN_STATUS(K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Delete is in progress on Stream [%s].", streamName_)); + } else if (streamState_ == StreamState::DELETE_IN_PROGRESS) { + RETURN_STATUS(K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Delete is in progress on Stream [%s].", streamName_)); + } + deleteStateRefCount_ += 1; + streamState_ = StreamState::DELETE_IN_PROGRESS; + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::StreamState, (int)streamState_); + } + return Status::OK(); + } + + /** + * @brief Sets stream into Active state + */ + void SetActiveState() + { + // lock is used to protect streamState + // only re-set Active if there are no other deletes running + std::unique_lock lock(streamStateMutex_); + if (streamState_ == StreamState::DELETE_IN_PROGRESS) { + deleteStateRefCount_ -= 1; + } + if (deleteStateRefCount_ == 0) { + streamState_ = StreamState::ACTIVE; + LOG(INFO) << FormatString("[S:%s] Active State Set", streamName_); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::StreamState, (int)streamState_); + } + } else { + LOG(INFO) << FormatString("[S:%s] Active State Not Set, refCount %d", streamName_, deleteStateRefCount_); + } + } + + /** + * @brief Set new stream state + * @param[in] newState new state to set. + * @return K_SC_STREAM_DELETE_IN_PROGRESS or OK + */ + inline Status SetNewState(StreamState newState) + { + // lock is used to protect streamState + std::unique_lock lock(streamStateMutex_); + if (streamState_ == StreamState::DELETE_IN_PROGRESS) { + RETURN_STATUS(K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Delete is in progress on Stream [%s].", streamName_)); + } + streamState_ = newState; + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::StreamState, (int)streamState_); + } + return Status::OK(); + } + + /** + * @brief Allow stream to resume operation. + * @return Status of the call. + */ + inline Status ResumeStream() + { + if (CheckIfStreamInState(StreamState::ACTIVE) || CheckIfStreamInState(StreamState::RESET_COMPLETE)) { + SetActiveState(); + return Status::OK(); + } else if (CheckIfStreamInState(StreamState::DELETE_IN_PROGRESS)) { + RETURN_STATUS(K_SC_STREAM_DELETE_IN_PROGRESS, + FormatString("Delete is in progress on Stream [%s].", streamName_)); + } + RETURN_STATUS(K_TRY_AGAIN, FormatString("Reset is still going on for string [%s]", streamName_)); + } + /** + * @brief Create the underlying page queue handler where all stream pages are stored + * @return Status of the call. + */ + Status CreatePageQueueHandler(Optional cfg); + + /** + * Pause the GC thread + */ + void PauseAckThread(); + + /** + * Resume the GC thread + */ + void ResumeAckThread(); + + /** + * Called by remote worker manager to move up the ack cursor + * provided when there is no local consumers + */ + Status RemoteAck(); + + /** + * @brief Crash recovery for lost client to unlock by cursor. + * @param[in] cursorId The cursorId. + * @param[in] isProducer Ture for producer. + * @param[in] lockId The lock id. + */ + void ForceUnlockByCursor(const std::string &cursorId, bool isProducer, uint32_t lockId); + + /** + * @brief Crash recovery for lost client to unlock mem view on all pages. + * @param[in] lockId The lock id. + */ + void ForceUnlockMemViemForPages(uint32_t lockId); + + /** + * @brief Garbage collection by scanning all consumers' last ack cursors + * @return Status object + */ + Status AckCursors(); + + /** + * @return T if auto cleanup is on + */ + bool AutoCleanup() const; + + /** + * @brief Get ratio of memory allocated to the StreamCache. + * @return ratio of Mem Allocated to Stream / Total Mem Allocated to Stream Cache. + */ + double GetStreamMemAllocRatio(); + + /** + * @brief Gets the page size of the stream + * @return Page size of the stream + */ + int64_t GetStreamPageSize(); + + /** + * @return the stream memory manager + */ + auto GetAllocManager() + { + return scAllocateManager_; + } + + /** + * @return ClientService pointer + */ + auto GetClientService() + { + return scSvc_.lock(); + } + + /** + * @brief Get max window count + */ + auto GetMaxWindowCount() const + { + return GetExclusivePageQueue()->GetMaxWindowCount(); + } + + /** + * @brief Get remote worker manager + */ + auto GetRemoteWorkerManager() + { + return remoteWorkerManager_; + } + + /** + * @brief Get log prefix. + * @return The log prefix. + */ + std::string LogPrefix() const; + + /** + * @brief Get stream data object + * @return + */ + std::shared_ptr GetExclusivePageQueue() const + { + return pageQueueHandler_->GetExclusivePageQueue(); + } + + /** + * @brief Get the stream number of the stream. + * @return stream number. + */ + uint64_t GetStreamNo() const + { + return localStreamNum_; + } + + /** + * @brief Get stream metrics + * @return + */ + auto GetSCStreamMetrics() + { + return scStreamMetrics_; + } + + void InitRetainData(uint64_t retainForNumConsumers) + { + retainData_.Init(retainForNumConsumers); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::RetainDataState, retainData_.GetRetainDataState()); + } + LOG(INFO) << "[RetainData] state changed for the stream: " << streamName_ + << " current state: " << retainData_.PrintCurrentState(); + } + + void RollBackRetainDataStateToInit() + { + retainData_.RollBackToInit(); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::RetainDataState, retainData_.GetRetainDataState()); + } + LOG(INFO) << "[RetainData] state is rolled back to Init for the stream: " << streamName_; + } + + bool IsRetainData() + { + return retainData_.GetRetainDataState() == RetainDataState::RETAIN; + } + + void SetRetainData(uint32_t state) + { + retainData_.SetRetainDataState(static_cast(state)); + if (scStreamMetrics_) { + scStreamMetrics_->LogMetric(StreamMetric::RetainDataState, retainData_.GetRetainDataState()); + } + LOG(INFO) << "[RetainData] state changed for the stream: " << streamName_ + << " current state: " << retainData_.PrintCurrentState(); + } + + std::vector GetRemoteWorkers() const; + + /** + * @brief Check if RemotePub is empty + * @return T/F + */ + bool IsRemotePubEmpty(); + + /** + * @brief Handle a timeout memory alloc rpc request + * @param producerId + * @param subTimeout + * @param startTime The time when BlockedCreateRequest is created. + */ + template + void HandleBlockedCreateTimeout(const std::string &producerId, int64_t subTimeout, + const std::chrono::steady_clock::time_point &startTime) + { + (void)producerId; + (void)subTimeout; + (void)startTime; + } + + /** + * @brief Update stream metrics + */ + void UpdateStreamMetrics(); + + /** + * @brief Initialize stream metrics for this stream + * @return Status of the call. + */ + Status InitStreamMetrics(); + + /** + * @brief A preliminary check if we can allocate a page or a big element. + * @param sz + * @return T/F + * @note Not to be considered an absolute check + */ + bool CheckHadEnoughMem(size_t sz) const; + + /** + * @brief Get number of local producers + * @return number of local producers + */ + size_t GetLocalProducerCount() const + { + std::shared_lock lock(mutex_); + return pubs_.size(); + } + + /** + * @brief Clear blocked request list + */ + void ClearBlockedList(); + + /** + * (Un)block memory reclaim + */ + void BlockMemoryReclaim(); + void UnblockMemoryReclaim(); + + /** + * @brief Check enable shared page or not. + * @param[in] streamMode The stream mode. + * @return number of local producers + */ + static bool EnableSharedPage(StreamMode streamMode); + + void SetSharedPageQueue(std::shared_ptr sharedPageQueue) + { + pageQueueHandler_->SetSharedPageQueue(sharedPageQueue); + } + + /** + * @brief Get or create shm meta. + * @param[in] tenantId The ID of tenant. + * @param[out] view The view of shm meta. + * @return Status of the call. + */ + Status GetOrCreateShmMeta(const std::string &tenantId, ShmView &view) + { + return pageQueueHandler_->GetOrCreateShmMeta(tenantId, view); + } + + /** + * @brief Try to decrease the usage of shared memory in this node for this stream. + * @param[in] size The size to be increased. + * @return Status of the call. + */ + Status TryDecUsage(uint64_t size) + { + return pageQueueHandler_->TryDecUsage(size); + } + + /** + * @brief Get stream meta shm. + * @return The pointer to stream meta shm. + */ + StreamMetaShm *GetStreamMetaShm() + { + return pageQueueHandler_->GetStreamMetaShm(); + } + +protected: + /** + * @brief Create subscription if it not exist. + * @param[in] config The config of the subscription. + * @return Status of the call. + */ + Status CreateSubscriptionIfMiss(const SubscriptionConfig &config, uint64_t &lastAckCursor); + + /** + * @brief Get the min ack cursor of all subscriptions. + * @return The min ack cursor of all subscriptions. + */ + uint64_t UpdateLastAckCursorUnlocked(uint64_t lastAppendCursor); + +private: + /** + * @brief Helper function to return the number of elements of this stream received by consumers, and reset + * the count. + * @return Returns the value of this variable before it was called. + */ + uint64_t GetEleCount(); + + /** + * @brief Helper function to return the number of elements of this stream that were acked + * @return Returns the amount. + */ + uint64_t GetEleCountAcked(); + + /** + * @brief Helper function to return the number of elements of this stream sent and reset the count + * @return Returns the amount. + */ + uint64_t GetEleCountSentAndReset(); + + /** + * @brief Helper function to return the number of elements of this stream received + * @return Returns the amount. + */ + uint64_t GetEleCountReceived(); + + /** + * @brief Helper function to return the number of send requests this stream received and reset the counts + * @return Returns the amount. + */ + uint64_t GetSendRequestCountAndReset(); + + /** + * @brief Helper function to return the number of receive requests this stream received + * @return Returns the amount. + */ + uint64_t GetReceiveRequestCountAndReset(); + + /** + * @brief Inline function to add callback to unblock sending stream. + * @param[in] addr The producer worker address. + * @param[in] unblockCallback The callback functions to unblock the stream. + */ + void AddUnblockCallback(const std::string &addr, std::function unblockCallback); + + /** + * @brief Removes the producers and consumers received from a client from the reset pub/sub list. + * @param[in] prodConList The list of producers and consumers which should be removed from reset pub/sub lsit. + * @return K_OK on success; the error code otherwise. + */ + Status RemovePubSubFromResetList(std::vector &prodConList); + + /** + * @brief Helper function to reclaim shared memory pages when producers and consumers are all gone. + * @return K_OK on success; the error code otherwise. + */ + Status EarlyReclaim(bool remoteAck = false, uint64_t lastAppendCursor = 0, uint64_t newAckCursor = 0); + + Status SendBlockProducerReq(const std::string &remoteWorkerAddr); + Status SendUnBlockProducerReq(const std::string &remoteWorkerAddr); + void ResetOOMState(const std::string &remoteWorkerAddr); + + std::string workerAddr_; + const std::string streamName_; + RemoteWorkerManager *remoteWorkerManager_; + // protect pubs_/subs_/remoteSubWorkerDict_/remotePubWorkerDict_/blockOnOOM_ + mutable std::shared_timed_mutex mutex_; + mutable std::shared_timed_mutex resetMutex_; + // protect streamState_ + mutable std::shared_timed_mutex streamStateMutex_; + // deleteStateRefCount_ to protect stream state from being reactivated + int deleteStateRefCount_ = 0; + + std::unordered_map> pubs_; + std::unordered_map> subs_; + MemAllocRequestList dataBlockedList_; + MemAllocRequestList lobBlockedList_; + std::shared_timed_mutex streamManagerBlockedListsMutex_; + std::unordered_map> blockOnOOM_; // block at a worker level + + // Remote sub workers, consumers. Key: remote sub worker address, Value: remote SubWorkerDesc + std::unordered_map> remoteSubWorkerDict_; + // Remote pub workers. Key: remote pub worker address, Value: remote HostPort + std::unordered_set remotePubWorkerDict_; + std::shared_ptr akSkManager_; + std::weak_ptr scSvc_; + std::unique_ptr pageQueueHandler_; + mutable std::shared_timed_mutex ackMutex_; + WaitPost ackWp_; + std::atomic_uint64_t lastAckCursor_; + bool wakeupPendingRecvOnProdFault_; + StreamState streamState_{ StreamState::ACTIVE }; + std::shared_ptr scAllocateManager_{ nullptr }; + std::vector prodConResetList_; + std::weak_ptr workerWorkerSCService_; + RetainDataState retainData_; + std::atomic pendingLastProducerClose_{ false }; + std::atomic pendingLastProducerForceClose_{ false }; + std::shared_ptr scStreamMetrics_{ nullptr }; + mutable std::shared_timed_mutex reclaimMutex_; + WaitPost reclaimWp_; + // +1 everytime a new producer is added to pubs_. + // Use to identify each local producer within the stream with a unique number in data verification. + // The count will not be decreased even if create producer failed, or producer is closed. + DataVerificationHeader::SenderProducerNo lifetimeLocalProducerCount_{ 0 }; + uint64_t localStreamNum_{ 0 }; +}; + +template <> +inline void StreamManager::HandleBlockedCreateTimeout( + const std::string &producerId, int64_t subTimeout, const std::chrono::steady_clock::time_point &startTime) +{ + dataBlockedList_.HandleBlockedCreateTimeout(streamName_, producerId, subTimeout, startTime); +} + +template <> +inline void StreamManager::HandleBlockedCreateTimeout( + const std::string &producerId, int64_t subTimeout, const std::chrono::steady_clock::time_point &startTime) +{ + lobBlockedList_.HandleBlockedCreateTimeout(streamName_, producerId, subTimeout, startTime); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif // DATASYSTEM_WORKER_STREAM_CACHE_STREAM_MANAGER_H diff --git a/src/datasystem/worker/stream_cache/stream_producer.h b/src/datasystem/worker/stream_cache/stream_producer.h new file mode 100644 index 0000000..75f5f24 --- /dev/null +++ b/src/datasystem/worker/stream_cache/stream_producer.h @@ -0,0 +1,42 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_STREAM_PRODUCER_H +#define DATASYSTEM_WORKER_STREAM_CACHE_STREAM_PRODUCER_H + +#include + +namespace datasystem { +namespace worker { +namespace stream_cache { +/** + * @brief This simple class provides a pair of named strings for improved code readability (better than std::pair) + * since it avoids getting mixed up on which string is first or second. Used for making lists of producers by stream. + */ +class StreamProducer { +public: + StreamProducer(const std::string &streamName, const std::string &producerId) + : streamName_(streamName), producerId_(producerId) + { + } + ~StreamProducer() = default; + std::string streamName_; + std::string producerId_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif // DATASYSTEM_WORKER_STREAM_CACHE_STREAM_PRODUCER_H diff --git a/src/datasystem/worker/stream_cache/subscription.cpp b/src/datasystem/worker/stream_cache/subscription.cpp new file mode 100644 index 0000000..68c2609 --- /dev/null +++ b/src/datasystem/worker/stream_cache/subscription.cpp @@ -0,0 +1,204 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/subscription.h" + +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/worker/stream_cache/consumer.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +Subscription::Subscription(SubscriptionConfig subConfig, uint64_t lastStreamAck, std::string streamName) + : subConfig_(std::move(subConfig)), streamName_(std::move(streamName)), lastSubAckCursor_(lastStreamAck) +{ +} + +Status Subscription::AddConsumer(const SubscriptionConfig &config, const std::string &consumerId, + uint64_t lastAckCursor, std::shared_ptr cursor) +{ + CHECK_FAIL_RETURN_STATUS(config == this->subConfig_, StatusCode::K_RUNTIME_ERROR, + "The subscription config is different."); + std::lock_guard lock(mutex_); + // Initialize cursor by SubscriptionType, available data range is [lastSubAckCursor_, lastAppendCursor). + CHECK_FAIL_RETURN_STATUS(config.subscriptionType == SubscriptionType::STREAM, K_INVALID, "Not supported config."); + CHECK_FAIL_RETURN_STATUS(consumers_.empty(), StatusCode::K_RUNTIME_ERROR, + "In STREAM mode, 1 Subscription can only contain 1 Consumer"); + // In stream mode, the consumer is the only one in current subscription, so we don't have to update ackCursor. + auto consumer = std::make_shared(consumerId, lastAckCursor, streamName_, cursor); + consumer->SetElementCount(lastAckCursor); + auto ret = consumers_.emplace(consumerId, std::move(consumer)); + CHECK_FAIL_RETURN_STATUS(ret.second, StatusCode::K_DUPLICATED, "Failed to add consumer into subscription"); + return Status::OK(); +} + +Status Subscription::RemoveConsumer(const std::string &consumerId) +{ + std::shared_ptr consumerPtr; + RETURN_IF_NOT_OK(GetConsumer(consumerId, consumerPtr)); + uint64_t consumerAck = consumerPtr->GetWALastAckCursor(); + std::lock_guard lock(mutex_); + if (ConsumerNum() == 1) { // If only one consumer left (stream mode or queue mode with the last consumer). + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[%s] Remove this sub since its last consumer is closed", + LogPrefix()); + consumers_.clear(); + } else { // Otherwise + CHECK_FAIL_RETURN_STATUS( + consumers_.erase(consumerId) == 1, StatusCode::K_RUNTIME_ERROR, + FormatString("Failed to remove consumer by consumerId %s in current Subscription", consumerId)); + if (consumerAck == lastSubAckCursor_) { + // If target consumer's ack is the minimum in this subscription, we should update ack cursor. + lastSubAckCursor_ = CalcMinAckCursorNoLock(); + } + } + return Status::OK(); +} + +SubscriptionType Subscription::GetSubscriptionType() const +{ + PerfPoint point(PerfKey::MANAGER_GET_SUB_TYPE); + std::shared_lock lock(mutex_); + return subConfig_.subscriptionType; +} + +Status Subscription::GetConsumer(const std::string &consumerId, std::shared_ptr &consumer) +{ + PerfPoint point(PerfKey::MANAGER_GET_CONSUMER); + std::shared_lock lock(mutex_); + auto iter = consumers_.find(consumerId); + if (iter == consumers_.end()) { + RETURN_STATUS(StatusCode::K_NOT_FOUND, "Consumer not found " + consumerId); + } + RETURN_RUNTIME_ERROR_IF_NULL(iter->second); + consumer = iter->second; + point.Record(); + return Status::OK(); +} + +std::string Subscription::LogPrefix() const +{ + return FormatString("Sub:%s", subConfig_.subscriptionName); +} + +uint64_t Subscription::CalcMinAckCursorNoLock() const +{ + uint64_t newAckCursor = std::numeric_limits::max(); + for (auto &ele : consumers_) { + uint64_t lastAckCursor = ele.second->GetWALastAckCursor(); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s %s] lastAckCursor = %zu", LogPrefix(), ele.second->LogPrefix(), + lastAckCursor); + newAckCursor = std::min(newAckCursor, lastAckCursor); + } + return newAckCursor; +} + +uint64_t Subscription::UpdateLastAckCursor() +{ + std::shared_lock lock(mutex_); + auto newAckCursor = CalcMinAckCursorNoLock(); + bool subAckForward = newAckCursor > lastSubAckCursor_; + if (subAckForward) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("Update min ack cursor from %zu to %zu for subscription %s", + lastSubAckCursor_.load(), newAckCursor, subConfig_.subscriptionName); + lastSubAckCursor_.store(newAckCursor); + } + return lastSubAckCursor_.load(std::memory_order_relaxed); +} + +Status Subscription::TryWakeUpPendingReceive(uint64_t lastAppendCursor) +{ + if (subConfig_.subscriptionType != SubscriptionType::STREAM) { + RETURN_STATUS(StatusCode::K_INVALID, "Only support stream mode"); + } + std::shared_lock lock(mutex_); + for (const auto &consumer : consumers_) { + RETURN_IF_NOT_OK(consumer.second->WakeUpPendingReceive(lastAppendCursor)); + } + return Status::OK(); +} + +Status Subscription::SetForceClose() +{ + if (subConfig_.subscriptionType != SubscriptionType::STREAM) { + RETURN_STATUS(StatusCode::K_INVALID, "Only support stream mode"); + } + std::shared_lock lock(mutex_); + Status rc; + for (const auto &consumer : consumers_) { + Status rc1 = consumer.second->SetForceClose(); + if (rc.IsOk()) { + rc = rc1; + } + } + return rc; +} + +uint64_t Subscription::GetElementCountAndReset() +{ + uint64_t val = 0; + std::shared_lock lock(mutex_); + for (auto &consumer : consumers_) { + val += consumer.second->GetElementCountAndReset(); + } + return val; +} + +uint64_t Subscription::GetElementCountReceived() +{ + uint64_t val = std::numeric_limits::max(); + uint64_t count = 0; + std::shared_lock lock(mutex_); + for (auto &consumer : consumers_) { + count = consumer.second->GetElementCount(); + if (val > count) { + val = count; + } + } + return val; +} + +uint64_t Subscription::GetRequestCountAndReset() +{ + uint64_t val = 0; + std::shared_lock lock(mutex_); + for (auto &consumer : consumers_) { + val += consumer.second->GetRequestCountAndReset(); + } + return val; +} + +void Subscription::GetAllConsumers(std::vector &consumers) const +{ + std::shared_lock lock(mutex_); + for (const auto &kv : consumers_) { + auto &consumerName = kv.first; + consumers.emplace_back(consumerName); + } +} + +void Subscription::CleanupSubscription() +{ + std::shared_lock lock(mutex_); + for (const auto &kv : consumers_) { + kv.second->CleanupConsumer(); + } + lastSubAckCursor_ = 0; +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/subscription.h b/src/datasystem/worker/stream_cache/subscription.h new file mode 100644 index 0000000..28beca6 --- /dev/null +++ b/src/datasystem/worker/stream_cache/subscription.h @@ -0,0 +1,182 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_SUBSCRIPTION_H +#define DATASYSTEM_WORKER_STREAM_CACHE_SUBSCRIPTION_H + +#include "datasystem/common/shared_memory/shm_unit_info.h" +#include "datasystem/common/stream_cache/cursor.h" +#include "datasystem/stream/stream_config.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class Consumer; +/** + * @brief Two Modes, workload assignment logic is here. + * @details For assignment see AssignCursors. + */ +class Subscription { +public: + /** + * @brief Create a new Subscription object. + * @param[in] subConfig The config fo this subscription. + * @param[in] lastStreamAck The stream ackCursor, which should be recorded by Subscription. + */ + Subscription(SubscriptionConfig subConfig, uint64_t lastStreamAck, std::string streamName); + ~Subscription() = default; + + /** + * @brief Subscribe to a stream, using a subscription name, i.e., register a consumer to a subscription. + * @param[in] config The subscription config. + * @param[in] consumerId The id of consumer. + * @param[out] lastAckCursor The last ack cursor of the consumer. + * @return Status of the call. + */ + Status AddConsumer(const SubscriptionConfig &config, const std::string &consumerId, uint64_t lastAckCursor, + std::shared_ptr cursor); + + /** + * @brief Close a consumer, trigger subscription cursor change and unregister a subscribed consumer to a stream. + * @param[in] consumerId Consumer id. + * @return Status of the call. + */ + Status RemoveConsumer(const std::string &consumerId); + + /** + * @brief Get consumer by id. + * @param[in] consumerId The id of consumer. + * @param[out] consumer The output consumer. + * @return Status of the call. + */ + Status GetConsumer(const std::string &consumerId, std::shared_ptr &consumer); + + /** + * @brief Get subscription type. + * @return SubscriptionType of the subscription. + */ + SubscriptionType GetSubscriptionType() const; + + /** + * @brief Get last ack cursor of the subscription. + * @return The last ack cursor of the subscription. + */ + uint64_t GetLastSubAckCursor() const + { + return lastSubAckCursor_; + } + + /** + * @brief Wake up pending receive if the element is enough. + * @param[in] lastAppendCursor The last append cursor of the stream. + * @return Status of the call. + */ + Status TryWakeUpPendingReceive(uint64_t lastAppendCursor); + + /** + * @brief Wake up pending receives if a producer is forcing consumers to be interrupted. + * @return Status of the call. + */ + Status SetForceClose(); + + /** + * @brief Get the num of consumer. + * @return The num of consumer in one subscription. + */ + size_t ConsumerNum() const + { + return consumers_.size(); + } + + /** + * @brief Identity if or not this subscription has one consumer. + * @return True if this subscription has one consumer. + */ + bool HasConsumer() const + { + std::shared_lock lock(mutex_); + return !consumers_.empty(); + } + + void GetAllConsumers(std::vector &consumers) const; + + /** + * @brief Get the subscription config. + * @return The subscription config. + */ + const SubscriptionConfig &GetSubscriptionConfig() const + { + return subConfig_; + } + + /** + * @brief Get log prefix + * @return The log prefix + */ + std::string LogPrefix() const; + + /** + * @brief Garbage collection by scanning all consumers' last ack cursors + * @param[out] last ack cursor + */ + uint64_t UpdateLastAckCursor(); + + /** + * @brief Get the element count and reset it to 0. + * @return + */ + uint64_t GetElementCountAndReset(); + + /** + * @brief Get the received element count. + * @return + */ + uint64_t GetElementCountReceived(); + + /** + * @brief Get the request count. + * @return + */ + uint64_t GetRequestCountAndReset(); + + const std::string SubName() const + { + return subConfig_.subscriptionName; + } + + /** + * @brief Cleanup indexes for this subscription + */ + void CleanupSubscription(); + +protected: + /** + * @brief Get the min ack cursor of all the consumers of this subscription. + * @return The min ack cursor of all the consumers of this subscription. + */ + uint64_t CalcMinAckCursorNoLock() const; + +private: + const SubscriptionConfig subConfig_; + const std::string streamName_; + std::atomic lastSubAckCursor_; + mutable std::shared_timed_mutex mutex_; // protect consumers_ + std::unordered_map> consumers_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif // DATASYSTEM_WORKER_STREAM_CACHE_SUBSCRIPTION_H diff --git a/src/datasystem/worker/stream_cache/usage_monitor.cpp b/src/datasystem/worker/stream_cache/usage_monitor.cpp new file mode 100644 index 0000000..e79573f --- /dev/null +++ b/src/datasystem/worker/stream_cache/usage_monitor.cpp @@ -0,0 +1,406 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "datasystem/worker/stream_cache/usage_monitor.h" +#include + +#include "datasystem/common/log/log_helper.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +UsageItem::UsageItem() : usage(0), usageBlocked(false) +{ +} + +UsageItem::UsageItem(std::string streamName, std::string remoteWorkerAddr, std::uint64_t usage) + : streamName(std::move(streamName)), + remoteWorkerAddr(std::move(remoteWorkerAddr)), + usage(usage), + usageBlocked(false) +{ +} + +MemReserveEntry::MemReserveEntry(uint64_t reserve) : reserveSize(reserve), usedSize(0) +{ +} + +UsageMonitor::UsageMonitor(ClientWorkerSCServiceImpl *clientWorkerScService, const uint64_t maxBufferPoolMem) + : clientWorkerScService_(clientWorkerScService), + totalUsedSize_(0), + maxBufferPoolMem_(maxBufferPoolMem), + interrupt_(false) +{ +} + +Status UsageMonitor::Init() +{ + producerBlockerThreadPool_ = std::make_unique(1, 0, "ScUsageMonitor"); + try { + producerBlockerThreadPool_->Execute([this]() { BlockProducersIfNeeded(); }); + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS(K_RUNTIME_ERROR, e.what()); + } +} + +void UsageMonitor::Stop() +{ + interrupt_ = true; + cv_.notify_all(); + producerBlockerThreadPool_.reset(); +} + +Status UsageMonitor::IncUsage(const std::string &streamName, const std::string &workerAddr, const std::uint64_t size) +{ + std::shared_lock l(usageMutex_); + + // Update per stream usage + TbbStreamReserveTable::Accessor accessorStTbl; + if (streamMemoryMap_.Find(accessorStTbl, streamName)) { + accessorStTbl.entry->data->usedSize += size; + } else { + LOG(WARNING) << "Stream name not found for the reservation"; + } + + return IncTotalUsageUnlocked(streamName, workerAddr, size); +} + +Status UsageMonitor::IncTotalUsageUnlocked(const std::string &streamName, const std::string &workerAddr, + const std::uint64_t size) +{ + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Increase Usage for stream " << streamName << " Remote worker " + << workerAddr << " by size " << size << " max size available " << maxBufferPoolMem_; + std::string id = streamName + workerAddr; + // Update per stream per remote worker usage + TbbUsageTable::accessor accessor; + if (usage_.insert(accessor, id)) { + // Insert new item + accessor->second = std::make_shared(streamName, workerAddr, size); + } else { + // Update usage. + accessor->second->usage += size; + } + + // Update total usage count + totalUsedSize_.fetch_add(size, std::memory_order_relaxed); + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Total memory used " << totalUsedSize_; + return Status::OK(); +} + +void UsageMonitor::DecUsage(const std::uint64_t size) +{ + do { + uint64_t val = totalUsedSize_.load(); + uint64_t updatedVal = val; + // Update total usage count + if (updatedVal > size) { + updatedVal -= size; + } else { + updatedVal = 0; + } + if (totalUsedSize_.compare_exchange_strong(val, updatedVal)) { + return; + } + } while (true); +} + +Status UsageMonitor::DecUsage(const std::string &streamName, const std::string &workerAddr, const std::uint64_t size) +{ + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Decrease Usage for stream " << streamName << " Remote worker " + << workerAddr << " by size " << size << " max size available " << maxBufferPoolMem_; + std::string id = streamName + workerAddr; + std::shared_lock l(usageMutex_); + // Update per stream usage + TbbStreamReserveTable::Accessor accessorStTbl; + if (streamMemoryMap_.Find(accessorStTbl, streamName)) { + accessorStTbl.entry->data->usedSize -= size; + } + + // Update per stream per remote worker usage + TbbUsageTable::accessor accessor; + if (usage_.find(accessor, id)) { + // If usage is 0 or less delete it + if (accessor->second->usage > size) { + accessor->second->usage -= size; + } else { + (void)usage_.erase(accessor); + } + } else { + // If key not found its a error. Usage is never recorded + RETURN_STATUS(StatusCode::K_NOT_FOUND, "usage key not found"); + } + + DecUsage(size); + + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Total memory used " << totalUsedSize_; + return Status::OK(); +} + +Status UsageMonitor::RemoveUsageStats(const std::string &streamName, const std::string &workerAddr) +{ + std::string id = streamName + workerAddr; + std::shared_lock l(usageMutex_); + TbbUsageTable::accessor accessor; + if (usage_.find(accessor, id)) { + (void)usage_.erase(accessor); + } else { + RETURN_STATUS(StatusCode::K_NOT_FOUND, "usage key not found"); + } + return Status::OK(); +} + +Status UsageMonitor::CheckOverUsed(const double threshold, const uint64_t size) +{ + if (totalUsedSize_.load() + size > maxBufferPoolMem_ * threshold) { + VLOG(SC_NORMAL_LOG_LEVEL) << "BufferPool is out of Memory, Total used: " << totalUsedSize_.load() + << " Total allocated " << maxBufferPoolMem_ << " Limited by " + << maxBufferPoolMem_ * threshold; + RETURN_STATUS(StatusCode::K_OUT_OF_MEMORY, "BufferPool is out of memory"); + } + return Status::OK(); +} + +Status UsageMonitor::CheckNIncOverUsedForStream(const std::string &streamName, const std::string &workerAddr, + const uint64_t lowerBound, const double threshold, const uint64_t size) +{ + INJECT_POINT("worker.UsageMonitor.CheckOverUsedForStream.MockError"); + // Update per stream per remote worker usage + auto limit = std::max(lowerBound, (maxBufferPoolMem_ * threshold)); + std::shared_lock l(usageMutex_); + TbbStreamReserveTable::Accessor accessor; + if (streamMemoryMap_.Find(accessor, streamName)) { + auto &entry = accessor.entry->data; + if (entry->usedSize > limit) { + RETURN_STATUS(StatusCode::K_OUT_OF_MEMORY, + FormatString("BufferPool is out of memory per stream %s, TotalUsed=%llu, LimitedBy=%f", + streamName, entry->usedSize, limit)); + } + uint64_t remainingReservedSize = entry->reserveSize - std::min(entry->reserveSize, entry->usedSize); + // Best effort check remaining fair share reserved memory + available mutual memory is enough for the request + uint64_t totalUsedSize = totalUsedSize_.load(); + uint64_t totalReservedSize = totalReservedSize_.load(); + if (std::max(totalUsedSize, totalReservedSize) + size + > remainingReservedSize + maxBufferPoolMem_ * DEFAULT_THRESHOLD) { + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "BufferPool is out of Memory, Total used: %llu, Total reserved: %llu, Total allocated %llu Limited by " + "%llu", + totalUsedSize, totalReservedSize, maxBufferPoolMem_, maxBufferPoolMem_ * threshold); + RETURN_STATUS(StatusCode::K_OUT_OF_MEMORY, "BufferPool is out of memory"); + } + INJECT_POINT("CheckNIncOverUsedForStream.TbbStreamReserveTable.CPU"); + // Increase the memory usage so that the reserved memory will not be counted for multiple times + entry->usedSize += size; + IncTotalUsageUnlocked(streamName, workerAddr, size); + return Status::OK(); + } + // The memory has to be reserved upfront + RETURN_STATUS(StatusCode::K_INVALID, "Local cache memory is not reserved."); +} + +Status UsageMonitor::GetMostUsed(std::shared_ptr &usageItem) +{ + std::lock_guard l(usageMutex_); + if (usage_.empty()) { + RETURN_STATUS(StatusCode::K_INVALID, "usage vector is empty"); + } + auto iter = std::max_element(usage_.begin(), usage_.end(), [](auto a, auto b) { + return ((a.second->usage < b.second->usage) && !b.second->usageBlocked); + }); + RETURN_RUNTIME_ERROR_IF_NULL(iter->second); + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Most used memory stream " << iter->second->streamName << " producer " + << iter->second->remoteWorkerAddr << " by size " << iter->second->usage; + usageItem = iter->second; + return Status::OK(); +} + +Status UsageMonitor::BlockUsage(std::shared_ptr &usageItem) +{ + Status rc = clientWorkerScService_->SendBlockProducerReq(usageItem->streamName, usageItem->remoteWorkerAddr); + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Usage Blocked for stream: " << usageItem->streamName + << " remote producer " << usageItem->remoteWorkerAddr; + // If ok or if stream or producer is already deleted then make blocked true + // If producer is already gone no need to block it + if (rc.IsOk() || rc.GetCode() == StatusCode::K_SC_STREAM_NOT_FOUND + || rc.GetCode() == StatusCode::K_SC_PRODUCER_NOT_FOUND) { + usageItem->usageBlocked = true; + LOG_IF_ERROR(rc, "Error while blocking"); + return Status::OK(); + } + return rc; +} + +Status UsageMonitor::UnBlockUsage(std::shared_ptr &usageItem) +{ + Status rc = clientWorkerScService_->SendUnBlockProducerReq(usageItem->streamName, usageItem->remoteWorkerAddr); + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Usage UnBlocked for stream: " << usageItem->streamName + << " remote producer " << usageItem->remoteWorkerAddr; + // If ok or if stream or producer is already deleted then make blocked false + // If producer is already gone no need to unblock it + if (rc.IsOk() || rc.GetCode() == StatusCode::K_SC_STREAM_NOT_FOUND + || rc.GetCode() == StatusCode::K_SC_PRODUCER_NOT_FOUND) { + usageItem->usageBlocked = false; + LOG_IF_ERROR(rc, "Error while unblocking"); + return Status::OK(); + } + return rc; +} + +bool UsageMonitor::WaitForOverUseCondition(const uint64_t timeoutMs, const double threshold) +{ + std::unique_lock lock(mux_); + // CheckOverused returns error when no memory is available + return cv_.wait_for(lock, std::chrono::milliseconds(timeoutMs), + [this, threshold]() { return CheckOverUsed(threshold).IsError(); }); +} + +Status UsageMonitor::BlockMostUsed() +{ + // Get producer that produces most + std::shared_ptr usageItem; + RETURN_IF_NOT_OK(GetMostUsed(usageItem)); + // Block the producer + RETURN_IF_NOT_OK(BlockUsage(usageItem)); + // Add it to the list so that we can unblock them later + blockedStreamProducers_.push_back(usageItem); + return Status::OK(); +} + +Status UsageMonitor::UnBlockAllProducers(const double unBlockThreshold) +{ + // If more than unBlockThreshold of Memory is available + // CheckOverUsed OK means memory is available + if (!blockedStreamProducers_.empty() && CheckOverUsed(unBlockThreshold).IsOk()) { + for (auto Iter = blockedStreamProducers_.begin(); Iter != blockedStreamProducers_.end();) { + // we only try once if does not work it will timeout at sender side + LOG_IF_ERROR(UnBlockUsage(*Iter), "Error in unblocking producer"); + Iter = blockedStreamProducers_.erase(Iter); + } + } + return Status::OK(); +} + +void UsageMonitor::BlockProducersIfNeeded() +{ + const uint64_t timeoutMs = 100; + // We block when 90% of memory is used + // Unblock when more than 30% of memory is available (i.e. < 70% in use) + const auto blockThreshold = 0.9; + const auto unBlockThreshold = 0.7; + const auto waitTimeSecs = 1; + while (true) { + // Wait on the cv for 0.1s for work or interrupt + auto overUsed = WaitForOverUseCondition(timeoutMs, blockThreshold); + if (interrupt_) { + VLOG(SC_INTERNAL_LOG_LEVEL) << "BlockProducers thread exits"; + break; + } + // Memory is available + if (!overUsed) { + // unblock all producers that were blocked + UnBlockAllProducers(unBlockThreshold); + continue; // No need to block continue to check + } + // Memory is not available + // Block the producer that uses most memory + auto rc = BlockMostUsed(); + if (rc.IsError()) { + LOG_IF_ERROR(rc, "Error while blocking most used producer"); + continue; // try again + } + std::this_thread::sleep_for(std::chrono::seconds(waitTimeSecs)); // Wait for memory usage to reduce + } +} + +Status UsageMonitor::ReserveMemory(const std::string &streamName, size_t reserveSize) +{ + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Reserve memory for: " << streamName << " with size = " << reserveSize; + std::shared_lock l(usageMutex_); + TbbStreamReserveTable::Accessor accessor; + if (streamMemoryMap_.Insert(accessor, streamName)) { + // As long as the total reserved memory is still in bound, it is allowed to reserve. + auto func = [this, reserveSize]() { + RETURN_IF_NOT_OK(CheckNIncTotalReservedMemory(reserveSize)); + return Status::OK(); + }; + Status rc = func(); + if (rc.IsError()) { + streamMemoryMap_.BlockingErase(accessor); + return rc; + } + accessor.entry->data = std::make_shared(reserveSize); + } else { + // If the reservation already exists, update the entry if applicable. + // The reserve size should be max between chunk size and page size, so it should not be less. + auto &entry = accessor.entry->data; + if (reserveSize > entry->reserveSize) { + uint64_t difference = reserveSize - entry->reserveSize; + RETURN_IF_NOT_OK(CheckNIncTotalReservedMemory(difference)); + entry->reserveSize = reserveSize; + } + } + return Status::OK(); +} + +void UsageMonitor::UndoReserveMemory(const std::string &streamName) +{ + VLOG(SC_NORMAL_LOG_LEVEL) << "[UsageMonitor] Undo the memory reservation for: " << streamName; + std::shared_lock l(usageMutex_); + TbbStreamReserveTable::Accessor accessor; + if (streamMemoryMap_.Find(accessor, streamName)) { + const auto &entry = accessor.entry->data; + totalReservedSize_ -= entry->reserveSize; + streamMemoryMap_.BlockingErase(accessor); + } +} + +uint64_t UsageMonitor::GetLocalMemoryUsed(const std::string &streamName) +{ + TbbStreamReserveTable::ConstAccessor accessor; + uint64_t val = 0; + if (streamMemoryMap_.Find(accessor, streamName)) { + val = accessor.entry->data->usedSize; + } + return val; +} + +std::string UsageMonitor::GetLocalMemoryUsed() +{ + return FormatString("%lu/%lu/%lu/%.3f", totalUsedSize_.load(), totalReservedSize_.load(), maxBufferPoolMem_, + totalUsedSize_ / static_cast(maxBufferPoolMem_)); +} + +Status UsageMonitor::CheckNIncTotalReservedMemory(uint64_t reserveSize) +{ + bool success = false; + do { + uint64_t totalReservedSize = totalReservedSize_.load(); + uint64_t remainingForReserve = maxBufferPoolMem_ - totalReservedSize; + // If not enough memory left for reservation, fail the CreateProducer/Subscribe + CHECK_FAIL_RETURN_STATUS( + remainingForReserve >= reserveSize, K_OUT_OF_MEMORY, + FormatString("Reserve local cache memory failed, need %d, remaining %d", reserveSize, remainingForReserve)); + success = totalReservedSize_.compare_exchange_weak(totalReservedSize, totalReservedSize + reserveSize, + std::memory_order_release, std::memory_order_acquire); + } while (!success); + return Status::OK(); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/usage_monitor.h b/src/datasystem/worker/stream_cache/usage_monitor.h new file mode 100644 index 0000000..1031fea --- /dev/null +++ b/src/datasystem/worker/stream_cache/usage_monitor.h @@ -0,0 +1,245 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_USAGE_MANAGER_H +#define DATASYSTEM_WORKER_STREAM_CACHE_USAGE_MANAGER_H + +#include + +#include + +#include "datasystem/common/util/lock_map.h" +#include "datasystem/common/util/thread_pool.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +class ClientWorkerSCServiceImpl; +static constexpr double DEFAULT_THRESHOLD = 1.0; +/** + * UsageManager creates a thread that continuously monitor + * The size of BufferPool if it goes above a threshold + * Sends blocking RPC call if this happens + */ +struct UsageItem { + std::string streamName; + std::string remoteWorkerAddr; + std::uint64_t usage; + bool usageBlocked; + UsageItem(); + UsageItem(std::string streamName, std::string remoteWorkerAddr, std::uint64_t usage); +}; + +struct MemReserveEntry { + // The reserved memory size. + uint64_t reserveSize; + // The memory used by the stream. + uint64_t usedSize; + MemReserveEntry(uint64_t reserve); +}; + +using TbbUsageTable = tbb::concurrent_hash_map>; +using TbbStreamReserveTable = LockMap>; +class UsageMonitor { +public: + /** + * @brief Construct the UsageHeap. + * @param[in] maxMemThresholdBytes BufferPool memory Limit set by user + */ + UsageMonitor(ClientWorkerSCServiceImpl *clientWorkerScService, const uint64_t maxMemThresholdBytes); + + ~UsageMonitor() = default; + + /** + * @brief Init function + * @return + */ + Status Init(); + + /** + * @brief Shutdown usage monitor + */ + void Stop(); + + /** + * @brief Adds usage of a stream and remote worker to BufferPool + * @param[in] streamName stream name + * @param[in] workerAddr remote worker address + * @param[in] size size of new PageView added + * @return Status of the call. + */ + Status IncUsage(const std::string &streamName, const std::string &workerAddr, const std::uint64_t size); + + /** + * @brief Decreases usage of a stream and remote worker from BufferPool + * @param[in] streamName stream name + * @param[in] workerAddr remote worker address + * @param[in] size size of new PageView added + * @return Status of the call. + */ + Status DecUsage(const std::string &streamName, const std::string &workerAddr, const std::uint64_t size); + + /** + * @brief Removes a Stream and Remote Worker from usage stats + * @param[in] streamName stream name + * @param[in] workerAddr remote worker address + * @return Status of the call. + */ + Status RemoveUsageStats(const std::string &streamName, const std::string &workerAddr); + + /** + * @brief Does Current BufferPool Usage Exceeds the User defined limit? + * @param[in] threshold e.g. if set to 0.8 will check 80% of max + * @param[in] size The size for the check + * @return Status of the call. + */ + Status CheckOverUsed(const double threshold = DEFAULT_THRESHOLD, const uint64_t size = 0); + + /** + * @brief Check if the stream Exceeds the User defined ratio? And also increase the memory usage accordingly. + * @param[in] streamName stream name + * @param[in] workerAddr remote worker address + * @param[in] lowerBound stream lower bound limit + * @param[in] threshold % of total memory allowed for the stream (ratio) + * @param[in] size The size for the check + * @return Status of the call. + */ + Status CheckNIncOverUsedForStream(const std::string &streamName, const std::string &workerAddr, + const uint64_t lowerBound, const double threshold, const uint64_t size); + + /** + * @brief Gets the StreamName and Remote Worker combination that uses most space + * @param[out] usageItem streamName and Remote Worker + * @return Status of the call. + */ + Status GetMostUsed(std::shared_ptr &usageItem); + + /** + * @brief Reserve local cache memory for stream + * @param[in] streamName stream name + * @param[in] reserveSize The size to reserve + * @return Status of the call. + */ + Status ReserveMemory(const std::string &streamName, size_t reserveSize); + + /** + * @brief Undo the reserved local cache memory for stream + * @param[in] streamName stream name + */ + void UndoReserveMemory(const std::string &streamName); + + /** + * @brief Gets amount of local memory used for a stream + * @param[in] streamName stream name + * @return The amount of memory + */ + uint64_t GetLocalMemoryUsed(const std::string &streamName); + + /** + * @brief Get the usage of scLocalCache. The format is totalUsedSize/totalReservedSize/totalLimit/usage + * @return The amount of memory + */ + std::string GetLocalMemoryUsed(); + +private: + /** + * @brief If Total memory exceeds user set limit + * Block Most offending producers + */ + void BlockProducersIfNeeded(); + + /** + * @brief Blocks usage for the stream and remote worker + * @param[in] usageItem gives stream name and remote worker + * @return Status of the call. + */ + Status BlockUsage(std::shared_ptr &usageItem); + + /** + * @brief UnBlocks usage for the stream and remote worker + * @param[in] usageItem gives stream name and remote worker + * @return Status of the call. + */ + Status UnBlockUsage(std::shared_ptr &usageItem); + + /** + * @brief waits for timeoutMs, checks for the memory usage + * @param[in] timeoutMs gives stream name and remote worker + * @param[in] threshold e.g. if set to 0.8 will check 80% of max + * @return Status of the call. + */ + bool WaitForOverUseCondition(const uint64_t timeoutMs, const double threshold); + + /** + * @brief Find producer that used local cache most and send a blocking call + * @return Status of the call. + */ + Status BlockMostUsed(); + + /** + * @brief Unblock all previously blocked producer if memory availability exceeds threshold + * @param[in] unBlockThreshold e.g. if set to 0.8, usage should be less than 80% of max + * @return Status of the call. + */ + Status UnBlockAllProducers(const double unBlockThreshold); + + /** + * @brief Helper function to decrease the total memory usage. + * @return Status of the call. + */ + void DecUsage(const std::uint64_t size); + + /** + * @brief Helper function to increase the total usage of a stream and usage regarding remote worker to BufferPool + * @param[in] streamName stream name + * @param[in] workerAddr remote worker address + * @param[in] size size of new PageView added + * @return Status of the call. + */ + Status IncTotalUsageUnlocked(const std::string &streamName, const std::string &workerAddr, + const std::uint64_t size); + + /** + * @brief Helper function to increment total reserved memory while make sure the limit is not exceeded. + * @param[in] reserveSize The size to increment with. + * @return Status of the call. + */ + Status CheckNIncTotalReservedMemory(uint64_t reserveSize); + + ClientWorkerSCServiceImpl *clientWorkerScService_; + // Backend thread to check memory usage and invoke blocking callbacks + std::unique_ptr producerBlockerThreadPool_; + // protect for tbbUsageTable usage_; + mutable std::shared_timed_mutex usageMutex_; + // Stores usage per stream per remote worker + TbbUsageTable usage_; + // Stores total size of all elements in the BufferPool + std::atomic_uint64_t totalUsedSize_; + // Max memory that can be used by Buffer pool + const uint64_t maxBufferPoolMem_; + // The memory reservation helpers, includes the usage per stream + TbbStreamReserveTable streamMemoryMap_; + std::atomic_uint64_t totalReservedSize_{ 0 }; + // Used to Interrupt Background Thread + std::atomic interrupt_; + mutable std::mutex mux_; // protect for cv_; + std::condition_variable cv_; + std::vector> blockedStreamProducers_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/worker_master_sc_api.cpp b/src/datasystem/worker/stream_cache/worker_master_sc_api.cpp new file mode 100644 index 0000000..678744d --- /dev/null +++ b/src/datasystem/worker/stream_cache/worker_master_sc_api.cpp @@ -0,0 +1,315 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/worker_master_sc_api.h" + +#include + +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/rpc/rpc_auth_key_manager.h" +#include "datasystem/common/rpc/rpc_stub_base.h" +#include "datasystem/common/rpc/rpc_stub_cache_mgr.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/thread_local.h" +#include "datasystem/master/stream_cache/master_sc_service_impl.h" +#include "datasystem/utils/optional.h" + +DS_DECLARE_string(unix_domain_socket_dir); + +namespace datasystem { +namespace worker { +namespace stream_cache { +static constexpr int64_t WORKER_TIMEOUT_MINUS_MILLISECOND = 5 * 1000; +static constexpr float WORKER_TIMEOUT_DESCEND_FACTOR = 0.9; + +#define CHECK_AND_SET_TIMEOUT(timeoutDuration_, request_, opts_) \ + do { \ + int64_t remainingTime_ = timeoutDuration_.CalcRemainingTime(); \ + CHECK_FAIL_RETURN_STATUS(remainingTime_ > 0, K_RPC_DEADLINE_EXCEEDED, \ + FormatString("Request timeout (%lld ms).", -remainingTime_)); \ + request_.set_timeout(WorkerGetRequestTimeout(remainingTime_)); \ + opts_.SetTimeout(remainingTime_); \ + } while (false) + +inline int64_t WorkerGetRequestTimeout(int32_t timeout) +{ + return std::max(int64_t(timeout * WORKER_TIMEOUT_DESCEND_FACTOR), timeout - WORKER_TIMEOUT_MINUS_MILLISECOND); +} + +// Base class methods +WorkerMasterSCApi::WorkerMasterSCApi(const HostPort &localWorkerAddress, std::shared_ptr akSkManager) + : localWorkerAddress_(localWorkerAddress), akSkManager_(std::move(akSkManager)) +{ +} + +std::shared_ptr WorkerMasterSCApi::CreateWorkerMasterSCApi(const HostPort &hostPort, + const HostPort &localHostPort, + std::shared_ptr akSkManager, + master::MasterSCServiceImpl *service) +{ + if (hostPort != localHostPort) { + LOG(INFO) << "Worker and master are not collocated. Creating a WorkerMasterSCApi as RPC-based api."; + return std::make_shared(hostPort, localHostPort, akSkManager); + } + + if (service == nullptr) { + LOG(INFO) << "Worker and master are collocated but the master service is not provided. Local bypass disabled."; + return std::make_shared(hostPort, localHostPort, akSkManager); + } + + LOG(INFO) << "Worker and master are collocated. Creating a WorkerMasterSCApi with local bypass optimization."; + return std::make_shared(service, localHostPort, akSkManager); +} + +// WorkerRemoteMasterSCApi methods + +WorkerRemoteMasterSCApi::WorkerRemoteMasterSCApi(const HostPort &masterAddress, const HostPort &localHostPort, + std::shared_ptr akSkManager) + : WorkerMasterSCApi(localHostPort, akSkManager), masterAddress_(masterAddress) +{ +} + +Status WorkerRemoteMasterSCApi::Init() +{ + std::shared_ptr rpcStub; + RETURN_IF_NOT_OK( + RpcStubCacheMgr::Instance().GetStub(masterAddress_, StubType::WORKER_MASTER_SC_SVC, rpcStub)); + rpcSession_ = std::dynamic_pointer_cast(rpcStub); + RETURN_RUNTIME_ERROR_IF_NULL(rpcSession_); + return Status::OK(); +} + +Status WorkerRemoteMasterSCApi::CreateProducer(master::CreateProducerReqPb &req, master::CreateProducerRspPb &rsp) +{ + RpcOptions opts; + CHECK_AND_SET_TIMEOUT(scTimeoutDuration, req, opts); + INJECT_POINT("worker.CreateProducer.beforeSendToMaster", [&opts](const std::string &code) { + const int timeout = 10 * 1000; // 10s; + opts.SetTimeout(timeout); + RETURN_STATUS(GetStatusCodeByName(code), "inject status"); + }); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->CreateProducer(opts, req, rsp)); + INJECT_POINT("worker.CreateProducer.afterSendToMaster"); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Add new pub node on master success", LogPrefix(), + req.producer_meta().stream_name()); + return Status::OK(); +} + +Status WorkerRemoteMasterSCApi::CloseProducer(master::CloseProducerReqPb &req, master::CloseProducerRspPb &rsp) +{ + RpcOptions opts; + CHECK_AND_SET_TIMEOUT(scTimeoutDuration, req, opts); + INJECT_POINT("worker.CloseProducer.beforeSendToMaster", [&opts](const std::string &code) { + const int timeout = 10 * 1000; // 10s; + opts.SetTimeout(timeout); + RETURN_STATUS(GetStatusCodeByName(code), "inject status"); + }); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->CloseProducer(opts, req, rsp)); + INJECT_POINT("worker.CloseProducer.afterSendToMaster"); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Closing %d producers on master success", LogPrefix(), + req.producer_infos_size()); + return Status::OK(); +} + +Status WorkerRemoteMasterSCApi::Subscribe(master::SubscribeReqPb &req, master::SubscribeRspPb &rsp) +{ + // Construct master::SubscribeReqPb req + RpcOptions opts; + CHECK_AND_SET_TIMEOUT(scTimeoutDuration, req, opts); + INJECT_POINT("worker.Subscribe.beforeSendToMaster", [&opts](const std::string &code) { + const int timeout = 10 * 1000; // 10s; + opts.SetTimeout(timeout); + RETURN_STATUS(GetStatusCodeByName(code), "inject status"); + }); + + INJECT_POINT("worker.Subscribe.sleepReturnTimeout", [&opts](int sleepTimeMs) { + std::this_thread::sleep_for(std::chrono::milliseconds(sleepTimeMs)); + // Simulates timeout + opts.SetTimeout(0); + return Status::OK(); + }); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->Subscribe(opts, req, rsp)); + INJECT_POINT("worker.Subscribe.afterSendToMaster"); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s] Add new consumer on master succeeded", LogPrefix(), + req.consumer_meta().stream_name(), req.consumer_meta().consumer_id()); + return Status::OK(); +} + +Status WorkerRemoteMasterSCApi::CloseConsumer(master::CloseConsumerReqPb &req, master::CloseConsumerRspPb &rsp) +{ + RpcOptions opts; + CHECK_AND_SET_TIMEOUT(scTimeoutDuration, req, opts); + INJECT_POINT("worker.CloseConsumer.beforeSendToMaster", [&opts](const std::string &code) { + const int timeout = 10 * 1000; // 10s; + opts.SetTimeout(timeout); + RETURN_STATUS(GetStatusCodeByName(code), "inject status"); + }); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->CloseConsumer(opts, req, rsp)); + INJECT_POINT("worker.CloseConsumer.afterSendToMaster"); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s] Delete consumer on master succeeded", LogPrefix(), + req.consumer_meta().stream_name(), req.consumer_meta().consumer_id()); + return Status::OK(); +} + +Status WorkerRemoteMasterSCApi::DeleteStream(master::DeleteStreamReqPb &req, master::DeleteStreamRspPb &rsp) +{ + RpcOptions opts; + CHECK_AND_SET_TIMEOUT(scTimeoutDuration, req, opts); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(rpcSession_->DeleteStream(opts, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Delete stream succeeded.", LogPrefix(), req.stream_name()); + return Status::OK(); +} + +Status WorkerRemoteMasterSCApi::QueryGlobalProducersNum(master::QueryGlobalNumReqPb &req, + master::QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + return rpcSession_->QueryGlobalProducersNum(req, rsp); +} + +Status WorkerRemoteMasterSCApi::QueryGlobalConsumersNum(master::QueryGlobalNumReqPb &req, + master::QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + return rpcSession_->QueryGlobalConsumersNum(req, rsp); +} + +std::string WorkerRemoteMasterSCApi::LogPrefix() const +{ + return FormatString("WorkerMasterApi, EndPoint:%s", masterAddress_.ToString()); +} + +// WorkerLocalMasterSCApi methods + +WorkerLocalMasterSCApi::WorkerLocalMasterSCApi(master::MasterSCServiceImpl *service, const HostPort &localHostPort, + std::shared_ptr akSkManager) + : WorkerMasterSCApi(localHostPort, akSkManager), masterSC_(service) +{ +} + +Status WorkerLocalMasterSCApi::Init() +{ + RETURN_RUNTIME_ERROR_IF_NULL(masterSC_); + return Status::OK(); +} + +Status WorkerLocalMasterSCApi::CreateProducer(master::CreateProducerReqPb &req, master::CreateProducerRspPb &rsp) +{ + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + INJECT_POINT("worker.CreateProducer.beforeSendToMaster", [&opts](const std::string &code) { + const int timeout = 10 * 1000; // 10s; + opts.SetTimeout(timeout); + RETURN_STATUS(GetStatusCodeByName(code), "inject status"); + }); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(masterSC_->CreateProducerImpl(nullptr, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Add new pub node on master success", LogPrefix(), + req.producer_meta().stream_name()); + return Status::OK(); +} + +Status WorkerLocalMasterSCApi::CloseProducer(master::CloseProducerReqPb &req, master::CloseProducerRspPb &rsp) +{ + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(masterSC_->CloseProducerImpl(nullptr, req, rsp)); + + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s] Closing %d producers on master success", LogPrefix(), + req.producer_infos_size()); + return Status::OK(); +} + +Status WorkerLocalMasterSCApi::Subscribe(master::SubscribeReqPb &req, master::SubscribeRspPb &rsp) +{ + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(masterSC_->SubscribeImpl(nullptr, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s] Add new consumer on master succeeded", LogPrefix(), + req.consumer_meta().stream_name(), req.consumer_meta().consumer_id()); + return Status::OK(); +} + +Status WorkerLocalMasterSCApi::CloseConsumer(master::CloseConsumerReqPb &req, master::CloseConsumerRspPb &rsp) +{ + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(masterSC_->CloseConsumerImpl(nullptr, req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s, C:%s] Delete consumer on master succeeded", LogPrefix(), + req.consumer_meta().stream_name(), req.consumer_meta().consumer_id()); + return Status::OK(); +} + +Status WorkerLocalMasterSCApi::DeleteStream(master::DeleteStreamReqPb &req, master::DeleteStreamRspPb &rsp) +{ + RpcOptions opts; + SET_RPC_TIMEOUT(scTimeoutDuration, opts); + req.set_timeout(opts.GetTimeout()); + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + RETURN_IF_NOT_OK(masterSC_->DeleteStream(req, rsp)); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString("[%s, S:%s] Delete stream succeeded.", LogPrefix(), req.stream_name()); + return Status::OK(); +} + +Status WorkerLocalMasterSCApi::QueryGlobalProducersNum(master::QueryGlobalNumReqPb &req, + master::QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + return masterSC_->QueryGlobalProducersNum(req, rsp); +} + +Status WorkerLocalMasterSCApi::QueryGlobalConsumersNum(master::QueryGlobalNumReqPb &req, + master::QueryGlobalNumRsqPb &rsp) +{ + RETURN_IF_NOT_OK(akSkManager_->GenerateSignature(req)); + return masterSC_->QueryGlobalConsumersNum(req, rsp); +} + +std::string WorkerLocalMasterSCApi::LogPrefix() const +{ + // local version of the api, the endpoint is ourself! (localWorkerAddress) + return FormatString("WorkerMasterApi, EndPoint:%s", localWorkerAddress_.ToString()); +} + +WorkerMasterSCApiManager::WorkerMasterSCApiManager(HostPort &hostPort, std::shared_ptr akSkManager, + master::MasterSCServiceImpl *masterSCService) + : WorkerMasterApiManagerBase(hostPort, akSkManager), masterSCService_(masterSCService) +{ +} + +std::shared_ptr WorkerMasterSCApiManager::CreateWorkerMasterApi(const HostPort &masterAddress) +{ + return WorkerMasterSCApi::CreateWorkerMasterSCApi(masterAddress, workerAddr_, akSkManager_, masterSCService_); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/worker_master_sc_api.h b/src/datasystem/worker/stream_cache/worker_master_sc_api.h new file mode 100644 index 0000000..5d1f00b --- /dev/null +++ b/src/datasystem/worker/stream_cache/worker_master_sc_api.h @@ -0,0 +1,240 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_WORKER_MASTER_SC_API_H +#define DATASYSTEM_WORKER_STREAM_CACHE_WORKER_MASTER_SC_API_H +#include + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/common/rpc/rpc_channel.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/master/stream_cache/master_sc_service_impl.h" +#include "datasystem/protos/master_stream.stub.rpc.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/utils/optional.h" +#include "datasystem/worker/stream_cache/stream_producer.h" +#include "datasystem/worker/worker_master_api_manager_base.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +using MasterSCService_Stub = master::MasterSCService_Stub; +/** + * @brief The WorkerMasterSCApi is an abstract class that defines the interface for interactions with the stream cache + * master service. + */ +class WorkerMasterSCApi { +public: + virtual ~WorkerMasterSCApi() = default; + + /** + * @brief Initialize the WorkerMasterSCApi Object. + * @return Status of the call. + */ + virtual Status Init() = 0; + + /** + * @brief Create producer service onto master request + * @param[in] req The req protobuf. + * @param[out] rsp The rsp protobuf. + * @return Status of the call. + */ + virtual Status CreateProducer(master::CreateProducerReqPb &req, master::CreateProducerRspPb &rsp) = 0; + + /** + * @brief Close producers service onto master request. List version. + * @param[in] req The req protobuf. + * @param[out] rsp The rsp protobuf. + * @return Status of the call. + */ + virtual Status CloseProducer(master::CloseProducerReqPb &req, master::CloseProducerRspPb &rsp) = 0; + + /** + * @brief Subscribe a new consumer onto master request. + * @param[in] req The req protobuf. + * @param[out] rsp The rsp protobuf. + * @return Status of the call. + */ + virtual Status Subscribe(master::SubscribeReqPb &req, master::SubscribeRspPb &rsp) = 0; + + /** + * @brief Close a consumer onto master request. + * @param[in] req The req protobuf. + * @param[out] rsp The rsp protobuf. + * @return Status of the call. + */ + virtual Status CloseConsumer(master::CloseConsumerReqPb &req, master::CloseConsumerRspPb &rsp) = 0; + + /** + * @brief Delete stream on master request. + * @param[in] req The req protobuf. + * @param[out] rsp The rsp protobuf. + * @return Status of the call. + */ + virtual Status DeleteStream(master::DeleteStreamReqPb &req, master::DeleteStreamRspPb &rsp) = 0; + + /** + * @brief Query each worker's producers in global scope for one stream. + * @param[in] streamName The target stream. + * @param[out] allWorkerProducers Vector of producers on every worker node. + * @return K_OK on success; the error code otherwise. + */ + virtual Status QueryGlobalProducersNum(master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) = 0; + + /** + * @brief Query each worker's consumers in global scope for one stream. + * @param[in] streamName The target stream. + * @param[out] allWorkerProducers Vector of consumers on every worker node. + * @return K_OK on success; the error code otherwise. + */ + virtual Status QueryGlobalConsumersNum(master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) = 0; + + /** + * @brief Get log prefix + * @return The log prefix + */ + virtual std::string LogPrefix() const = 0; + + /** + * @brief Get master address. + * @return Master address. + */ + virtual std::string Address() const = 0; + + /** + * @brief A factory method to instantiate the correct derived version of the api. Remote masters will use an + * rpc-based api, whereas local masters can be optimized for in-process pointer based api. + * @param[in] hostPort The host port of the target master + * @param[in] localHostPort The local worker rpc service host port. + * @param[in] akSkManager default to the RPC-based version. + * @param[in] service The local pointer to the master SC service implementation. If null, the created api must + * @return A base class pointer to the correct derived type of api. + */ + static std::shared_ptr CreateWorkerMasterSCApi(const HostPort &hostPort, + const HostPort &localHostPort, + std::shared_ptr akSkManager, + master::MasterSCServiceImpl *service = nullptr); + +protected: + /** + * @brief Construct WorkerMasterSCApi. Protected constructor enforces class instantiation through the factory method + * CreateWorkerMasterSCApi. + * @param[in] localWorkerAddress The local worker rpc service host port + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + explicit WorkerMasterSCApi(const HostPort &localWorkerAddress, std::shared_ptr akSkManager); + + HostPort localWorkerAddress_; // The HostPort of the local worker node + std::shared_ptr akSkManager_{ nullptr }; +}; + +/** + * @brief WorkerRemoteMasterApi is the derived remote version of the api for sending and receiving master SC requests + * where the master is on a different host. This class will use an RPC mechanism for communication to the remote + * location. + * Callers will access this class naturally through base class polymorphism. + * See the parent interface for function argument documentation. + */ +class WorkerRemoteMasterSCApi : public WorkerMasterSCApi { +public: + /** + * @brief Constructor for the remote version of the api + * @param[in] masterAddress The host port of the target master + * @param[in] localHostPort The local worker rpc service host port + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + explicit WorkerRemoteMasterSCApi(const HostPort &masterAddress, const HostPort &localHostPort, + std::shared_ptr akSkManager); + ~WorkerRemoteMasterSCApi() override = default; + Status Init() override; + Status CreateProducer(master::CreateProducerReqPb &req, master::CreateProducerRspPb &rsp) override; + Status CloseProducer(master::CloseProducerReqPb &req, master::CloseProducerRspPb &rsp) override; + Status Subscribe(master::SubscribeReqPb &req, master::SubscribeRspPb &rsp) override; + Status CloseConsumer(master::CloseConsumerReqPb &req, master::CloseConsumerRspPb &rsp) override; + Status DeleteStream(master::DeleteStreamReqPb &req, master::DeleteStreamRspPb &rsp) override; + Status QueryGlobalProducersNum(master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) override; + Status QueryGlobalConsumersNum(master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) override; + std::string LogPrefix() const override; + + std::string Address() const override + { + return masterAddress_.ToString(); + } + +private: + HostPort masterAddress_; // The HostPort of the master node + std::shared_ptr rpcSession_{ nullptr }; // Session to the master rpc service +}; + +/** + * @brief WorkerLocalMasterSCApi is the derived local version of the api for sending and receiving master OC requests + * where the master exists in the same process as the service. This class will directly reference the service through a + * pointer and does not use any RPC mechanism for communication. + * Callers will access this class naturally through base class polymorphism. + * See the parent interface for function argument documentation. + */ +class WorkerLocalMasterSCApi : public WorkerMasterSCApi { +public: + /** + * @brief Constructor for the local version of the api + * @param[in] service The pointer to the master SC service implementation + * @param[in] localHostPort The local worker service host port. + * @param[in] akSkManager Used to do AK/SK authenticate. + */ + explicit WorkerLocalMasterSCApi(master::MasterSCServiceImpl *service, const HostPort &localHostPort, + std::shared_ptr akSkManager); + ~WorkerLocalMasterSCApi() override = default; + Status Init() override; + Status CreateProducer(master::CreateProducerReqPb &req, master::CreateProducerRspPb &rsp) override; + Status CloseProducer(master::CloseProducerReqPb &req, master::CloseProducerRspPb &rsp) override; + Status Subscribe(master::SubscribeReqPb &req, master::SubscribeRspPb &rsp) override; + Status CloseConsumer(master::CloseConsumerReqPb &req, master::CloseConsumerRspPb &rsp) override; + Status DeleteStream(master::DeleteStreamReqPb &req, master::DeleteStreamRspPb &rsp) override; + Status QueryGlobalProducersNum(master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) override; + Status QueryGlobalConsumersNum(master::QueryGlobalNumReqPb &req, master::QueryGlobalNumRsqPb &rsp) override; + std::string LogPrefix() const override; + + std::string Address() const override + { + return localWorkerAddress_.ToString(); + } + +private: + master::MasterSCServiceImpl *masterSC_; +}; + +class WorkerMasterSCApiManager : public WorkerMasterApiManagerBase { +public: + WorkerMasterSCApiManager(HostPort &hostPort, std::shared_ptr manager, + master::MasterSCServiceImpl *masterSCService); + virtual ~WorkerMasterSCApiManager() = default; + + /** + * @brief Create a worker to Master api object for masterAddress + * @param[in] masterAddress The remote master ip address + * @return The WorkerMasterSCApi + */ + std::shared_ptr CreateWorkerMasterApi(const HostPort &masterAddress) override; + +private: + master::MasterSCServiceImpl *masterSCService_{ nullptr }; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif // DATASYSTEM_WORKER_STREAM_CACHE_WORKER_MASTER_SC_API_H diff --git a/src/datasystem/worker/stream_cache/worker_sc_allocate_memory.cpp b/src/datasystem/worker/stream_cache/worker_sc_allocate_memory.cpp new file mode 100644 index 0000000..bfd66cc --- /dev/null +++ b/src/datasystem/worker/stream_cache/worker_sc_allocate_memory.cpp @@ -0,0 +1,95 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" + +#include +#include "datasystem/common/object_cache/safe_table.h" +#include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/shared_memory/arena.h" +#include "datasystem/common/util/format.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/object_cache/worker_oc_eviction_manager.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +WorkerSCAllocateMemory::WorkerSCAllocateMemory(std::shared_ptr manager) + : ocEvictManager_(std::move(manager)) +{ + streamMaxSize_ = datasystem::memory::Allocator::Instance()->GetMaxMemorySize(ServiceType::STREAM); +} + +Status WorkerSCAllocateMemory::AllocateMemoryForStream(const std::string &tenantId, const std::string &streamId, + const uint64_t needSize, bool populate, ShmUnit &shmUnit, + bool retryOnOOM) +{ + Timer timer; + PerfPoint point(PerfKey::WORKER_MEMORY_ALLOCATE); + auto streamMemoryUsageSize = + datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(ServiceType::STREAM); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + UINT64_MAX - needSize >= streamMemoryUsageSize, K_OUT_OF_RANGE, + FormatString("The size is overflow, stream cache memory use size:%d + add:%d > UINT64_MAX:%d", + streamMemoryUsageSize, needSize, UINT64_MAX)); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + streamMemoryUsageSize + needSize <= streamMaxSize_, K_OUT_OF_MEMORY, + FormatString( + "Stream cache memory size overflow, maxStreamSize is: %d, need size is: %d, stream cache use size: %d", + streamMaxSize_, needSize, streamMemoryUsageSize)); + bool evict = false; + if (datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(ServiceType::OBJECT) > 0) { + evict = EvictWhenMemoryExceedThrehold(streamId, needSize, ocEvictManager_, ServiceType::STREAM); + } + // Allocate some memory into this shmUnit + // if object used size = 0, no need to evict object, return OOM, if stream size is max return OOM + Status rc = shmUnit.AllocateMemory(tenantId, needSize, populate, ServiceType::STREAM); + static const std::vector WAIT_MSECOND = { 1, 10, 50, 100, 200, 400, 800, 1600, 3200 }; + if (rc.GetCode() == K_OUT_OF_MEMORY && evict && retryOnOOM) { + for (int t : WAIT_MSECOND) { + auto remainingTime = reqTimeoutDuration.CalcRealRemainingTime(); + auto sleepTime = std::min(remainingTime, t); + streamMemoryUsageSize = + datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage(ServiceType::STREAM); + if (streamMemoryUsageSize + needSize > streamMaxSize_) { + return Status(K_OUT_OF_MEMORY, FormatString("Stream cache memory size overflow, maxStreamSize is: %d, " + "need size is: %d, stream cache use size: %d", + streamMaxSize_, needSize, streamMemoryUsageSize)); + } + if (remainingTime <= 0) { + break; + } + VLOG(1) << FormatString("OOM, sleep time: %ld, streamId: %s, needSize %ld", sleepTime, streamId, + needSize); + std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); + rc = shmUnit.AllocateMemory(tenantId, needSize, populate, ServiceType::STREAM); + if (rc.GetCode() != K_OUT_OF_MEMORY) { + break; + } + + (void)EvictWhenMemoryExceedThrehold(streamId, needSize, ocEvictManager_, ServiceType::STREAM); + } + } + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + rc, FormatString("[stream %s] Error while allocating memory size %ld", streamId, needSize)); + VLOG(DEBUG_LOG_LEVEL) << "allocate for stream success, allocate size: " << needSize << " stream cache use size: " + << datasystem::memory::Allocator::Instance()->GetTotalRealMemoryUsage( + ServiceType::STREAM); + return Status::OK(); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/worker_sc_allocate_memory.h b/src/datasystem/worker/stream_cache/worker_sc_allocate_memory.h new file mode 100644 index 0000000..fe58e82 --- /dev/null +++ b/src/datasystem/worker/stream_cache/worker_sc_allocate_memory.h @@ -0,0 +1,70 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Defines the worker sc allocate memory manager class. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_WORKER_SC_ALLOCATE_MEMORY_H +#define DATASYSTEM_WORKER_STREAM_CACHE_WORKER_SC_ALLOCATE_MEMORY_H + +#include + +#include "datasystem/common/shared_memory/shm_unit.h" +namespace datasystem { +namespace object_cache { +class WorkerOcEvictionManager; +} +namespace worker { +namespace stream_cache { +class WorkerSCAllocateMemory { +public: + /** + * @brief Construct a new Worker S C Evict Object object + * @param manager + */ + WorkerSCAllocateMemory(std::shared_ptr manager); + + /** + * @brief + * @param tenantId + * @param streamId + * @param needSize + * @param populate + * @param shmUnit + * @param retryOnOOM + * @return Status + */ + Status AllocateMemoryForStream(const std::string &tenantId, const std::string &streamId, + const uint64_t needSize, bool populate, ShmUnit &shmUnit, bool retryOnOOM); + + /** + * @brief Gets Total SHM memory allocated to the Stream Cache + * @return Memory Size + */ + uint64_t GetTotalMaxStreamSHMSize() + { + return streamMaxSize_; + } + +private: + std::shared_ptr ocEvictManager_; + uint64_t streamMaxSize_ = 0; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem +#endif // DATASYSTEM_WORKER_STREAM_CACHE_WORKER_SC_ALLOCATE_MEMORY_H \ No newline at end of file diff --git a/src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.cpp b/src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.cpp new file mode 100644 index 0000000..0a2099a --- /dev/null +++ b/src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.cpp @@ -0,0 +1,405 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/worker_worker_sc_service_impl.h" + +#include +#include + +#include "datasystem/worker/stream_cache/stream_manager.h" +#include "datasystem/worker/stream_cache/usage_monitor.h" + +DS_DECLARE_int32(sc_regular_socket_num); +DS_DECLARE_uint64(sc_local_cache_memory_size_mb); + +namespace datasystem { +namespace worker { +namespace stream_cache { + +std::string RecvElementView::StreamName() const +{ + return streamName_; +} + +std::string RecvElementView::ProducerName() const +{ + // All local and remote producers write to the same page. + // However, we still need to distinguish elements from different worker + // because they can have the same seqNo + return workerAddr_; +} + +std::string RecvElementView::ProducerInstanceId() const +{ + // Sequence number will get reset if producer worker restarts + return workerInstanceId_; +} + +uint64_t RecvElementView::StreamHash() const +{ + // We will swap the position of stream and worker address so to hash differently + StreamProducerKey key(ProducerName(), KeyName(), ProducerInstanceId()); + return std::hash{}(key); +} + +Status RecvElementView::ReleasePage() +{ + return Status::OK(); +} + +void *RecvElementView::GetBufferPointer() +{ + void *ptr = decrypted_.load() ? localBuf_.get() : recvBuffer_.Data(); + return ptr; +} + +WorkerWorkerSCServiceImpl::WorkerWorkerSCServiceImpl(ClientWorkerSCServiceImpl *impl, + std::shared_ptr akSkManager) + : akSkManager_(std::move(akSkManager)), + clientWorkerScService_(impl), + usageMonitor_(impl, FLAGS_sc_local_cache_memory_size_mb * 1024 * 1024) +{ +} + +WorkerWorkerSCServiceImpl::~WorkerWorkerSCServiceImpl() +{ + if (dataMap_) { + dataMap_->Stop(); + } +} + +Status WorkerWorkerSCServiceImpl::Init() +{ + dataMap_ = std::make_unique(FLAGS_sc_regular_socket_num, "ScCopyToShm", + std::bind(&WorkerWorkerSCServiceImpl::BatchAsyncFlushEntry, this, + std::placeholders::_1, std::placeholders::_2)); + RETURN_IF_NOT_OK(dataMap_->Init()); + clientWorkerScService_->SetWorkerWorkerSCServiceImpl(weak_from_this()); + return Status::OK(); +} + +UsageMonitor &WorkerWorkerSCServiceImpl::GetUsageMonitor() +{ + return usageMonitor_; +} + +Status WorkerWorkerSCServiceImpl::ProcessEndOfStream(const std::shared_ptr &streamMgr, + std::list dataLst, const std::string &streamName, + const std::string &workerAddr) +{ + (void)streamMgr; + // Clean up the usage in the UsageMonitor + for (auto &ele : dataLst) { + auto data = std::static_pointer_cast(ele.first); + auto sz = data->recvBuffer_.Size(); + usageMonitor_.DecUsage(streamName, workerAddr, sz); + } + // Discard all the buffers. + dataLst.clear(); + // Signal this job is done. + return Status::OK(); +} + +Status WorkerWorkerSCServiceImpl::CheckStreamState(const std::string &streamName, + StreamManagerMap::const_accessor &accessor, + std::shared_ptr &mgr) +{ + Status rc = clientWorkerScService_->GetStreamManager(streamName, accessor); + if (rc.IsOk()) { + mgr = accessor->second; + // Check its state (delete, reset) + return mgr->CheckIfStreamActive(); + } else if (rc.GetCode() == K_SC_STREAM_NOT_FOUND) { + return Status::OK(); + } + return rc; +} + +Status WorkerWorkerSCServiceImpl::ProcessRecvElementView(std::shared_ptr &baseBufferData, + const std::string &streamName, + const std::string &sendWorkerAddr, bool &isBlocked) +{ + auto data = std::static_pointer_cast(baseBufferData); + auto workerAddr = data->workerAddr_; + auto bufSz = data->recvBuffer_.Size(); + auto seqNo = data->seqNo_; + auto count = data->sz_.size(); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(data->traceId_); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "[S:%s, W:%s] Processing RecvElementView. Seq %zu, Count %zu, " + "Size %zu", + streamName, workerAddr, seqNo, count, bufSz); + // Get stream manager. If it is gone, purge the buffers. + StreamManagerMap::const_accessor accessor; + std::shared_ptr streamMgr; + auto rc = CheckStreamState(data->StreamName(), accessor, streamMgr); + if (streamMgr) { + if (streamMgr->IsProducerBlocked(sendWorkerAddr)) { + isBlocked = true; + return Status::OK(); + } + rc = streamMgr->CopyElementView(data, usageMonitor_, RPC_POLL_TIME); + } + return rc; +} + +Status WorkerWorkerSCServiceImpl::BatchAsyncFlushEntry(int myId, const PendingFlushList &pendingFlushList) +{ + (void)myId; + size_t numProducers = pendingFlushList.size(); + std::vector rc(numProducers); + for (size_t i = 0; i < numProducers; ++i) { + const StreamProducerKey &key = pendingFlushList.at(i).first; + std::list &dataList = pendingFlushList.at(i).second; + const std::string &streamName = key.firstKey_; + const std::string &sendWorkerAddr = key.producerId_; + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("[S:%s] Processing RecvElementsViews. Number of elements %zu", + streamName, dataList.size()); + auto it = dataList.begin(); + while (it != dataList.end()) { + bool isBlocked = false; + rc[i] = ProcessRecvElementView(it->first, streamName, sendWorkerAddr, isBlocked); + if (rc[i].IsError() || isBlocked) { + break; + } + it = dataList.erase(it); + } + } + return ReturnFirstErrorStatus(rc); +} + +Status WorkerWorkerSCServiceImpl::ReturnFirstErrorStatus(const std::vector &rc) +{ + // Return the first non-ok error + for (const auto &status : rc) { + if (status.IsError()) { + return status; + } + } + return Status::OK(); +} + +void WorkerWorkerSCServiceImpl::RemoveStream(const std::string &keyName, const std::string &sharedPageName) +{ + dataMap_->RemoveStream(keyName, sharedPageName); +} + +void WorkerWorkerSCServiceImpl::PurgeBuffer(const std::shared_ptr &streamManager) + +{ + dataMap_->PurgeBuffer(streamManager->GetStreamName(), + std::bind(&WorkerWorkerSCServiceImpl::ProcessEndOfStream, this, streamManager, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); +} + +Status WorkerWorkerSCServiceImpl::ParsePushData(const PushReqPb &pushReqPb, std::vector &payloads, + std::shared_ptr &streamManager, + const std::string &workerAddr, PushRspPb &pushRspPb, + std::vector> &flushList) +{ + const std::string streamName = streamManager->GetStreamName(); + auto streamAllocRatio = streamManager->GetStreamMemAllocRatio(); + auto streamPageSize = streamManager->GetStreamPageSize(); + int numPageViews = pushReqPb.seq_size(); + const std::string &producerId = pushReqPb.producer_id(); + const std::string &workerInstanceId = pushReqPb.worker_instance_id(); + Status allocRc; + uint64_t totalSize = 0; + for (int i = 0; i < numPageViews; ++i) { + // Step 1: Check for OOM + // Check if local memory is over used. If OOM, no need to continue + // We can't use the payload size because it can be encrypted and different from the real size. + size_t totalLength = 0; + std::vector sz; + // Elements are packed in reverse order. + auto &eleMeta = pushReqPb.element_meta(i); + for (auto k = 0; k < eleMeta.element_sizes_size(); ++k) { + auto eleSz = eleMeta.element_sizes(k); + sz.emplace_back(eleSz); + totalLength += eleSz; + } + allocRc = usageMonitor_.CheckNIncOverUsedForStream(streamName, workerAddr, streamPageSize, streamAllocRatio, + totalLength); + if (allocRc.IsError()) { + for (int k = i; k < numPageViews; ++k) { + auto *rsp = pushRspPb.mutable_error(k); + rsp->Clear(); + rsp->set_error_code(allocRc.GetCode()); + rsp->set_error_msg(allocRc.GetMsg()); + VLOG(SC_NORMAL_LOG_LEVEL) << FormatString( + "[RW:%s, S:%s, I:%s, seq:%zu, count: %d, sz:%zu] " + "Not enough memory to satisfy the request", + workerAddr, streamName, workerInstanceId, pushReqPb.seq(k), eleMeta.element_sizes_size(), + payloads[k].Size()); + } + // No need to continue; + break; + } + // Step 2: Parse meta data. + // Will be freed by processing thread (threadpool_) + uint64_t seqNo = pushReqPb.seq(i); + auto recvElementView = std::make_shared(); + // Step 2: Parse payloads. + recvElementView->streamName_ = streamName; + recvElementView->workerAddr_ = workerAddr; + recvElementView->recvBuffer_ = std::move(payloads[i]); + recvElementView->traceId_ = Trace::Instance().GetTraceID(); + recvElementView->workerInstanceId_ = workerInstanceId; + recvElementView->seqNo_ = seqNo; + recvElementView->totalLength_ = totalLength; + recvElementView->sz_ = sz; + for (auto k = 0; k < eleMeta.header_bits_size(); ++k) { + recvElementView->headerBits_.emplace_back(eleMeta.header_bits(k)); + } + auto status = dataMap_->UnsortedInsert(recvElementView, pushReqPb.seq(i), pushReqPb.first_seq()); + // Rollback reserved memory for duplicate element + if (status.GetCode() == K_DUPLICATED) { + LOG_IF_ERROR(usageMonitor_.DecUsage(streamName, workerAddr, totalLength), + FormatString("%s:%s", __FUNCTION__, __LINE__)); + continue; + } + auto bufSz = recvElementView->recvBuffer_.Size(); + totalSize += bufSz; + // Accept and add the pageviews for processing + flushList.emplace_back(std::move(recvElementView)); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString( + "Finished req, stream:[%s], worker:[%s], producer:[%s], Instance:[%s], seq:[%zu], count:[%d], size:[%zu], " + "total size:[%zu]", + streamName, workerAddr, producerId, workerInstanceId, seqNo, eleMeta.element_sizes_size(), bufSz, + totalSize); + } + return Status::OK(); +} + +Status WorkerWorkerSCServiceImpl::PushElementsCursors( + std::shared_ptr> serverApi) +{ + INJECT_POINT("PushElementsCursors.begin"); + PerfPoint point(PerfKey::PUSH_CURSOR_ALL); + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("Preparing to receive pushed data."); + PushReqPb pushReqPb; + PushRspPb pushRspPb; + std::vector payloads; + RETURN_IF_NOT_OK(serverApi->Read(pushReqPb)); + RETURN_IF_NOT_OK(serverApi->ReceivePayload(payloads)); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(pushReqPb), "AK/SK failed"); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReqPb.trace_id()); + RETURN_IF_NOT_OK(PushElementsCursorsHelper(pushReqPb, payloads, pushRspPb)); + point.RecordAndReset(PerfKey::PUSH_CURSOR_RESPONSE); + VLOG(SC_DEBUG_LOG_LEVEL) << "worker PushElementsCursors done"; + // We reply to the client at this point. + RETURN_IF_NOT_OK(serverApi->Write(pushRspPb)); + return Status::OK(); +} + +Status WorkerWorkerSCServiceImpl::PushElementsCursorsHelper(PushReqPb &pushReqPb, std::vector &payloads, + PushRspPb &pushRspPb) +{ + const auto &streamName = pushReqPb.stream_name(); + const auto &workerAddr = pushReqPb.worker_addr(); + StreamManagerMap::const_accessor accessor; + RETURN_IF_NOT_OK(clientWorkerScService_->GetStreamManager(streamName, accessor)); + std::shared_ptr streamManager = accessor->second; + auto streamPageSize = streamManager->GetStreamPageSize(); + CHECK_FAIL_RETURN_STATUS(streamPageSize > 0, K_TRY_AGAIN, "Uninitialized page size, try again later."); + RETURN_IF_NOT_OK(streamManager->CheckConsumerExist(workerAddr)); + // Set up response. We may change the error code later. + for (int i = 0; i < pushReqPb.seq_size(); ++i) { + ErrorInfoPb rsp; + rsp.set_error_code(StatusCode::K_OK); + pushRspPb.mutable_error()->Add(std::move(rsp)); + } + // Reject any new data if stream is already in reset mode + // Send OK to remote producer so that it does not try to resend old data. But we have just + // set up the response for each PV + if (streamManager->CheckIfStreamActive().IsError()) { + return Status::OK(); + } + std::vector> flushList; + RETURN_IF_NOT_OK(ParsePushData(pushReqPb, payloads, streamManager, workerAddr, pushRspPb, flushList)); + return Status::OK(); +} + +Status WorkerWorkerSCServiceImpl::PushSharedPageCursors( + std::shared_ptr> serverApi) +{ + INJECT_POINT("PushElementsCursors.begin"); + PerfPoint point(PerfKey::PUSH_CURSOR_ALL); + VLOG(SC_DEBUG_LOG_LEVEL) << FormatString("Preparing to receive pushed shared page data."); + SharedPagePushReqPb pushReqPb; + PushRspPb pushRspPb; + std::vector payloads; + RETURN_IF_NOT_OK(serverApi->Read(pushReqPb)); + RETURN_IF_NOT_OK(serverApi->ReceivePayload(payloads)); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(akSkManager_->VerifySignatureAndTimestamp(pushReqPb), "AK/SK failed"); + TraceGuard traceGuard = Trace::Instance().SetTraceNewID(pushReqPb.trace_id()); + // Group requests by stream name, and also record the indexes so the response is still in order. + std::unordered_map streamDataMap; + std::vector errorList(pushReqPb.metas_size()); + for (int i = 0; i < pushReqPb.metas_size(); ++i) { + const std::string &streamName = pushReqPb.stream_names(pushReqPb.metas(i).stream_index()); + VLOG(SC_INTERNAL_LOG_LEVEL) << FormatString("PushSharedPageCursors req involves stream %s at iteration %d", + streamName, i); + auto iter = streamDataMap.find(streamName); + if (iter == streamDataMap.end()) { + PushReqPb req; + req.set_stream_name(streamName); + req.set_worker_addr(pushReqPb.worker_addr()); + req.set_producer_id(pushReqPb.producer_id()); + // Force the first sequence number to be 1, + // so that UnsortedInsert only waits for the expected sequence number. + req.set_first_seq(1); + req.set_worker_instance_id(pushReqPb.worker_instance_id()); + iter = streamDataMap.emplace(streamName, PushSharedPageTuple()).first; + iter->second.req_ = std::move(req); + } + iter->second.index_.emplace_back(i); + auto &req = iter->second.req_; + *req.mutable_element_meta()->Add() = pushReqPb.metas(i).element_meta(); + *req.mutable_seq()->Add() = pushReqPb.metas(i).seq(); + iter->second.payload_.emplace_back(std::move(payloads[i])); + } + for (auto &tuple : streamDataMap) { + PushReqPb &req = tuple.second.req_; + std::vector &payload = tuple.second.payload_; + PushRspPb rsp; + auto rc = PushElementsCursorsHelper(req, payload, rsp); + ErrorInfoPb err; + if (rc.IsError()) { + err.set_error_code(rc.GetCode()); + err.set_error_msg(rc.GetMsg()); + } else { + uint64_t expectedSize = tuple.second.index_.size(); + uint64_t actualSize = static_cast(rsp.error_size()); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + actualSize == expectedSize, K_RUNTIME_ERROR, + FormatString("Unexpected number of error info, expected %zu, actual %zu", expectedSize, actualSize)); + } + uint64_t rspIndex = 0; + for (const auto &index : tuple.second.index_) { + errorList[index] = rc.IsError() ? err : rsp.error(rspIndex++); + } + } + *pushRspPb.mutable_error() = { errorList.begin(), errorList.end() }; + point.RecordAndReset(PerfKey::PUSH_CURSOR_RESPONSE); + VLOG(SC_DEBUG_LOG_LEVEL) << "worker PushSharedPageCursors done"; + // We reply to the client at this point. + RETURN_IF_NOT_OK(serverApi->Write(pushRspPb)); + return Status::OK(); +} +} // namespace stream_cache +} // namespace worker +} // namespace datasystem diff --git a/src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.h b/src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.h new file mode 100644 index 0000000..bee9a8b --- /dev/null +++ b/src/datasystem/worker/stream_cache/worker_worker_sc_service_impl.h @@ -0,0 +1,176 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Defines the worker worker service processing main class. + */ + +#ifndef DATASYSTEM_WORKER_STREAM_CACHE_WORKER_WORKER_SC_SERVICE_IMPL_H +#define DATASYSTEM_WORKER_STREAM_CACHE_WORKER_WORKER_SC_SERVICE_IMPL_H + +#include "datasystem/common/ak_sk/ak_sk_manager.h" +#include "datasystem/protos/stream_posix.service.rpc.pb.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/stream_cache/buffer_pool.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/usage_monitor.h" + +namespace datasystem { +namespace worker { +namespace stream_cache { +struct RecvElementView : public BaseBufferData { + std::string streamName_; + RpcMessage recvBuffer_; // Holding this buffer so that we free it later + std::string workerAddr_; + std::string workerInstanceId_; + uint64_t seqNo_{ 0 }; + std::vector sz_; + std::vector headerBits_; + std::unique_ptr localBuf_; // For decrypted data + std::atomic decrypted_{ false }; + size_t totalLength_{ 0 }; + uint64_t idx_{ 0 }; + + std::string StreamName() const override; + std::string ProducerName() const override; + std::string ProducerInstanceId() const override; + uint64_t StreamHash() const override; + Status ReleasePage() override; + void *GetBufferPointer(); +}; + +class WorkerWorkerSCServiceImpl : public WorkerWorkerSCService, + public std::enable_shared_from_this { +public: + WorkerWorkerSCServiceImpl(ClientWorkerSCServiceImpl *impl, std::shared_ptr akSkManager); + ~WorkerWorkerSCServiceImpl() override; + + /** + * @brief Init Worker worker sc service. + * @return Status of the call. + */ + Status Init() override; + + /** + * @brief Receive pushed elements and cursors from pub workers. + * @param[in, out] serverApi The server reader writer session. + * @return K_OK on success; the error code otherwise. + */ + Status PushElementsCursors(std::shared_ptr> serverApi) override; + + /** + * @brief Helper function to handle the pushed elements. + * @param[in] req The request protobuf. + * @param[in] payloads The actual data payloads. + * @param[out] rsp The response protobuf. + * @return K_OK on success; the error code otherwise. + */ + Status PushElementsCursorsHelper(PushReqPb &req, std::vector &payloads, PushRspPb &rsp); + + /** + * @brief Receive pushed elements and cursors from pub workers, for shared page scenario. + * @param[in, out] serverApi The server reader writer session. + * @return K_OK on success; the error code otherwise. + */ + Status PushSharedPageCursors( + std::shared_ptr> serverApi) override; + + Status ProcessEndOfStream(const std::shared_ptr &streamMgr, std::list dataLst, + const std::string &streamName, const std::string &workerAddr); + + void PurgeBuffer(const std::shared_ptr &streamManager); + + /** + * @brief Remove the info of useless stream from BufferPool + * @param keyName The stream name or page name. + * @param sharedPageName The shared page name. Empty if the stream use exclusive page or the keyName is page. + */ + void RemoveStream(const std::string &keyName, const std::string &sharedPageName); + + /** + * @brief Get usage monitor. + * @return The reference to the usage monitor. + */ + UsageMonitor &GetUsageMonitor(); + +private: + /** + * @brief Batch Async flush entry. + * @param[in] myId The num of Partitions. + * @param[in] pendingFlushList Flush list for pending. + * @return Status of the call. + */ + Status BatchAsyncFlushEntry(int myId, const PendingFlushList &pendingFlushList); + + /** + * @brief Async flush entry. + * @param[in] baseBufferData The entry to be flush. + * @param[in] streamName The stream name. + * @param[in] sendWorkerAddr The send worker address. + * @param[in] isBlocked Whether the remote producer is blocked. + * @return Status of the call. + */ + Status ProcessRecvElementView(std::shared_ptr &baseBufferData, const std::string &streamName, + const std::string &sendWorkerAddr, bool &isBlocked); + + /** + * @brief Check the stream state. + * @param[in] streamName The stream name. + * @param[in] accessor The StreamManagerMap accessor. + * @param[in] mgr The instance of StreamManager. + * @return Status of the call. + */ + Status CheckStreamState(const std::string &streamName, StreamManagerMap::const_accessor &accessor, + std::shared_ptr &mgr); + + /** + * @brief Parse the data from payload. + * @param[in] pushReqPb PushReqPb message. + * @param[in] payloads Payloads data list. + * @param[in] streamManager StreamManager pointer. + * @param[in] workerAddr Worker address. + * @param[out] pushRspPb PushRspPb message. + * @param[out] flushList Flush list for BaseBufferData. + * @param[out] totalSize Total Size of element. + * @return Status of the call. + */ + Status ParsePushData(const PushReqPb &pushReqPb, std::vector &payloads, + std::shared_ptr &streamManager, const std::string &workerAddr, + PushRspPb &pushRspPb, std::vector> &flushList); + + /** + * @brief Checks all error and returns the first one. + * @param[in] rc vector of errors for each producer + * @return status of the call + */ + Status ReturnFirstErrorStatus(const std::vector &rc); + + struct PushSharedPageTuple { + PushReqPb req_; + std::vector payload_; + std::vector index_; + }; + + std::shared_ptr akSkManager_; + ClientWorkerSCServiceImpl *clientWorkerScService_; + std::unique_ptr dataMap_; + UsageMonitor usageMonitor_; +}; +} // namespace stream_cache +} // namespace worker +} // namespace datasystem + +#endif // DATASYSTEM_WORKER_STREAM_CACHE_WORKER_WORKER_SC_SERVICE_IMPL_H diff --git a/src/datasystem/worker/worker_cli.h b/src/datasystem/worker/worker_cli.h index a76043a..69c6543 100644 --- a/src/datasystem/worker/worker_cli.h +++ b/src/datasystem/worker/worker_cli.h @@ -20,7 +20,18 @@ namespace datasystem { namespace cli { +/** + * @brief Save hash ring to the file. + * @param[in] filename The filename to save the hash ring information. + * @return Status of this call + */ Status SaveHashRingToFile(const std::string &filename); + +/** + * @brief Updating the hash ring from the file. + * @param[in] filename Update the hash ring based on the information in the file. + * @return Status of this call + */ Status UpdateHashRingFromFile(const std::string &filename); bool HandleCli(); } // namespace cli diff --git a/src/datasystem/worker/worker_liveness_check.cpp b/src/datasystem/worker/worker_liveness_check.cpp index e91a549..f563c9e 100644 --- a/src/datasystem/worker/worker_liveness_check.cpp +++ b/src/datasystem/worker/worker_liveness_check.cpp @@ -46,6 +46,8 @@ DS_DECLARE_bool(enable_distributed_master); DS_DECLARE_string(master_address); DS_DECLARE_uint32(node_timeout_s); +DS_DECLARE_int32(sc_regular_socket_num); +DS_DECLARE_int32(sc_stream_socket_num); namespace datasystem { namespace worker { @@ -66,6 +68,11 @@ WorkerLivenessCheck::WorkerLivenessCheck(WorkerOCServer *workerOcServer, std::st } } +inline bool EnableSCService() +{ + return FLAGS_sc_regular_socket_num > 0 && FLAGS_sc_stream_socket_num > 0; +} + Status WorkerLivenessCheck::Init() { livenessKey_ = FormatString("liveness-%s;%s", GetStringUuid(), workerUuid_); @@ -79,6 +86,14 @@ Status WorkerLivenessCheck::Init() if (IsMasterNode()) { servicesNames_.emplace_back("MasterOCService"); } + if (EnableSCService()) { + servicesNames_.emplace_back("ClientWorkerSCService"); + servicesNames_.emplace_back("WorkerWorkerSCService"); + servicesNames_.emplace_back("MasterWorkerSCService"); + if (IsMasterNode()) { + servicesNames_.emplace_back("MasterSCService"); + } + } LivenessHealthCheckEvent::GetInstance().AddSubscriber( "WORKER_LIVENESS_CHECK", [this](Timer &timer, const Status &lastStatus) { return CheckLivenessProbeFile(timer, lastStatus); }); diff --git a/src/datasystem/worker/worker_master_api_manager_base.h b/src/datasystem/worker/worker_master_api_manager_base.h index 2dc584b..f070e14 100644 --- a/src/datasystem/worker/worker_master_api_manager_base.h +++ b/src/datasystem/worker/worker_master_api_manager_base.h @@ -36,7 +36,7 @@ public: /** * @brief Get or Create a worker to Master api object according to an identifier. - * @param[in] id An identifier, can be an object key in OC scenario. + * @param[in] id An identifier, can be an object key in OC scenario, or a stream name in SC scenario. * @param[in] etcdCm The cluster manager pointer to assign. * @param[out] api The WorkerMasterApi instance. * @return The status of this call. @@ -57,7 +57,7 @@ public: /** * @brief Get or Create a worker to Master api object according to an identifier. - * @param[in] id An identifier, can be an object key in OC scenario. + * @param[in] id An identifier, can be an object key in OC scenario, or a stream name in SC scenario. * @param[in] etcdCm The cluster manager pointer to assign. * @return The WorkerMasterApi */ diff --git a/src/datasystem/worker/worker_oc_server.cpp b/src/datasystem/worker/worker_oc_server.cpp index 0398dd7..48195a5 100644 --- a/src/datasystem/worker/worker_oc_server.cpp +++ b/src/datasystem/worker/worker_oc_server.cpp @@ -56,6 +56,8 @@ #include "datasystem/common/util/strings_util.h" #include "datasystem/common/util/uri.h" #include "datasystem/common/util/validator.h" +#include "datasystem/master/stream_cache/rpc_session_manager.h" +#include "datasystem/master/stream_cache/sc_metadata_manager.h" #include "datasystem/protos/hash_ring.pb.h" #include "datasystem/protos/object_posix.stub.rpc.pb.h" #include "datasystem/protos/worker_object.service.rpc.pb.h" @@ -65,6 +67,8 @@ #include "datasystem/worker/hash_ring/hash_ring.h" #include "datasystem/worker/hash_ring/hash_ring_event.h" #include "datasystem/worker/object_cache/worker_oc_spill.h" +#include "datasystem/worker/stream_cache/metrics/sc_metrics_monitor.h" +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" #include "datasystem/worker/cluster_manager/worker_health_check.h" #include "datasystem/common/metrics/res_metric_collector.h" #include "datasystem/worker/worker_liveness_check.h" @@ -84,10 +88,20 @@ DS_DEFINE_uint64(shared_memory_size_mb, 1024, #endif DS_DEFINE_uint64(shared_disk_size_mb, 0, "Upper limit of the shared disk, the unit is mb."); +#ifdef WITH_TESTS +DS_DEFINE_uint64(sc_local_cache_memory_size_mb, 128, + "Upper limit of the shared memory, the unit is mb, must be greater than 0."); +#else +DS_DEFINE_uint64(sc_local_cache_memory_size_mb, 1024, + "Upper limit of the SC local cache, the unit is mb, must be greater than 0."); +#endif DS_DEFINE_uint32(oc_shm_threshold_percentage, 100, "Upper limit of the shared memory in percentage can be used by OC, must be within (0, 100]"); - +DS_DEFINE_uint32(sc_shm_threshold_percentage, 100, + "Upper limit of the shared memory in percentage can be used by SC, must be within (0, 100]."); +DS_DEFINE_uint32(page_size, 1024 * 1024, + "Size of the page used for caching worker files. The valid range is 4096-1073741824."); DS_DEFINE_bool(ipc_through_shared_memory, true, "Using shared memory to exchange data between client and worker."); DS_DECLARE_bool(authorization_enable); DS_DEFINE_string(ready_check_path, "", @@ -106,6 +120,10 @@ DS_DEFINE_uint32(request_expire_time_s, 300, DS_DEFINE_validator(ready_check_path, &Validator::ValidatePathString); DS_DEFINE_validator(shared_memory_size_mb, &Validator::ValidateSharedMemSize); DS_DEFINE_validator(shared_disk_size_mb, &Validator::ValidateSharedDiskSize); +DS_DEFINE_validator(sc_local_cache_memory_size_mb, &Validator::ValidateLocalCacheMemSize); +DS_DEFINE_validator(page_size, &Validator::ValidatePageSize); +DS_DECLARE_int32(sc_regular_socket_num); +DS_DECLARE_int32(sc_stream_socket_num); DS_DECLARE_string(unix_domain_socket_dir); DS_DECLARE_string(etcd_address); DS_DEFINE_bool(async_delete, false, "Master notify workers to delete objects asynchronously."); @@ -114,6 +132,8 @@ DS_DEFINE_bool(cross_az_get_data_from_worker, true, "Control whether try to get DS_DECLARE_uint32(node_timeout_s); DS_DEFINE_int32(oc_worker_worker_direct_port, 0, "Direct tcp/ip port for WorkerWorkerOCService. 0 -- disable this direction connection"); +DS_DEFINE_int32(sc_worker_worker_direct_port, 0, + "Direct tcp/ip port for WorkerWorkerSCService. 0 -- disable this direction connection"); DS_DEFINE_bool(enable_hash_ring_self_healing, false, "Whether to support self-healing when the hash ring is in an abnormal state, default is false."); DS_DEFINE_string(liveness_check_path, "", @@ -125,6 +145,10 @@ DS_DEFINE_uint32(liveness_probe_timeout_s, 150, "Liveness probe timeout in secon DS_DEFINE_uint32(check_async_queue_empty_time_s, 15, "The async queue needs to be empty for a certain period of time before worker can exist."); DS_DECLARE_string(rocksdb_store_dir); +DS_DEFINE_string(sc_encrypt_secret_key, "", + "The encrypted secret key for stream cache. The key length is up to 1024 bytes and must be 32 bytes " + "after decryption."); +DS_DEFINE_validator(sc_encrypt_secret_key, &Validator::ValidateScEncryptSecretKey); DS_DEFINE_int32(max_rpc_session_num, 2048, "Maximum number of sessions that can be cached, must be within [512, 10'000]"); DS_DEFINE_validator(max_rpc_session_num, &Validator::ValidateMaxRpcSessionNum); @@ -152,7 +176,7 @@ static bool ValidatePopulate(const char *flagName, bool value) } DS_DEFINE_validator(shared_memory_populate, &ValidatePopulate); DS_DECLARE_string(sfs_path); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); DS_DECLARE_string(log_dir); namespace datasystem { @@ -169,6 +193,11 @@ bool EnableOCService() { return FLAGS_rpc_thread_num > 0; } + +bool EnableSCService() +{ + return FLAGS_sc_regular_socket_num > 0 && FLAGS_sc_stream_socket_num > 0; +} } // namespace WorkerOCServer::~WorkerOCServer() @@ -191,6 +220,9 @@ WorkerOCServer::~WorkerOCServer() objCacheWorkerWkSvc_.reset(); objCacheWorkerMsSvc_.reset(); objCacheClientWorkerSvc_.reset(); + streamCacheWorkerWorkerSvc_.reset(); + streamCacheClientWorkerSvc_.reset(); + streamCacheMasterSvc_.reset(); etcdCM_.reset(); replicaSvc_.reset(); datasystem::memory::Allocator::Instance()->Shutdown(); @@ -239,6 +271,71 @@ Status WorkerOCServer::InitMasterWorkerOCService() return Status::OK(); } +Status WorkerOCServer::CheckScEncryptSecretKey() +{ + if (FLAGS_sc_encrypt_secret_key.empty() || !SecretManager().Instance()->IsRootKeyActive()) { + return Status::OK(); + } + std::unique_ptr keyContent; + int outSize; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG( + SecretManager::Instance()->Decrypt(FLAGS_sc_encrypt_secret_key, keyContent, outSize), + "Sc encrypt secret key decrypt failed."); + (void)memset_s(keyContent.get(), outSize, 0, outSize); + const int AES_256_GCM_KEY_LEN = 32; + CHECK_FAIL_RETURN_STATUS(outSize == AES_256_GCM_KEY_LEN, StatusCode::K_INVALID, + "The decrypted length is incorrect."); + return Status::OK(); +} + +Status WorkerOCServer::InitClientWorkerSCService() +{ + RETURN_OK_IF_TRUE(!EnableSCService()); + RETURN_IF_NOT_OK(streamCacheClientWorkerSvc_->Init()); + CHECK_FAIL_RETURN_STATUS(FLAGS_sc_stream_socket_num + FLAGS_sc_regular_socket_num <= THREAD_POOL_SIZE_LIMIT, + StatusCode::K_INVALID, + "The number of service threads exceeds the upper limit, please adjust it"); + RETURN_IF_NOT_OK(CheckScEncryptSecretKey()); + RpcServiceCfg cfg; + cfg.numRegularSockets_ = FLAGS_sc_regular_socket_num; + cfg.numStreamSockets_ = 0; + cfg.hwm_ = RPC_HEAVY_SERVICE_HWM; + cfg.udsEnabled_ = FLAGS_ipc_through_shared_memory; + builder_.AddService(streamCacheClientWorkerSvc_.get(), cfg); + return Status::OK(); +} + +Status WorkerOCServer::InitWorkerWorkerSCService() +{ + RETURN_OK_IF_TRUE(!EnableSCService()); + RETURN_IF_NOT_OK(streamCacheWorkerWorkerSvc_->Init()); + CHECK_FAIL_RETURN_STATUS(FLAGS_sc_stream_socket_num + FLAGS_sc_regular_socket_num <= THREAD_POOL_SIZE_LIMIT, + StatusCode::K_INVALID, + "The number of service threads exceeds the upper limit, please adjust it"); + RpcServiceCfg cfg; + cfg.numRegularSockets_ = FLAGS_sc_regular_socket_num; + cfg.numStreamSockets_ = 0; + cfg.hwm_ = RPC_HEAVY_SERVICE_HWM; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + Validator::ValidatePort("FLAGS_sc_worker_worker_direct_port", FLAGS_sc_worker_worker_direct_port), K_INVALID, + FormatString("Invalid tcp/ip port value %d", FLAGS_sc_worker_worker_direct_port)); + cfg.tcpDirect_ = std::to_string(FLAGS_sc_worker_worker_direct_port); + builder_.AddService(streamCacheWorkerWorkerSvc_.get(), cfg); + return Status::OK(); +} + +Status WorkerOCServer::InitMasterWorkerSCService() +{ + RETURN_OK_IF_TRUE(!EnableSCService()); + RETURN_IF_NOT_OK(streamCacheMasterWorkerSvc_->Init()); + RpcServiceCfg cfg; + cfg.numRegularSockets_ = std::max(FLAGS_sc_regular_socket_num, LIGHTWEIGHT_SERVICE_THREAD_NUM); + cfg.numStreamSockets_ = DEFAULT_STREAM_SOCKET_NUM; + cfg.hwm_ = RPC_HEAVY_SERVICE_HWM; + builder_.AddService(streamCacheMasterWorkerSvc_.get(), cfg); + return Status::OK(); +} + Status WorkerOCServer::InitWorkerService() { RETURN_IF_NOT_OK(workerSvc_->Init()); @@ -283,6 +380,18 @@ Status WorkerOCServer::InitReplicaService() return Status::OK(); } +Status WorkerOCServer::InitMasterSCService() +{ + RETURN_OK_IF_TRUE(!EnableSCService()); + RETURN_IF_NOT_OK(streamCacheMasterSvc_->Init()); + RpcServiceCfg cfg; + cfg.numRegularSockets_ = std::max(FLAGS_rpc_thread_num, LIGHTWEIGHT_SERVICE_THREAD_NUM); + cfg.numStreamSockets_ = 0; + cfg.hwm_ = RPC_HEAVY_SERVICE_HWM; + builder_.AddService(streamCacheMasterSvc_.get(), cfg); + return Status::OK(); +} + #ifdef WITH_TESTS Status WorkerOCServer::InitUtOCService() { @@ -313,6 +422,12 @@ void WorkerOCServer::EnableLocalBypass() // Pass the receiving-side ptr of the WorkerMaster service so that the sending side can implement local objCacheMasterSvc_->AssignLocalWorker(objCacheWorkerMsSvc_.get()); } + + if (EnableSCService()) { + // MasterSCServiceImpl uses the RpcSessionManager singleton for managing the MasterWorkerSCApi. Provide the + // session manager with the fields needed to enable local bypass. + rpcSessionManager_->SetLocalArgs(hostPort_, streamCacheMasterWorkerSvc_); + } } Status WorkerOCServer::InitAkSk() @@ -345,6 +460,12 @@ void WorkerOCServer::CreateMasterServices() std::make_unique(hostPort_, persistenceApi_, akSkManager_, replicaManager_.get()); objCacheMasterSvc_->SetClusterManager(etcdCM_.get()); } + if (EnableSCService()) { + // create MasterSCServiceImpl + rpcSessionManager_ = std::make_shared(); + streamCacheMasterSvc_ = std::make_unique(hostPort_, akSkManager_, replicaManager_.get()); + streamCacheMasterSvc_->SetClusterManager(etcdCM_.get()); + } if (replicaManager_->MultiReplicaEnabled()) { replicaSvc_ = std::make_unique(hostPort_, replicaManager_.get(), akSkManager_); } @@ -377,6 +498,19 @@ void WorkerOCServer::CreateWorkerServices() objCacheWorkerMsSvc_ = std::make_shared( objCacheClientWorkerSvc_, akSkManager_); } + if (EnableSCService()) { + auto scAllocateManager = std::make_shared(evictionManager); + // create ClientWorkerSCService + streamCacheClientWorkerSvc_ = std::make_shared( + hostPort_, masterAddr_, streamCacheMasterSvc_.get(), akSkManager_, scAllocateManager); + streamCacheClientWorkerSvc_->SetClusterManager(etcdCM_.get()); + // create MasterWorkerSCServiceImpl + streamCacheMasterWorkerSvc_ = std::make_shared( + hostPort_, masterAddr_, streamCacheClientWorkerSvc_.get(), akSkManager_); + // create WorkerWorkerSCService + streamCacheWorkerWorkerSvc_ = + std::make_unique(streamCacheClientWorkerSvc_.get(), akSkManager_); + } } void WorkerOCServer::CreateAllServices() @@ -399,6 +533,7 @@ Status WorkerOCServer::InitializeMasterServices(const ClusterInfo &clusterInfo) { RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitMasterService(), "InitMasterService failed"); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitMasterOCService(), "InitMasterOCService failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitMasterSCService(), "InitMasterSCService failed"); if (replicaManager_->MultiReplicaEnabled()) { RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitReplicaService(), "InitReplicaService failed"); } @@ -433,6 +568,10 @@ Status WorkerOCServer::InitializeWorkerServices() if (etcdCM_->IsCurrentNodeMaster()) { EnableLocalBypass(); } + // Init the stream services and hook them up to the RPC server + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitClientWorkerSCService(), "InitClientWorkerSCService failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitMasterWorkerSCService(), "InitMasterWorkerSCService failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitWorkerWorkerSCService(), "InitWorkerWorkerSCService failed"); RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitWorkerService(), "InitWorkerService failed"); return Status::OK(); } @@ -530,7 +669,7 @@ Status WorkerOCServer::LoadHashRingFromRocksDb(ClusterInfo &clusterInfo, HashRin return Status(K_RUNTIME_ERROR, "Failed to parse HashRingPb from string"); } auto azName = GetSubStringBeforeField(itr->first, std::string(ETCD_RING_PREFIX) + "/").erase(0, 1); - if (!FLAGS_az_name.empty() && azName != FLAGS_az_name) { + if (!FLAGS_cluster_name.empty() && azName != FLAGS_cluster_name) { clusterInfo.otherAzHashrings.emplace_back(std::move(azName), std::move(itr->second)); } else { clusterInfo.localHashRing.emplace_back(std::move(*itr)); @@ -551,7 +690,7 @@ Status WorkerOCServer::LoadWorkersFromRocksDb(ClusterInfo &clusterInfo, auto workerAddr = GetSubStringAfterField(itr->first, std::string(ETCD_CLUSTER_TABLE) + "/"); CHECK_FAIL_RETURN_STATUS(!workerAddr.empty(), K_RUNTIME_ERROR, "The loaded cluster information is incomplete"); auto azName = GetSubStringBeforeField(itr->first, "/" + std::string(ETCD_CLUSTER_TABLE) + "/").erase(0, 1); - if (!FLAGS_az_name.empty() && azName != FLAGS_az_name) { + if (!FLAGS_cluster_name.empty() && azName != FLAGS_cluster_name) { clusterInfo.otherAzWorkers.emplace_back(std::move(workerAddr), std::move(itr->second)); } else { if (workerAddr != hostPort_.ToString()) { @@ -644,11 +783,11 @@ Status WorkerOCServer::Init() ssize_t decayMs = FLAGS_memory_reclamation_time_second * 1000; // convert to ms. RETURN_IF_NOT_OK_PRINT_ERROR_MSG(datasystem::memory::Allocator::Instance()->Init( sharedMemoryBytes, sharedDiskBytes, FLAGS_shared_memory_populate, true, - decayMs, FLAGS_oc_shm_threshold_percentage), + decayMs, FLAGS_oc_shm_threshold_percentage, FLAGS_sc_shm_threshold_percentage), "Init allocator failed"); // Call base class to init common service RETURN_IF_NOT_OK_PRINT_ERROR_MSG(CommonServer::Init(), "CommonServer init failed"); - RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitializeUrmaManager(hostPort_.Host()), "URMA init failed"); + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(InitializeUrmaManager(hostPort_), "URMA init failed"); RETURN_IF_NOT_OK(RpcStubCacheMgr::Instance().Init(FLAGS_max_rpc_session_num, hostPort_)); if (IsSupportL2Storage(GetCurrentStorageType())) { persistenceApi_ = std::make_shared(); @@ -728,7 +867,9 @@ Status WorkerOCServer::InitReplicaManager() param.etcdCM = etcdCM_.get(); param.masterWorkerService = objCacheWorkerMsSvc_.get(); param.workerWorkerService = objCacheWorkerWkSvc_.get(); + param.rpcSessionManager = rpcSessionManager_; param.isOcEnabled = EnableOCService(); + param.isScEnabled = EnableSCService(); return replicaManager_->Init(param); } @@ -771,6 +912,26 @@ void WorkerOCServer::RegisteringWorkerCallbackFunc() return std::to_string(objCacheClientWorkerSvc_->GetTotalObjectSize()); }); } + + if (EnableSCService()) { + instance.RegisterCollectHandler(ResMetricName::STREAM_COUNT, + [this]() { return streamCacheClientWorkerSvc_->GetTotalStreamCount(); }); + + // The usage of WorkerSCService + instance.RegisterCollectHandler(ResMetricName::WORKER_SC_SERVICE_THREAD_POOL, + [this]() { return GetRpcServicesUsage("ClientWorkerSCService").ToString(); }); + + // The usage of WorkerSCService + instance.RegisterCollectHandler(ResMetricName::WORKER_WORKER_SC_SERVICE_THREAD_POOL, + [this]() { return GetRpcServicesUsage("WorkerWorkerSCService").ToString(); }); + + instance.RegisterCollectHandler(ResMetricName::STREAM_REMOTE_SEND_SUCCESS_RATE, + [this]() { return streamCacheClientWorkerSvc_->GetSCRemoteSendSuccessRate(); }); + + instance.RegisterCollectHandler(ResMetricName::SC_LOCAL_CACHE, [this]() { + return streamCacheWorkerWorkerSvc_->GetUsageMonitor().GetLocalMemoryUsed(); + }); + } } void WorkerOCServer::RegisteringMasterCallbackFunc() @@ -801,6 +962,15 @@ void WorkerOCServer::RegisteringMasterCallbackFunc() []() { return RES_THREAD_POOL_DEFAULT_USAGE; }); } } + + if (EnableSCService()) { + // The usage of MasterWorkerOCService + instance.RegisterCollectHandler(ResMetricName::MASTER_WORKER_SC_SERVICE_THREAD_POOL, + [this]() { return GetRpcServicesUsage("MasterWorkerSCService").ToString(); }); + // The usage of MasterOcService + instance.RegisterCollectHandler(ResMetricName::MASTER_SC_SERVICE_THREAD_POOL, + [this]() { return GetRpcServicesUsage("MasterSCService").ToString(); }); + } } void WorkerOCServer::RegisteringThirdComponentCallbackFunc() @@ -879,6 +1049,9 @@ Status WorkerOCServer::Start() // The task via uds accept fd is started here. clientWorkerCommonSvcStatus_ = loadFunctor(*workerSvc_); if (etcdCM_->IsCurrentNodeMaster()) { + if (EnableSCService()) { + RETURN_IF_NOT_OK_APPEND_MSG(streamCacheMasterSvc_->StartCheckMetadata(), "\nmaster Start failed."); + } if (EnableOCService()) { RETURN_IF_NOT_OK_APPEND_MSG(objCacheClientWorkerSvc_->WhetherNonRestart(), "\nWorker Start failed."); } else { @@ -965,7 +1138,7 @@ Status WorkerOCServer::PreShutDown() bool waitFlag = false; auto traceId = Trace::Instance().GetTraceID(); - if (EnableOCService()) { + if (EnableOCService() || EnableSCService()) { RETURN_IF_EXCEPTION_OCCURS(checkAsyncTasksThread_ = std::make_unique([this, traceId]() { TraceGuard traceGuard = Trace::Instance().SetTraceNewID(traceId); CheckAsyncTasks(); @@ -1045,6 +1218,10 @@ Status WorkerOCServer::Shutdown() objCacheMasterSvc_->Shutdown(); } + if (streamCacheMasterSvc_) { + streamCacheMasterSvc_->Shutdown(); + } + if (etcdCM_) { Status rc = etcdCM_->Shutdown(); if (rc.IsError()) { @@ -1099,6 +1276,12 @@ void WorkerOCServer::AfterClientLostHandler(const std::string &clientId) LOG_IF_ERROR(objCacheClientWorkerSvc_->RefreshMeta(clientId), FormatString("Failed to RefreshMeta for client:%s", clientId)); } + if (streamCacheClientWorkerSvc_ != nullptr) { + // When a client is lost it uses forceMode true when closing the producers and consumers. + // Any errors that occur from this close are ignored. + LOG_IF_ERROR(streamCacheClientWorkerSvc_->ClosePubSubForClientLost(clientId), + FormatString("Failed to ClosePubSubForClient: %s ", clientId)); + } ClientManager::Instance().RemoveClient(clientId); } @@ -1149,9 +1332,13 @@ bool WorkerOCServer::IsAsyncTasksRunning() // check etcd and persistence async task if (etcdCM_->IsCurrentNodeMaster()) { return (objCacheClientWorkerSvc_ != nullptr && objCacheClientWorkerSvc_->HaveAsyncTasksRunning()) - || (objCacheMasterSvc_ != nullptr && objCacheMasterSvc_->HaveAsyncMetaRequest()); + || (objCacheMasterSvc_ != nullptr && objCacheMasterSvc_->HaveAsyncMetaRequest()) + || (streamCacheClientWorkerSvc_ != nullptr && streamCacheClientWorkerSvc_->HaveTasksToProcess()) + || (clusterStore_ != nullptr && !clusterStore_->IsAsyncQueueEmpty()); } - return (objCacheClientWorkerSvc_ != nullptr && objCacheClientWorkerSvc_->HaveAsyncTasksRunning()); + return (objCacheClientWorkerSvc_ != nullptr && objCacheClientWorkerSvc_->HaveAsyncTasksRunning()) + || (streamCacheClientWorkerSvc_ != nullptr && streamCacheClientWorkerSvc_->HaveTasksToProcess()) + || (clusterStore_ != nullptr && !clusterStore_->IsAsyncQueueEmpty()); } void WorkerOCServer::CheckAsyncTasks() diff --git a/src/datasystem/worker/worker_oc_server.h b/src/datasystem/worker/worker_oc_server.h index 9a8b2e1..6655470 100644 --- a/src/datasystem/worker/worker_oc_server.h +++ b/src/datasystem/worker/worker_oc_server.h @@ -35,12 +35,16 @@ #include "datasystem/master/object_cache/master_oc_service_impl.h" #include "datasystem/master/replica_manager.h" #include "datasystem/master/replication_service_impl.h" +#include "datasystem/master/stream_cache/master_sc_service_impl.h" #include "datasystem/server/common_server.h" #include "datasystem/worker/hash_ring/hash_ring.h" #include "datasystem/worker/cluster_manager/etcd_cluster_manager.h" #include "datasystem/worker/object_cache/master_worker_oc_service_impl.h" #include "datasystem/worker/object_cache/worker_oc_service_impl.h" #include "datasystem/worker/object_cache/worker_worker_oc_service_impl.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/master_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/worker_worker_sc_service_impl.h" #include "datasystem/worker/worker_liveness_check.h" #include "datasystem/worker/worker_service_impl.h" #ifdef WITH_TESTS @@ -165,6 +169,24 @@ private: */ Status InitMasterWorkerOCService(); + /** + * @brief Init stream cache service for client request. + * @return Status of the call. + */ + Status InitClientWorkerSCService(); + + /** + * @brief Init stream cache service for worker request. + * @return Status of the call. + */ + Status InitWorkerWorkerSCService(); + + /** + * @brief Init stream cache service for master request. + * @return Status of the call. + */ + Status InitMasterWorkerSCService(); + /** * @brief Init common service for client request. * @return Status of the call. @@ -183,6 +205,12 @@ private: */ Status InitMasterOCService(); + /** + * @brief Init stream cache service for worker request. + * @return Status of the call. + */ + Status InitMasterSCService(); + /** * @brief Init rocksdb replica service for worker request. * @return Status of the call. @@ -312,6 +340,12 @@ private: */ Status InitReplicaManager(); + /** + * @brief Check sc_encrypt_secret_key. + * @return Status of this call. + */ + Status CheckScEncryptSecretKey(); + /** * @brief Update cluster info in rocksdb. * @param[in] event The event watched from ETCD. @@ -393,6 +427,7 @@ private: std::shared_ptr akSkManager_{ nullptr }; HostPort masterAddr_; std::unique_ptr replicaManager_{ nullptr }; + std::shared_ptr rpcSessionManager_{ nullptr }; std::unique_ptr etcdCM_{ nullptr }; std::unique_ptr workerSvc_{ nullptr }; // Worker common service. WaitPost waitCond_; @@ -403,13 +438,21 @@ private: std::shared_ptr objCacheWorkerWkSvc_{ nullptr }; // Object cache rpc service for master request. std::shared_ptr objCacheWorkerMsSvc_{ nullptr }; + // Stream cache rpc service for client request. + std::shared_ptr streamCacheClientWorkerSvc_{ nullptr }; + // Stream cache rpc service for master request. + std::shared_ptr streamCacheMasterWorkerSvc_{ nullptr }; + // Stream cache rpc service for worker request. + std::shared_ptr streamCacheWorkerWorkerSvc_{ nullptr }; // Master services exist in the worker for Object cache compile mode. std::shared_ptr commonSvc_{ nullptr }; std::unique_ptr objCacheMasterSvc_{ nullptr }; + std::unique_ptr streamCacheMasterSvc_{ nullptr }; std::unique_ptr replicaSvc_{ nullptr }; std::future objCacheMasterSvcStatus_; std::future objCacheMasterAdSvcStatus_; + std::future streamCacheMasterSvcStatus_; // Check whether all asynchronous tasks are completed before the worker ends. std::unique_ptr checkAsyncTasksThread_{ nullptr }; diff --git a/src/datasystem/worker/worker_service_impl.cpp b/src/datasystem/worker/worker_service_impl.cpp index c5cc59d..9e08ee4 100644 --- a/src/datasystem/worker/worker_service_impl.cpp +++ b/src/datasystem/worker/worker_service_impl.cpp @@ -56,13 +56,14 @@ #include "datasystem/worker/client_manager/client_manager.h" DS_DECLARE_string(unix_domain_socket_dir); +DS_DECLARE_uint32(page_size); DS_DEFINE_bool(authorization_enable, false, "Indicates whether to enable the tenant authentication, default is false."); DS_DECLARE_bool(ipc_through_shared_memory); DS_DEFINE_uint32(max_client_num, 200, "Maximum number of clients that can be connected to a worker. Value range: [1, 10000]"); DS_DEFINE_validator(max_client_num, &Validator::ValidateClientNum); DS_DECLARE_uint32(node_timeout_s); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); DS_DECLARE_uint64(client_dead_timeout_s); DS_DECLARE_bool(enable_huge_tlb); DS_DEFINE_uint64(oc_shm_transfer_threshold_kb, 500u, @@ -256,6 +257,7 @@ Status WorkerServiceImpl::RegisterClient(const RegisterClientReqPb &req, Registe std::string id; RETURN_IF_NOT_OK_PRINT_ERROR_MSG(worker_->GetShmQueueUnit(lockId, fd, mmapSize, offset, id), "worker process get ShmQ unit failed"); + rsp.set_page_size(FLAGS_page_size); rsp.set_quorum_timeout_mult(timeoutMultiplier_); rsp.set_client_id(clientId); rsp.set_lock_id(lockId); diff --git a/example/src/python/hetero_cache/hetero_h2d_d2h_benchmark.py b/tests/benchmark/hetero_h2d_d2h_benchmark.py similarity index 100% rename from example/src/python/hetero_cache/hetero_h2d_d2h_benchmark.py rename to tests/benchmark/hetero_h2d_d2h_benchmark.py diff --git a/tests/kvconnector/patch/0001-implement-yuanrong-datasystem-connector.patch b/tests/kvconnector/patch/0001-implement-yuanrong-datasystem-connector.patch new file mode 100644 index 0000000..6d230a5 --- /dev/null +++ b/tests/kvconnector/patch/0001-implement-yuanrong-datasystem-connector.patch @@ -0,0 +1,1598 @@ +From e88d5484af1fd7afda5d26d24a647418dfd99b68 Mon Sep 17 00:00:00 2001 +From: yangsonglin +Date: Mon, 17 Nov 2025 15:01:54 +0800 +Subject: [PATCH] implement yuanrong-datasystem connector. + +--- + .github/workflows/vllm_ascend_test_pd.yaml | 6 +- + tests/e2e/pd_disaggreate/yuanrong/README.md | 155 ++++ + .../pd_disaggreate/yuanrong/clean_yuanrong.sh | 9 + + .../yuanrong/run_pd_instances.sh | 44 + + .../yuanrong/run_proxy_server.sh | 14 + + .../pd_disaggreate/yuanrong/run_yuanrong.sh | 24 + + .../yuanrong/simple_pd_proxy_server.py | 212 +++++ + .../yuanrong/test_yuanrong_connector.py | 141 +++ + vllm_ascend/core/scheduler.py | 38 + + vllm_ascend/distributed/__init__.py | 4 + + vllm_ascend/distributed/yuanrong_connector.py | 826 ++++++++++++++++++ + 11 files changed, 1472 insertions(+), 1 deletion(-) + create mode 100644 tests/e2e/pd_disaggreate/yuanrong/README.md + create mode 100644 tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh + create mode 100644 tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh + create mode 100644 tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh + create mode 100644 tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh + create mode 100644 tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py + create mode 100644 tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py + create mode 100644 vllm_ascend/distributed/yuanrong_connector.py + +diff --git a/.github/workflows/vllm_ascend_test_pd.yaml b/.github/workflows/vllm_ascend_test_pd.yaml +index fee06bee..5c0919f6 100644 +--- a/.github/workflows/vllm_ascend_test_pd.yaml ++++ b/.github/workflows/vllm_ascend_test_pd.yaml +@@ -109,4 +109,8 @@ jobs: + - name: Run vllm-project/vllm-ascend PD Disaggregation edge test + run: | + git config --global --add safe.directory/__w/vllm-ascend/vllm-ascend +- bash tests/e2e/pd_disaggreate/run_edge_case_test.sh +\ No newline at end of file ++ bash tests/e2e/pd_disaggreate/run_edge_case_test.sh ++ ++ - name: Run vllm-project/vllm-ascend PD Disaggregation test with YuanRong Connector ++ run: | ++ pytest -sv tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py +diff --git a/tests/e2e/pd_disaggreate/yuanrong/README.md b/tests/e2e/pd_disaggreate/yuanrong/README.md +new file mode 100644 +index 00000000..427919cf +--- /dev/null ++++ b/tests/e2e/pd_disaggreate/yuanrong/README.md +@@ -0,0 +1,155 @@ ++# Overview: transfer KVCache through host memory with YuanRong connector ++ ++### Dataflow ++ ++ ++|----------------------- prefill node ---------------------|   |---------------------- decode node ---------------------| ++ ++Prefill Instance -----> YuanRongConnector -----> YuanRong Data Worker -----> YuanRongConnector -----> Decode Instance ++ ++|----- kv on npu -----|   |----- kv offload to host -----|   |----- kv transfer by host net -----|   |----- kv load to npu -----| ++ ++### Pros ++- Network jitter and failures are handled outside of the vLLM process, better isolation and fault tolerance ++- No need to allocate communication buffers on NPU, enable a larger sequence batch and throughput ++- Work seamlessly with features those require offloading kvcache to host memory or SSD, like prefix cache, priority scheduling, RAG, etc. ++### Cons ++- Higher transfer latency compared with device-to-device transfer, not optimal for latency-sensitive scenarios ++ ++ ++ ++ ++ ++# Installation ++ ++## Install etcd ++#### 1. Download the latest binaries from [etcd github releases](https://github.com/etcd-io/etcd/releases) ++``` ++ETCD_VERSION="v3.5.12" ++wget https://github.com/etcd-io/etcd/releases/download/${ETCD_VERSION}/etcd-${ETCD_VERSION}-linux-amd64.tar.gz ++``` ++#### 2. Unzip and install ++``` ++tar -xvf etcd-${ETCD_VERSION}-linux-amd64.tar.gz ++cd etcd-${ETCD_VERSION}-linux-amd64 ++# copy the binary to system ++sudo cp etcd etcdctl /usr/local/bin/ ++``` ++#### 3. Verify installation ++``` ++etcd --version ++etcdctl version ++``` ++ ++ ++## Install YR-DataSystem ++#### Install from pip (recommended): ++ ++``` ++pip install yr-datasystem ++``` ++ ++#### Or install from source: ++ ++- Refer to the yr-datasystem documentation [here](https://gitee.com/openeuler/yuanrong-datasystem) ++ ++ ++ ++# Deployment ++## Deploy etcd ++> Note: this is the minimal example to deploy etcd, more can be found at the [etcd official site](https://etcd.io/docs/current/op-guide/clustering/). ++ ++#### Deploy a single node etcd cluster at port 2379: ++``` ++etcd \ ++ --name etcd-single \ ++ --data-dir /tmp/etcd-data \ ++ --listen-client-urls http://0.0.0.0:2379 \ ++ --advertise-client-urls http://0.0.0.0:2379 \ ++ --listen-peer-urls http://0.0.0.0:2380 \ ++ --initial-advertise-peer-urls http://0.0.0.0:2380 \ ++ --initial-cluster etcd-single=http://0.0.0.0:2380 ++``` ++ ++ ++#### Parameters: ++- --name:cluster name ++- --data-dir:directory to store data ++- --listen-client-urls:address to listen from clients (0.0.0.0 allows access from any IP address) ++- --advertise-client-urls:address advertised to clients ++- --listen-peer-urls:address to listen from other nodes in the cluster ++- --initial-advertise-peer-urls:address advertised to other nodes in the cluster ++- --initial-cluster:initial nodes in the cluster (format: name1=peer_url1,name2=peer_url2,...) ++ ++#### Try to access the etcd cluster with the `etcdctl` command: ++``` ++etcdctl --endpoints "127.0.0.1:2379" put key "value" ++etcdctl --endpoints "127.0.0.1:2379" get key ++``` ++etcd cluster is successfully deployed if the commands work good. ++ ++## Deploy YR-DataSystem ++#### Deploy a single node yr-datasystem cluster with the minimum config: ++``` ++dscli start -w --worker_address "127.0.0.1:31501" --etcd_address "127.0.0.1:2379" ++# [INFO] [ OK ] Start worker service @ 127.0.0.1:31501 success, PID: 38100 ++``` ++yr-datasystem is deployed successful as you see the `[ OK ]` output. ++ ++#### To safely stop and clean the yr-datasystem processes, run the command: ++``` ++dscli stop -w --worker_address "127.0.0.1:31501" ++``` ++#### Please refer to the [yr-datasystem gitee repo](https://gitee.com/openeuler/yuanrong-datasystem) for more information. ++ ++# Run disaggregated prefill with vLLM v1 ++ ++> Note: an example script for 1P1D disaggregated prefill is available at: *vllm-ascend/tests/e2e/pd_disaggregate/yuanrong/test_yuanrong_connector.py* ++ ++#### 1. Populate the yr-datasystem worker address with environment variable: ++ ++`export DS_WORKER_ADDR=127.0.0.1:31501` ++ ++YuanRongConnector will read the yr-datasystem address from this environment variable ++ ++#### 2. Start two vLLM instances with YuanRongConnector as the backend to form a 1P1D disaggregated cluster: ++``` ++export VLLM_USE_V1=True ++ ++# start a prefill instance on localhost:8100 ++ASCEND_RT_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \ ++ --port 8100 \ ++ --max-num-batched-tokens 45000 \ ++ --gpu-memory-utilization 0.8 \ ++ --trust-remote-code \ ++ --enforce-eager \ ++ --kv-transfer-config \ ++ '{"kv_connector":"YuanRongConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & ++ ++# start a decode instance on localhost:8200 ++ASCEND_RT_VISIBLE_DEVICES=1 vllm serve Qwen/Qwen2.5-7B-Instruct \ ++ --port 8200 \ ++ --max-num-batched-tokens 45000 \ ++ --gpu-memory-utilization 0.8 \ ++ --trust-remote-code \ ++ --enforce-eager \ ++ --kv-transfer-config \ ++ '{"kv_connector":"YuanRongConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & ++``` ++ ++#### 3. Start a proxy server to serve and route HTTP requests: ++``` ++python vllm-ascend/tests/e2e/pd_disaggregate/yuanrong/simple_pd_proxy_server.py --prefiller-port 8100 --decoder-port 8200 ++``` ++ ++#### 4. Send HTTP requests to the proxy server: ++``` ++curl -X POST -s http://localhost:8000/v1/completions \ ++-H "Content-Type: application/json" \ ++-d '{ ++"model": "Qwen/Qwen2.5-7B-Instruct", ++"prompt": "who is the presiden of the united states?", ++"max_tokens": 50, ++"temperature": 0 ++}' ++``` +\ No newline at end of file +diff --git a/tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh b/tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh +new file mode 100644 +index 00000000..7941a455 +--- /dev/null ++++ b/tests/e2e/pd_disaggreate/yuanrong/clean_yuanrong.sh +@@ -0,0 +1,9 @@ ++#!/bin/bash ++ ++HOST_IP=$1 ++WORKER_PORT=$2 ++ ++dscli stop \ ++ --worker_address ${HOST_IP}:${WORKER_PORT} ++ ++pkill etcd +\ No newline at end of file +diff --git a/tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh b/tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh +new file mode 100644 +index 00000000..c2e18b57 +--- /dev/null ++++ b/tests/e2e/pd_disaggreate/yuanrong/run_pd_instances.sh +@@ -0,0 +1,44 @@ ++#!/bin/bash ++ ++MODEL_NAME=$1 ++HOST_IP=$2 ++PREFILL_PORT=$3 ++DECODE_PORT=$4 ++ ++if python -c "import datasystem" &> /dev/null; then ++ echo "yr-datasystem is already installed" ++else ++ echo "Install yr-datasystem ..." ++ python -m pip install yr-datasystem ++fi ++ ++wait_for_server() { ++ local port=$1 ++ timeout 1200 bash -c " ++ until curl -s ${HOST_IP}:${port}/v1/completions > /dev/null; do ++ sleep 1 ++ done" && return 0 || return 1 ++} ++ ++ASCEND_RT_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ ++ --host ${HOST_IP} \ ++ --port ${PREFILL_PORT} \ ++ --max-num-batched-tokens 45000 \ ++ --gpu-memory-utilization 0.8 \ ++ --trust-remote-code \ ++ --enforce-eager \ ++ --kv-transfer-config \ ++ '{"kv_connector":"YuanRongConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & ++ ++ASCEND_RT_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ ++ --host ${HOST_IP} \ ++ --port ${DECODE_PORT} \ ++ --max-num-batched-tokens 45000 \ ++ --gpu-memory-utilization 0.8 \ ++ --trust-remote-code \ ++ --enforce-eager \ ++ --kv-transfer-config \ ++ '{"kv_connector":"YuanRongConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & ++ ++wait_for_server ${PREFILL_PORT} ++wait_for_server ${DECODE_PORT} +diff --git a/tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh b/tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh +new file mode 100644 +index 00000000..879f4863 +--- /dev/null ++++ b/tests/e2e/pd_disaggreate/yuanrong/run_proxy_server.sh +@@ -0,0 +1,14 @@ ++#!/bin/bash ++PROXY_SERVER_SCRIPT=$1 ++HOST=$2 ++PORT=$3 ++PREFILL_PORT=$4 ++DECODE_PORT=$5 ++ ++python ${PROXY_SERVER_SCRIPT} \ ++ --host ${HOST} \ ++ --port ${PORT} \ ++ --prefiller-host ${HOST} \ ++ --prefiller-port ${PREFILL_PORT} \ ++ --decoder-host ${HOST} \ ++ --decoder-port ${DECODE_PORT} & +\ No newline at end of file +diff --git a/tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh b/tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh +new file mode 100644 +index 00000000..8f4e3d7f +--- /dev/null ++++ b/tests/e2e/pd_disaggreate/yuanrong/run_yuanrong.sh +@@ -0,0 +1,24 @@ ++#!/bin/bash ++ ++HOST_IP=$1 ++WORKER_PORT=$2 ++ETCD_PORT=$3 ++ ++MASTER_PORT=`expr ${WORKER_PORT} + 1` ++ETCD_PEER_PORT=`expr ${ETCD_PORT} + 1` ++ ++etcd \ ++ --name etcd-yuanrong \ ++ --data-dir /tmp/etcd-yuanrong \ ++ --listen-client-urls http://${HOST_IP}:${ETCD_PORT} \ ++ --advertise-client-urls http://${HOST_IP}:${ETCD_PORT} \ ++ --listen-peer-urls http://${HOST_IP}:${ETCD_PEER_PORT} \ ++ --initial-advertise-peer-urls http://${HOST_IP}:${ETCD_PEER_PORT} \ ++ --initial-cluster etcd-yuanrong=http://${HOST_IP}:${ETCD_PEER_PORT} & ++ ++ ++dscli start \ ++ -w \ ++ --worker_address ${HOST_IP}:${WORKER_PORT} \ ++ --master_address ${HOST_IP}:${MASTER_PORT} \ ++ --etcd_address ${HOST_IP}:${ETCD_PORT} & +\ No newline at end of file +diff --git a/tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py b/tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py +new file mode 100644 +index 00000000..c6b957cd +--- /dev/null ++++ b/tests/e2e/pd_disaggreate/yuanrong/simple_pd_proxy_server.py +@@ -0,0 +1,212 @@ ++import argparse ++import os ++import time ++from contextlib import asynccontextmanager ++from uuid import uuid4 ++ ++import httpx ++import numpy as np ++from fastapi import FastAPI, Request ++from fastapi.responses import StreamingResponse ++ ++ ++@asynccontextmanager ++async def lifespan(app: FastAPI): ++ """ ++ Lifespan context manager to handle startup and shutdown events. ++ """ ++ # Startup: Initialize clients ++ prefiller_base_url = ( ++ f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1") ++ decoder_base_url = ( ++ f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1") ++ ++ app.state.prefill_client = httpx.AsyncClient(timeout=None, ++ base_url=prefiller_base_url) ++ app.state.decode_client = httpx.AsyncClient(timeout=None, ++ base_url=decoder_base_url) ++ ++ yield ++ ++ # Shutdown: Close clients ++ await app.state.prefill_client.aclose() ++ await app.state.decode_client.aclose() ++ ++ ++# Update FastAPI app initialization to use lifespan ++app = FastAPI(lifespan=lifespan) ++ ++ ++class StatsCalculator: ++ ++ def __init__(self): ++ self._stats = [] ++ self._last_log_time = time.time() ++ ++ def add(self, value): ++ self._stats.append(value) ++ if time.time() - self._last_log_time > 5: ++ self._log_stats() ++ self._last_log_time = time.time() ++ ++ def _log_stats(self): ++ # Print average, median, and 99th percentile ++ np_arr = np.array(self._stats) ++ output_str = ( ++ f"\nNum requests: {len(self._stats)}" + ++ "\nPrefill node TTFT stats:" + ++ f"\n - Average (ms): {np.mean(np_arr)}" + ++ f"\n - Median (ms): {np.median(np_arr)}" + ++ f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n") ++ print( ++ "===============================", ++ output_str, ++ "===============================", ++ ) ++ ++ ++stats_calculator = StatsCalculator() ++counter = 0 ++ ++ ++def parse_args(): ++ parser = argparse.ArgumentParser() ++ ++ parser.add_argument("--port", type=int, default=8000) ++ parser.add_argument("--host", type=str, default="localhost") ++ parser.add_argument("--prefiller-host", type=str, default="localhost") ++ parser.add_argument("--prefiller-port", type=int, default=8100) ++ parser.add_argument("--decoder-host", type=str, default="localhost") ++ parser.add_argument("--decoder-port", type=int, default=8200) ++ args = parser.parse_args() ++ return args ++ ++ ++# Initialize variables to hold the persistent clients ++app.state.prefill_client = None ++app.state.decode_client = None ++ ++ ++async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, ++ req_data: dict, request_id: str): ++ """ ++ Send a request to a service using a persistent client. ++ """ ++ req_data = req_data.copy() ++ req_data["max_tokens"] = 1 ++ if "max_completion_tokens" in req_data: ++ req_data["max_completion_tokens"] = 1 ++ ++ headers = { ++ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", ++ "X-Request-Id": request_id ++ } ++ response = await client.post(endpoint, json=req_data, headers=headers) ++ response.raise_for_status() ++ return response ++ ++ ++async def stream_service_response(client: httpx.AsyncClient, endpoint: str, ++ req_data: dict, request_id: str): ++ """ ++ Asynchronously stream the response from a service using a persistent client. ++ """ ++ headers = { ++ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", ++ "X-Request-Id": request_id ++ } ++ async with client.stream("POST", endpoint, json=req_data, ++ headers=headers) as response: ++ response.raise_for_status() ++ async for chunk in response.aiter_bytes(): ++ yield chunk ++ ++ ++@app.post("/v1/completions") ++async def handle_completions(request: Request): ++ global counter, stats_calculator ++ counter += 1 ++ ++ st = time.time() ++ try: ++ req_data = await request.json() ++ request_id = str(uuid4()) ++ ++ # Send request to prefill service, ignore the response ++ await send_request_to_service(app.state.prefill_client, "/completions", ++ req_data, request_id) ++ ++ et = time.time() ++ stats_calculator.add(et - st) ++ ++ # Stream response from decode service ++ async def generate_stream(): ++ async for chunk in stream_service_response(app.state.decode_client, ++ "/completions", ++ req_data, request_id): ++ yield chunk ++ ++ return StreamingResponse(generate_stream(), ++ media_type="text/event-stream") ++ ++ except Exception as e: ++ import sys ++ import traceback ++ ++ exc_info = sys.exc_info() ++ print( ++ "Error occurred in disagg prefill proxy server - completions endpoint" ++ ) ++ print(e) ++ print("".join(traceback.format_exception(*exc_info))) ++ raise ++ ++ ++@app.post("/v1/chat/completions") ++async def handle_chat_completions(request: Request): ++ global counter, stats_calculator ++ counter += 1 ++ ++ st = time.time() ++ try: ++ req_data = await request.json() ++ request_id = str(uuid4()) ++ ++ # Send request to prefill service, ignore the response ++ await send_request_to_service(app.state.prefill_client, ++ "/chat/completions", req_data, ++ request_id) ++ ++ et = time.time() ++ stats_calculator.add(et - st) ++ ++ # Stream response from decode service ++ async def generate_stream(): ++ async for chunk in stream_service_response(app.state.decode_client, ++ "/chat/completions", ++ req_data, request_id): ++ yield chunk ++ ++ return StreamingResponse(generate_stream(), ++ media_type="text/event-stream") ++ ++ except Exception as e: ++ import sys ++ import traceback ++ ++ exc_info = sys.exc_info() ++ print( ++ "Error occurred in disagg prefill proxy server - chat completions endpoint" ++ ) ++ print(e) ++ print("".join(traceback.format_exception(*exc_info))) ++ raise ++ ++ ++if __name__ == "__main__": ++ global global_args ++ global_args = parse_args() ++ ++ import uvicorn ++ ++ uvicorn.run(app, host=global_args.host, port=global_args.port) +diff --git a/tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py b/tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py +new file mode 100644 +index 00000000..aaf6da39 +--- /dev/null ++++ b/tests/e2e/pd_disaggreate/yuanrong/test_yuanrong_connector.py +@@ -0,0 +1,141 @@ ++import json ++import os ++import signal ++import subprocess ++import time ++ ++import pytest ++import requests ++ ++HOST_IP = "127.0.0.1" ++MODEL_NAME = "Qwen/Qwen2.5-7B" ++WORKSPACE_DIR = "./tests/e2e/pd_disaggreate/yuanrong/" ++ ++RUN_INSTANCES_SCRIPT = os.path.join(WORKSPACE_DIR, ++ "run_pd_with_yuanrong_connector.sh") ++RUN_PROXY_SERVER_SCRIPT = os.path.join(WORKSPACE_DIR, "run_proxy_server.sh") ++RUN_YUANRONG_SCRIPT = os.path.join(WORKSPACE_DIR, "run_yuanrong.sh") ++CLEAN_YUANRONG_SCRIPT = os.path.join(WORKSPACE_DIR, "clean_yuanrong.sh") ++PROXY_SERVER_SCRIPT = os.path.join(WORKSPACE_DIR, "simple_pd_proxy_server.py") ++PROXY_PORT = 8000 ++PREFILL_PORT = 8100 ++DECODE_PORT = 8200 ++WORKER_PORT = 31530 ++ETCD_PORT = 2411 ++ ++PROMPT_ANSWER = { ++ "who is the president of the united states?": "?\nDonald Trump" ++} ++RUN_INSTANCE_KEYWORDS = "vllm serve" ++RUN_PROXY_SERVER_KEYWORDS = "simple_pd_proxy_server.py" ++ ++ ++def start_yuanrong(): ++ proc = subprocess.Popen([ ++ "bash", RUN_YUANRONG_SCRIPT, f"{HOST_IP}", f"{WORKER_PORT}", ++ f"{ETCD_PORT}" ++ ]) ++ proc.wait() ++ ++ ++def clean_yuanrong(): ++ proc = subprocess.Popen( ++ ["bash", CLEAN_YUANRONG_SCRIPT, f"{HOST_IP}", f"{WORKER_PORT}"]) ++ proc.wait() ++ ++ ++def start_instances(): ++ proc = subprocess.Popen([ ++ "bash", RUN_INSTANCES_SCRIPT, f"{MODEL_NAME}", f"{HOST_IP}", ++ f"{PREFILL_PORT}", f"{DECODE_PORT}" ++ ]) ++ proc.wait() ++ ++ ++def start_proxy_server(): ++ proc = subprocess.Popen([ ++ "bash", RUN_PROXY_SERVER_SCRIPT, PROXY_SERVER_SCRIPT, f"{HOST_IP}", ++ f"{PROXY_PORT}", f"{PREFILL_PORT}", f"{DECODE_PORT}" ++ ]) ++ proc.wait() ++ ++ ++def clean_instances_and_proxy_server(): ++ instance_pids = get_pids_by_keyword(RUN_INSTANCE_KEYWORDS) ++ proxy_pids = get_pids_by_keyword(RUN_PROXY_SERVER_KEYWORDS) ++ for pid in proxy_pids + instance_pids: ++ pid = int(pid) ++ try: ++ os.kill(pid, signal.SIGINT) ++ except ProcessLookupError: ++ print(f"No such process with PID {pid}") ++ except PermissionError: ++ print(f"Permission denied to send SIGINT to PID {pid}") ++ except Exception as e: ++ print(f"Error: {e}") ++ time.sleep(3) ++ pid = int(pid) ++ try: ++ os.kill(pid, signal.SIGKILL) ++ except ProcessLookupError: ++ print(f"No such process with PID {pid}") ++ except PermissionError: ++ print(f"Permission denied to send SIGKILL to PID {pid}") ++ except Exception as e: ++ print(f"Error: {e}") ++ ++ ++def send_post_request(url, data): ++ try: ++ response = requests.post(url, json=data, timeout=10) ++ response.raise_for_status() ++ return response.text ++ except requests.exceptions.RequestException as e: ++ return f"Request failed: {e}" ++ ++ ++def get_pids_by_keyword(keyword): ++ try: ++ # Run 'ps aux' to get all running processes ++ result = subprocess.run(['ps', 'aux'], ++ stdout=subprocess.PIPE, ++ text=True) ++ lines = result.stdout.strip().split('\n') ++ ++ matching_pids = [] ++ ++ for line in lines[1:]: # Skip the header line ++ if keyword in line: ++ parts = line.split() ++ pid = parts[1] # PID is the second column ++ matching_pids.append(pid) ++ ++ return matching_pids ++ except Exception as e: ++ return f"error occurred trying to get PIDs of processes containing keyword {keyword}, error: {e}" ++ ++ ++@pytest.fixture ++def setup_and_clean_cluster(): ++ start_yuanrong() ++ start_instances() ++ start_proxy_server() ++ time.sleep(3) ++ yield ++ clean_instances_and_proxy_server() ++ clean_yuanrong() ++ ++ ++def test_yuanrong_pd_dist(setup_and_clean_cluster): ++ proxy_url = f"http://{HOST_IP}:{PROXY_PORT}/v1/completions" ++ for prompt, answer in PROMPT_ANSWER.items(): ++ data = { ++ "model": MODEL_NAME, ++ "prompt": prompt, ++ "max_tokens": 50, ++ "temperature": 0 ++ } ++ response_str = send_post_request(proxy_url, data) ++ response_json = json.loads(response_str) ++ output = response_json["choices"][0]["text"] ++ assert output == answer, f"wrong response: {output}, expected: {answer}" +diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py +index f4c8cc73..267a9635 100644 +--- a/vllm_ascend/core/scheduler.py ++++ b/vllm_ascend/core/scheduler.py +@@ -32,6 +32,8 @@ from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + from vllm.v1.structured_output import StructuredOutputManager + ++import vllm_ascend.distributed ++ + + class AscendScheduler(Scheduler): + """This Scheduler extends vllm's original v1 scheduler +@@ -362,10 +364,15 @@ class AscendScheduler(Scheduler): + req_index += 1 + continue + ++ num_draft_tokens = max( ++ num_new_tokens + request.num_computed_tokens - ++ request.num_tokens, 0) ++ + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, ++ num_new_computed_tokens=num_draft_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + # The request cannot be scheduled. +@@ -585,3 +592,34 @@ class AscendScheduler(Scheduler): + + return super().update_from_output(scheduler_output, + model_runner_output) ++ ++ def _update_waiting_for_remote_kv(self, request: Request) -> bool: ++ """ ++ KV Connector: check if the request_id is finished_recving. ++ ++ The finished_recving_kv_req_ids list is populated ++ on the previous steps()'s update_from_output based ++ on the worker side connector. ++ ++ When the kv transfer is ready, we cache the blocks ++ and the request state will be moved back to WAITING from ++ WAITING_FOR_REMOTE_KV. ++ """ ++ assert self.connector is not None ++ if request.request_id not in self.finished_recving_kv_req_ids: ++ return False ++ ++ # Now that the blocks are ready, actually cache them. ++ (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) ++ num_computed_tokens = len(block_ids) * self.block_size ++ # Handle the case where num request tokens less then one block. ++ num_computed_tokens = min(num_computed_tokens, request.num_tokens) ++ if num_computed_tokens == request.num_tokens: ++ num_computed_tokens -= 1 ++ ++ # Update the request state for scheduling. ++ request.num_computed_tokens = num_computed_tokens ++ ++ # Return that we are ready. ++ self.finished_recving_kv_req_ids.remove(request.request_id) ++ return True +diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py +index 26ddd8f9..ae3957e7 100644 +--- a/vllm_ascend/distributed/__init__.py ++++ b/vllm_ascend/distributed/__init__.py +@@ -31,3 +31,7 @@ KVConnectorFactory.register_connector( + "MooncakeConnectorStoreV1", + "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", + "MooncakeConnectorV1") ++ ++KVConnectorFactory.register_connector( ++ "YuanRongConnector", "vllm_ascend.distributed.yuanrong_connector", ++ "YuanRongConnector") +diff --git a/vllm_ascend/distributed/yuanrong_connector.py b/vllm_ascend/distributed/yuanrong_connector.py +new file mode 100644 +index 00000000..52812cd6 +--- /dev/null ++++ b/vllm_ascend/distributed/yuanrong_connector.py +@@ -0,0 +1,826 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++import os ++import enum ++import hashlib ++from dataclasses import dataclass ++from typing import TYPE_CHECKING, List, Optional, Any ++import threading ++from collections import defaultdict ++import asyncio ++ ++import numpy ++import torch ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) ++from vllm.logger import init_logger ++from vllm.v1.attention.backends.mla.common import MLACommonMetadata ++from vllm.v1.core.sched.output import SchedulerOutput ++from vllm.distributed.parallel_state import get_tp_group ++ ++from datasystem import DsTensorClient, Future ++ ++ENABLE_PREFIX_CACHING = int(os.environ.get("USING_PREFIX_CONNECTOR", 1)) ++FUTURE_TIMEOUT = int(os.getenv("FUTURE_TIMEOUT", 10000)) ++SYNC_FUTURE_TIMEOUT = int(os.getenv("SYNC_FUTURE_TIMEOUT", 1)) ++SLEEP_TIMEOUT = 0.005 ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.request import Request ++ ++logger = init_logger(f"vllm.{__name__}") ++ ++ ++class RequestStatus(enum.IntEnum): ++ WAITING = enum.auto() ++ TIMEOUT = enum.auto() ++ FINISHED = enum.auto() ++ ++ ++@dataclass ++class RequestTracker: ++ request_id: str ++ token_ids: torch.Tensor ++ block_ids: list[int] ++ num_scheduled_tokens: int ++ ++ @staticmethod ++ def from_new_request(request_id, token_ids, block_ids, num_scheduled_tokens) -> "RequestTracker": ++ """ ++ Create the request tracker from a new request. ++ """ ++ return RequestTracker( ++ request_id=request_id, ++ token_ids=token_ids, ++ block_ids=block_ids, ++ num_scheduled_tokens=num_scheduled_tokens ++ ) ++ ++ def update( ++ self, ++ block_ids, ++ num_external_scheduled_tokens ++ ) -> None: ++ """ ++ Update the request tracker when a running request is ++ scheduled again ++ """ ++ self.block_ids[0].extend(block_ids[0]) ++ self.num_scheduled_tokens += num_external_scheduled_tokens ++ ++ ++@dataclass ++class ReqMeta: ++ request_id: str ++ token_ids: torch.Tensor ++ block_ids: list[int] ++ request_rank: int ++ skip_block_num: int ++ ds_cached_block_num: int ++ need_save: bool ++ ++ @staticmethod ++ def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], ++ block_size: int, request_rank: int, skip_block_num: int, ds_cached_block_num: int, need_save: bool) \ ++ -> "ReqMeta": ++ """make request meta""" ++ valid_num_tokens = align_to_block_size(len(token_ids), block_size) ++ valid_block_ids = valid_num_tokens // block_size ++ return ReqMeta( ++ request_id=request_id, ++ token_ids=numpy.array(token_ids), ++ block_ids=block_ids[0][:valid_block_ids], ++ request_rank=request_rank, ++ skip_block_num=skip_block_num, ++ ds_cached_block_num=ds_cached_block_num, ++ need_save=need_save ++ ) ++ ++ ++@dataclass ++class YuanRongConnectorMetadata(KVConnectorMetadata): ++ requests: list[ReqMeta] ++ ++ def __init__(self, tp_size, block_size): ++ self.requests = [] ++ self.tp_size = tp_size ++ self.request_rank = 0 ++ self._block_size = block_size ++ ++ def add_request( ++ self, ++ request_id: str, ++ token_ids: list[int], ++ block_ids: list[int], ++ skip_block_num: int, ++ ds_cached_block_num: int, ++ need_save: bool = True ++ ) -> None: ++ """add request meta""" ++ request_rank = self.request_rank % self.tp_size ++ self.requests.append( ++ ReqMeta.make_meta(request_id, token_ids, block_ids, self._block_size, request_rank, skip_block_num, ++ ds_cached_block_num, need_save)) ++ self.request_rank = request_rank + 1 ++ ++ ++@dataclass ++class ReqState: ++ """Per-request state for tracking async transfers.""" ++ num_pending: int = -1 ++ finished: bool = False ++ ++ ++class AsyncHandler: ++ """Manage async saving/loading in separate thread.""" ++ ++ def __init__(self, role, task_list): ++ self._async_save_reqs = defaultdict[str, ReqState](ReqState) ++ self._async_load_reqs = defaultdict[str, ReqState](ReqState) ++ self._is_producer = role ++ self._finished_save_reqs = asyncio.Queue() ++ self._finished_load_reqs = asyncio.Queue() ++ self._future_save_list = asyncio.Queue() ++ self._future_load_list = asyncio.Queue() ++ if self._is_producer or ENABLE_PREFIX_CACHING: ++ task_list.append(asyncio.get_event_loop().create_task(self.get_save_futures_async())) ++ if not self._is_producer or ENABLE_PREFIX_CACHING: ++ task_list.append(asyncio.get_event_loop().create_task(self.get_load_futures_async())) ++ ++ async def get_save_futures_async(self): ++ """async get save futures""" ++ while True: ++ try: ++ save_future_len = self._future_save_list.qsize() ++ for _ in range(save_future_len): ++ request_id, future = self._future_save_list.get_nowait() ++ res = get_future(future) ++ req_state = self._async_save_reqs[request_id] ++ if res == RequestStatus.FINISHED: ++ logger.info(f"request: {request_id} is finished") ++ req_state.num_pending -= 1 ++ if req_state.finished and not req_state.num_pending: ++ self._finished_save_reqs.put_nowait(request_id) ++ del self._async_save_reqs[request_id] ++ elif res == RequestStatus.WAITING or not req_state.finished: ++ self._future_save_list.put_nowait((request_id, future)) ++ else: ++ logger.error(f"request:{request_id} get save future timeout, res:{res}") ++ self._finished_save_reqs.put_nowait(request_id) ++ del self._async_save_reqs[request_id] ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ except Exception as e: ++ logger.error(f"get_futures_async fail, error:{e}") ++ ++ async def get_load_futures_async(self): ++ """async get load futures""" ++ while True: ++ try: ++ load_future_len = self._future_load_list.qsize() ++ for _ in range(load_future_len): ++ request_id, future = self._future_load_list.get_nowait() ++ res = get_future(future) ++ req_state = self._async_load_reqs[request_id] ++ if res == RequestStatus.FINISHED: ++ logger.info(f"request: {request_id} is finished") ++ req_state.num_pending -= 1 ++ if not req_state.num_pending: ++ self._finished_load_reqs.put_nowait(request_id) ++ del self._async_load_reqs[request_id] ++ elif res == RequestStatus.WAITING: ++ self._future_load_list.put_nowait((request_id, future)) ++ else: ++ logger.error(f"request:{request_id} get load future timeout, res:{res}") ++ self._finished_load_reqs.put_nowait(request_id) ++ del self._async_load_reqs[request_id] ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ except Exception as e: ++ logger.error(f"get_futures_async fail, error:{e}") ++ ++ def add_save_request(self, request: ReqMeta, future_num: int) -> None: ++ """add save request future""" ++ self._async_save_reqs[request.request_id].num_pending = future_num ++ ++ def add_load_request(self, request: ReqMeta, future_num: int) -> None: ++ """add load reqeust future""" ++ self._async_load_reqs[request.request_id].num_pending = future_num ++ ++ def add_save_future(self, request: ReqMeta, future: Future) -> None: ++ """add save reqeust future""" ++ self._future_save_list.put_nowait((request.request_id, future)) ++ ++ def add_load_future(self, request: ReqMeta, future: Future) -> None: ++ """add load request future""" ++ self._future_load_list.put_nowait((request.request_id, future)) ++ ++ def get_save_finished(self, finished_request_ids: set[str]) -> Optional[set[str]]: ++ """Finished saving request ids.""" ++ finished_reqs = set() ++ for req_id in finished_request_ids: ++ req_state = self._async_save_reqs[req_id] ++ if req_state: ++ req_state.finished = True ++ if not req_state.num_pending: ++ finished_reqs.add(req_id) ++ del self._async_save_reqs[req_id] ++ ++ while not self._finished_save_reqs.empty(): ++ finished_reqs.add(self._finished_save_reqs.get_nowait()) ++ if len(finished_reqs) != 0: ++ logger.debug(f"get_finished, finished_reqs:{finished_reqs}, length:{len(finished_reqs)}") ++ else: ++ finished_reqs = None ++ return finished_reqs ++ ++ def get_load_finished(self) -> set[str]: ++ """Finished saving request ids.""" ++ finished_reqs = set() ++ while not self._finished_load_reqs.empty(): ++ finished_reqs.add(self._finished_load_reqs.get_nowait()) ++ if len(finished_reqs) != 0: ++ logger.debug(f"get_finished, finished_reqs:{finished_reqs}, length:{len(finished_reqs)}") ++ else: ++ finished_reqs = None ++ return finished_reqs ++ ++ ++class YuanRongConnector(KVConnectorBase_V1): ++ ++ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ++ super().__init__(vllm_config=vllm_config, role=role) ++ self._block_size = vllm_config.cache_config.block_size ++ self._requests_need_load: dict[str, Request] = {} ++ self.config = vllm_config.kv_transfer_config ++ self.is_producer = self.config.is_kv_producer ++ self.do_async_save = int(os.getenv("ASYNC_SAVE", 1)) ++ self.layer_name_list = [] ++ self.kv_caches = [] ++ self.key_caches = [] ++ self.value_caches = [] ++ self._skip_blocks: dict[str, int] = {} ++ self._ds_cached_blocks: dict[str, int] = {} ++ self._delay_save = {} ++ self._load_request_queue = asyncio.Queue() ++ self._save_request_queue = asyncio.Queue() ++ self.task_list = [] ++ self.is_ms_non_mla_type = False ++ self.is_ms_mla = False ++ self.is_mla = False ++ self._async_handler = None ++ ++ self.tp_size = vllm_config.parallel_config.tensor_parallel_size ++ ds_worker_addr = os.getenv("DS_WORKER_ADDR", "172.17.0.4:9000") ++ ip_port = ds_worker_addr.split(":") ++ ip = ip_port[0] ++ port = int(ip_port[1]) ++ ++ self.device = self.tp_rank = 0 ++ if role == KVConnectorRole.WORKER: ++ self.tp_rank = get_tp_group().rank_in_group ++ self.tp_group = get_tp_group() ++ self.kvc_store = DsTensorClient(ip, port, self.device) ++ self.kvc_store.init() ++ if self.do_async_save: ++ self.loop = asyncio.get_event_loop() ++ self._async_handler = AsyncHandler(self.is_producer, self.task_list) ++ if ENABLE_PREFIX_CACHING or not self.is_producer: ++ self.task_list.append(self.loop.create_task(self.consumer_request_task())) ++ ++ if ENABLE_PREFIX_CACHING or self.is_producer: ++ self.task_list.append(self.loop.create_task(self.producer_request_task())) ++ ++ thread = threading.Thread(target=self.start_event_loop, daemon=True) ++ thread.start() ++ elif ENABLE_PREFIX_CACHING: ++ self.kvc_store = DsTensorClient(ip, port, self.device) ++ self.kvc_store.init() ++ else: ++ self.tp_group = None ++ logger.info(f"init datasystem ip = {ip}, port = {port}, device_id = {self.device}") ++ ++ def start_event_loop(self): ++ """start event loop""" ++ current_thread = threading.current_thread() ++ logger.info(f"start_event_loop: {current_thread.ident}") ++ self.loop.run_until_complete(asyncio.gather(*self.task_list)) ++ self.loop.close() ++ ++ async def producer_request_task(self): ++ """consumer request task""" ++ while True: ++ try: ++ save_request_len = self._save_request_queue.qsize() ++ for _ in range(save_request_len): ++ request = self._save_request_queue.get_nowait() ++ self.do_save_request(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ except Exception as e: ++ logger.error(f"producer_request_task fail, error:{e}") ++ self._save_request_queue.put_nowait(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ ++ async def consumer_request_task(self): ++ """consumer request task""" ++ while True: ++ try: ++ load_request_len = self._load_request_queue.qsize() ++ for _ in range(load_request_len): ++ request = self._load_request_queue.get_nowait() ++ self.do_load_kv(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ except Exception as e: ++ logger.error(f"consumer_request_task fail, error:{e}") ++ self._load_request_queue.put_nowait(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ ++ def generate_kv_cache_token_key( ++ self, ++ request: ReqMeta, ++ block_start_index: int, ++ block_end_index: int ++ ) -> List[str]: ++ """ ++ generate kv_cache token key. ++ """ ++ if not self.is_mla: ++ external_key = "-" + str(self.tp_rank) ++ else: ++ external_key = "-0" ++ ++ return generate_hash_sha256(block_start_index, block_end_index, request.token_ids, ++ self._block_size, external_key) ++ ++ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: ++ """ ++ Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ """ ++ # effective only when prefix cache is disabled and the role is producer. ++ if self.is_producer and not ENABLE_PREFIX_CACHING: ++ return ++ ++ metadata: KVConnectorMetadata = self._get_connector_metadata() ++ if len(metadata.requests) == 0: ++ return ++ ++ if len(self.kv_caches) == 0: ++ self._init_kv_caches_from_forward_context(forward_context) ++ ++ for request in metadata.requests: ++ if self._async_handler is not None: ++ self._load_request_queue.put_nowait(request) ++ else: ++ self.do_load_kv(request) ++ ++ def get_finished( ++ self, finished_req_ids: set[str] ++ ) -> tuple[Optional[set[str]], Optional[set[str]]]: ++ """Finished (saving, loading) request ids.""" ++ finished_saved_req, finished_loaded_req = None, None ++ if self._async_handler is not None: ++ if self.is_producer or ENABLE_PREFIX_CACHING: ++ finished_saved_req = self._async_handler.get_save_finished(finished_req_ids) ++ ++ if not self.is_producer or ENABLE_PREFIX_CACHING: ++ finished_loaded_req = self._async_handler.get_load_finished() ++ ++ return finished_saved_req, finished_loaded_req ++ return None, None ++ ++ def get_sending_count(self): ++ """ ++ Return count of finished sending requests aggregated. ++ For mla model, just save kvc for tp rank = 0 ++ """ ++ if self.is_mla: ++ return 1 ++ return self.tp_size ++ ++ def do_load_kv(self, request) -> None: ++ """ ++ Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be the same. ++ """ ++ ds_cached_block_num = request.ds_cached_block_num ++ skip_block_num = request.skip_block_num ++ logger.debug(f"request:{request.request_id}, ds_cached_block_num: {ds_cached_block_num}, " ++ f"skip_block_num: {skip_block_num}") ++ if ds_cached_block_num == 0: ++ return ++ key_list = self.generate_kv_cache_token_key(request, skip_block_num, ds_cached_block_num) ++ block_id_list = request.block_ids ++ if not block_id_list or not key_list: ++ return ++ if not self.is_mla: ++ value_cache_key_list = [key + "-value" for key in key_list] ++ if len(key_list) != len(block_id_list): ++ logger.error(f"mget_tensors_h2d fail, request.request_id:{request.request_id}.") ++ ++ get_timeout = 10000 ++ future = self.kvc_store.mget_page_attn_blockwise_h2d(key_list, self.key_caches, block_id_list, get_timeout) ++ future_1 = self.kvc_store.mget_page_attn_blockwise_h2d(value_cache_key_list, self.value_caches, ++ block_id_list, get_timeout) ++ if not self.do_async_save: ++ get_future(future, SYNC_FUTURE_TIMEOUT) ++ get_future(future_1, SYNC_FUTURE_TIMEOUT) ++ else: ++ self._async_handler.add_load_request(request, 2) ++ self._async_handler.add_load_future(request, future) ++ self._async_handler.add_load_future(request, future_1) ++ logger.debug(f"mget_tensors_h2d success, request.request_id:{request.request_id}, " ++ f"key_list length:{len(key_list)}") ++ return ++ ++ future = self.kvc_store.mget_page_attn_blockwise_h2d(key_list, self.kv_caches, block_id_list) ++ if not self.do_async_save: ++ get_future(future, SYNC_FUTURE_TIMEOUT) ++ else: ++ self._async_handler.add_load_request(request, 1) ++ self._async_handler.add_load_future(request, future) ++ logger.debug(f"mget_tensors_h2d success, request.request_id:{request.request_id}, " ++ f"key_list length:{len(key_list)}") ++ ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """ ++ wait_for_layer_load ++ """ ++ return ++ ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """ ++ save_kv_layer ++ """ ++ if not ENABLE_PREFIX_CACHING and not self.is_producer: ++ return ++ ++ if layer_name not in self.layer_name_list: ++ self.layer_name_list.append(layer_name) ++ self.is_ms_non_mla_type = isinstance(kv_layer, tuple) and len(kv_layer) == 2 ++ self.is_ms_mla = os.getenv("vLLM_MODEL_BACKEND", None) == "MindFormers" and not self.is_ms_non_mla_type ++ self.is_mla = isinstance(attn_metadata, MLACommonMetadata) or self.is_ms_mla ++ if self.is_mla: ++ self.kv_caches.append(kv_layer) ++ else: ++ self.key_caches.append(kv_layer[0]) ++ self.value_caches.append(kv_layer[1]) ++ ++ def do_save_request(self, request) -> None: ++ """ ++ Start saving the KV cache of the layer from vLLM's paged buffer to the connector. ++ """ ++ logger.debug(f"do_save_request, request:{request}") ++ if not self.is_producer or not request.need_save: ++ return ++ ++ if self.is_mla and self.tp_rank != request.request_rank: ++ return ++ ++ if not request.block_ids: ++ return ++ ++ token_key_list = self.generate_kv_cache_token_key(request, 0, len(request.block_ids)) ++ if not self.is_mla: ++ value_cache_key_list = [key + "-value" for key in token_key_list] ++ future = self.kvc_store.mset_page_attn_blockwise_d2h(token_key_list, self.key_caches, request.block_ids) ++ future_1 = self.kvc_store.mset_page_attn_blockwise_d2h(value_cache_key_list, self.value_caches, ++ request.block_ids) ++ if not self.do_async_save: ++ get_future(future, SYNC_FUTURE_TIMEOUT) ++ get_future(future_1, SYNC_FUTURE_TIMEOUT) ++ else: ++ self._async_handler.add_save_request(request, 2) ++ self._async_handler.add_save_future(request, future) ++ self._async_handler.add_save_future(request, future_1) ++ logger.debug(f"mset_tensors_d2h success, request.request_id:{request.request_id}, " ++ f"key_list length:{len(token_key_list)}") ++ return ++ ++ future = self.kvc_store.mset_page_attn_blockwise_d2h(token_key_list, self.kv_caches, request.block_ids) ++ if not self.do_async_save: ++ get_future(future, SYNC_FUTURE_TIMEOUT) ++ else: ++ self._async_handler.add_save_request(request, 1) ++ self._async_handler.add_save_future(request, future) ++ logger.debug(f"mset_tensors_d2h success, request.request_id:{request.request_id}, " ++ f"key_list length:{len(token_key_list)}.") ++ ++ def wait_for_save(self) -> None: ++ """ ++ wait_for_save ++ """ ++ if not self.is_producer: ++ return ++ connector_metadata = self._get_connector_metadata() ++ if not isinstance(connector_metadata, YuanRongConnectorMetadata): ++ raise ValueError("connector_metadata is not an instance of YuanRongConnectorMetadata") ++ ++ if not connector_metadata.requests: ++ return ++ ++ for request in connector_metadata.requests: ++ if self._async_handler is not None: ++ self._save_request_queue.put_nowait(request) ++ else: ++ self.do_save_request(request) ++ ++ def get_num_new_matched_tokens( ++ self, ++ request: "Request", ++ num_computed_tokens: int, ++ ) -> tuple[int, bool]: ++ """ ++ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. ++ ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ ++ Returns: ++ the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ """ ++ num_computed_blocks = num_computed_tokens // self._block_size ++ num_tokens_to_check = align_to_block_size(len(request.prompt_token_ids), self._block_size) ++ prompt_blocks = num_tokens_to_check // self._block_size ++ num_external_hit_tokens = 0 ++ if not self.is_producer: ++ self._skip_blocks[request.request_id] = num_computed_blocks ++ num_external_computed_tokens = len(request.prompt_token_ids) - num_computed_tokens - 1 ++ self._ds_cached_blocks[request.request_id] = prompt_blocks ++ if self.do_async_save and num_external_computed_tokens > 0: ++ logger.info(f"request_id:{request.request_id}, num_computed_tokens:{num_computed_tokens}, " ++ f"num_external_computed_tokens:{num_external_computed_tokens}") ++ return num_external_computed_tokens, True ++ ++ return num_external_computed_tokens, False ++ if ENABLE_PREFIX_CACHING: ++ tokens = request.prompt_token_ids ++ keys = generate_hash_sha256(num_computed_blocks, prompt_blocks, numpy.array(tokens), self._block_size, "-0") ++ if not keys: ++ logger.info( ++ "Reqid: %s, Total tokens %d, HBM hit tokens: %d, " ++ "need to load: 0", request.request_id, request.num_tokens, num_computed_tokens) ++ return 0, False ++ ++ try: ++ exists = self.kvc_store.exist(keys) + [False] ++ except RuntimeError: ++ logger.info( ++ "Reqid: %s, Total tokens %d, HBM hit tokens: %d, " ++ "need to load: 0", request.request_id, request.num_tokens, num_computed_tokens) ++ return 0, False ++ ++ num_external_hit_blocks = exists.index(False) ++ num_external_hit_tokens = num_external_hit_blocks * self._block_size ++ ++ self._skip_blocks[request.request_id] = num_computed_blocks ++ self._ds_cached_blocks[request.request_id] = num_external_hit_blocks + num_computed_blocks ++ ++ logger.info( ++ "Reqid: %s, Total tokens %d, HBM hit tokens: %d, " ++ "need to load: %d", request.request_id, request.num_tokens, num_computed_tokens, ++ num_external_hit_tokens) ++ ++ if self.do_async_save and num_external_hit_tokens > 0: ++ return num_external_hit_tokens, True ++ ++ return num_external_hit_tokens, False ++ ++ def update_state_after_alloc( ++ self, ++ request: "Request", ++ blocks: "KVCacheBlocks", ++ num_external_tokens: int ++ ) -> None: ++ """ ++ Update KVConnector state after block allocation. ++ ++ If blocks were allocated, add to _requests_need_load, ++ such that we load the KVs in the next forward pass. ++ """ ++ if num_external_tokens > 0: ++ block = blocks.get_unhashed_block_ids() ++ self._requests_need_load[request.request_id] = (request, [block]) ++ logger.debug(f"_requests_need_load add request_id: {request.request_id}") ++ ++ def build_connector_meta( ++ self, ++ scheduler_output: SchedulerOutput, ++ ) -> KVConnectorMetadata: ++ """ ++ Build the connector metadata for this step. ++ ++ This function should NOT modify any fields in the scheduler_output. ++ Also, calling this function will reset the state of the connector. ++ ++ Args: ++ scheduler_output (SchedulerOutput): the scheduler output object. ++ """ ++ meta = YuanRongConnectorMetadata(self.tp_size, self._block_size) ++ total_need_load = 0 ++ for new_req in scheduler_output.scheduled_new_reqs: ++ if new_req.req_id in self._requests_need_load: ++ meta.add_request(request_id=new_req.req_id, ++ token_ids=new_req.prompt_token_ids, ++ block_ids=new_req.block_ids, ++ skip_block_num=self._skip_blocks.pop(new_req.req_id, 0), ++ ds_cached_block_num=self._ds_cached_blocks.pop(new_req.req_id, 0)) ++ total_need_load += 1 ++ else: ++ if self.is_producer: ++ num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(new_req.req_id) ++ num_scheduled_tokens += new_req.num_computed_tokens ++ if len(new_req.prompt_token_ids) > num_scheduled_tokens: ++ self._delay_save[new_req.req_id] = RequestTracker.from_new_request(new_req.req_id, ++ new_req.prompt_token_ids, ++ new_req.block_ids, ++ num_scheduled_tokens) ++ else: ++ meta.add_request(request_id=new_req.req_id, ++ token_ids=new_req.prompt_token_ids, ++ block_ids=new_req.block_ids, ++ skip_block_num=self._skip_blocks.pop(new_req.req_id, 0), ++ ds_cached_block_num=self._ds_cached_blocks.pop(new_req.req_id, 0)) ++ ++ cached_reqs = scheduler_output.scheduled_cached_reqs ++ for i, req_id in enumerate(cached_reqs.req_ids): ++ new_block_ids = cached_reqs.new_block_ids[i] ++ resumed_from_preemption = cached_reqs.resumed_from_preemption[i] ++ ++ # NOTE(rob): here we rely on the resumed requests being ++ # the first N requests in the list scheduled_cache_reqs. ++ if not resumed_from_preemption: ++ if req_id in self._delay_save: ++ request_tracker = self._delay_save.get(req_id) ++ num_external_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id) ++ request_tracker.update(new_block_ids, num_external_scheduled_tokens) ++ if len(request_tracker.token_ids) <= request_tracker.num_scheduled_tokens: ++ del self._delay_save[req_id] ++ logger.debug(f"add delay save request, request id:{request_tracker.request_id}") ++ meta.add_request(request_id=request_tracker.request_id, ++ token_ids=request_tracker.token_ids, ++ block_ids=request_tracker.block_ids, ++ skip_block_num=self._skip_blocks.pop(request_tracker.request_id, 0), ++ ds_cached_block_num=self._ds_cached_blocks.pop(request_tracker.request_id, 0)) ++ ++ if req_id in self._requests_need_load: ++ # NOTE(rob): cached_req_data does not have the full ++ # list of token ids (only new tokens). So we look it ++ # up in the actual request object. ++ request = self._requests_need_load[req_id] ++ token_ids = request.all_token_ids[:len(request.prompt_token_ids)] ++ logger.debug(f"request_id:{request.request_id} resumed from preemption") ++ # NOTE(rob): For resumed req, new_block_ids is all of the block_ids for the request. ++ block_ids = new_block_ids ++ meta.add_request(request_id=req_id, ++ token_ids=token_ids, ++ block_ids=block_ids, ++ skip_block_num=self._skip_blocks.pop(req_id, 0), ++ ds_cached_block_num=self._ds_cached_blocks.pop(req_id, 0)) ++ total_need_load += 1 ++ if self.do_async_save: ++ for req_id, (req, block_ids) in self._requests_need_load.items(): ++ if not block_ids: ++ logger.debug( ++ "Skipping adding request %s to ConnectorMetadata, " ++ "as there are no remote blocks to pull", req_id) ++ continue ++ ++ meta.add_request( ++ request_id=req_id, ++ token_ids=req.prompt_token_ids, ++ block_ids=block_ids, ++ skip_block_num=self._skip_blocks.pop(req_id, 0), ++ ds_cached_block_num=self._ds_cached_blocks.pop(req_id, 0), ++ need_save=False) ++ total_need_load += 1 ++ ++ logger.debug(f"total_need_load:{total_need_load}, self._requests_need_load:{len(self._requests_need_load)}") ++ # Clear the list once workers start the transfers ++ if total_need_load != len(self._requests_need_load): ++ logger.error(f"total_need_load={total_need_load} " ++ f"is not equal to requests_need_load={len(self._requests_need_load)}") ++ raise ValueError("total_need_load is not equal to requests_need_load") ++ self._requests_need_load.clear() ++ return meta ++ ++ def request_finished( ++ self, ++ request: "Request", ++ block_ids: list[int], ++ ) -> tuple[bool, Optional[dict[str, Any]]]: ++ """ ++ request_finished ++ """ ++ # Return True to indicate that saving may be happening asynchronously. ++ if self.is_producer: ++ return self.do_async_save, None ++ ++ return False, None ++ ++ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): ++ """ ++ Initialize KV caches from forward_context. ++ ++ Args: ++ forward_context: forward_context. ++ """ ++ attn_metadata = forward_context.attn_metadata ++ for layer_name in forward_context.no_compile_layers: ++ attn_layer = forward_context.no_compile_layers[layer_name] ++ kv_layer = attn_layer.kv_cache[forward_context.virtual_engine] ++ self.is_ms_non_mla_type: bool = isinstance(kv_layer, tuple) and len(kv_layer) == 2 ++ self.is_ms_mla = os.getenv("vLLM_MODEL_BACKEND", None) == "MindFormers" and not self.is_ms_non_mla_type ++ self.is_mla = isinstance(attn_metadata, MLACommonMetadata) or self.is_ms_mla ++ if layer_name not in self.layer_name_list: ++ self.layer_name_list.append(layer_name) ++ logger.debug(f"_init_kv_caches_from_forward_context, layer_name:{layer_name}") ++ if not self.is_mla: ++ self.key_caches.append(kv_layer[0]) ++ self.value_caches.append(kv_layer[1]) ++ elif self.is_ms_mla: ++ self.kv_caches.append(kv_layer[0]) ++ else: ++ self.kv_caches.append(kv_layer) ++ ++ ++def extract_number(s: str) -> Optional[int]: ++ """extract number""" ++ parts = s.split('.') ++ for part in parts: ++ if part.isdigit(): ++ return int(part) ++ return None ++ ++ ++def align_to_block_size(num_tokens: int, block_size: int) -> int: ++ """ ++ Align the number of tokens to the block size. ++ """ ++ return (num_tokens + block_size - 2) // block_size * block_size ++ ++ ++def generate_hash_sha256( ++ block_start_index: int, ++ block_end_index: int, ++ token_ids: numpy.ndarray, ++ block_size: int, ++ external_key: str ++) -> List[str]: ++ """ ++ generate kv_cache token key. ++ ++ Args: ++ block_id_num: number of block ids. ++ token_ids: token ids ++ block_size: block size of vllm ++ external_key: additional key ++ """ ++ hash_list = [] ++ for block_index in range(block_start_index, block_end_index): ++ end_index = (block_index + 1) * block_size ++ input_ids = token_ids[:end_index] ++ input_ids_bytes = input_ids.tobytes() ++ token_hash = hashlib.sha256(input_ids_bytes).hexdigest() ++ hash_list.append(token_hash + external_key) ++ return hash_list ++ ++ ++def get_future(fut: Future, timeout: int = FUTURE_TIMEOUT) -> RequestStatus: ++ """get future""" ++ try: ++ failed_list = fut.get(timeout) ++ except TimeoutError: ++ return RequestStatus.WAITING ++ ++ if len(failed_list) != 0: ++ logger.error(f"failed_list: {failed_list}") ++ return RequestStatus.TIMEOUT ++ ++ return RequestStatus.FINISHED +-- +2.33.0 + diff --git a/tests/python/prefetch_tests/start_worker.sh b/tests/python/prefetch_tests/start_worker.sh index b71e0ed..7594512 100755 --- a/tests/python/prefetch_tests/start_worker.sh +++ b/tests/python/prefetch_tests/start_worker.sh @@ -11,8 +11,7 @@ export PATH=$PATH:${BASE_DIR}/scripts/modules function run_example() { echo -e "---- Start Smoke Testing..." - bash "${DATASYSTEM_DIR}/example/run-example.sh" "off" "off" "on" || - (remove_running_pids && go_die "---- Smoke Testing failed!") + bash "${DATASYSTEM_DIR}/example/run-example.sh" || (remove_running_pids && go_die "---- Smoke Testing failed!") echo -e "---- Smoke Testing success!" } diff --git a/tests/python/test_ds_tensor_client.py b/tests/python/test_ds_tensor_client.py index b559e09..152111e 100644 --- a/tests/python/test_ds_tensor_client.py +++ b/tests/python/test_ds_tensor_client.py @@ -14,6 +14,8 @@ """ Test datasystem tensor client python interface. """ +import logging +from multiprocessing import Process, Barrier import json import os import random @@ -42,6 +44,12 @@ try: except ImportError: is_tensor_client_exist = False +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(message)s', +) +logger = logging.getLogger(__name__) + class TestDsTensorClient(unittest.TestCase): """ @@ -226,7 +234,7 @@ class TestDsTensorClient(unittest.TestCase): @unittest.skipUnless(is_mindspore_exist and is_tensor_client_exist, "Run when dependency is exist") def test_dev_mset_and_dev_mget_with_mindspore_tensor(self): """Test dev_mset and dev_mget device object.""" - src_device_id, dest_device_id = 6, 7 + src_device_id, dest_device_id = 0, 1 key_num = 1 keys = [self.random_str(10) for _ in range(key_num)] datas = [np.random.rand(2, 3) for _ in range(key_num)] @@ -528,3 +536,85 @@ class TestDsTensorClient(unittest.TestCase): self.assertEqual(len(failed_keys), 0) acl.finalize() + + @unittest.skipUnless(is_mindspore_exist and is_tensor_client_exist, "Run when dependency is exist") + def test_dev_d2d_dead_lock1(self): + """Test the d2d deadlock.""" + local_rank_num = 8 + dtype = ms.float32 + shape = (2, 3) + + def task(i, barrier, local_rank_num): + acl.init() + acl.rt.set_device(i) + ms.set_device(device_target="Ascend", device_id=i) + client = self.init_test_tensor_client(i) + keys = [f'{i}_{j}' for j in range(local_rank_num)] + send_tensors = [ms.Tensor(np.ones(shape), dtype) + 0 for i in range(local_rank_num)] + + failed_keys = client.dev_mset(keys, send_tensors) + assert len(failed_keys) == 0 + logger.info(f"device {i} set key_list:{keys} success") + barrier.wait() + + get_keys = [f'{j}_{i}' for j in range(local_rank_num)] + recv_tensors = [ms.Tensor(np.zeros(shape), dtype) + 0 for i in range(local_rank_num)] + failed_keys = client.dev_mget(get_keys, recv_tensors, 60 * 1000) + assert len(failed_keys) == 0 + + self.batch_tensors_check(recv_tensors, send_tensors) + logger.info(f"device {i} get key_list:{get_keys} success") + + barrier.wait() + failed_keys = client.dev_delete(keys) + assert len(failed_keys) == 0 + + processes = [] + barrier = Barrier(local_rank_num) + for i in range(local_rank_num): + p = Process(target=task, args=(i, barrier, local_rank_num)) + processes.append(p) + p.start() + + for p in processes: + p.join() + + @unittest.skipUnless(is_mindspore_exist and is_tensor_client_exist, "Run when dependency is exist") + def test_dev_d2d_dead_lock2(self): + """Test the d2d deadlock.""" + local_rank_num = 8 + dtype = ms.float32 + shape = (2, 3) + key_lists_formal = [f"device_id_{i}" for i in range(local_rank_num)] + array_lists_formal = [np.random.randn(*shape) for _ in range(local_rank_num)] + + def task(i, barrier, local_rank_num): + acl.init() + acl.rt.set_device(i) + ms.set_device(device_target="Ascend", device_id=i) + client = self.init_test_tensor_client(i) + send_tensors = [ms.Tensor(array_lists_formal[i], dtype) + 0] + failed_keys = client.dev_mset([key_lists_formal[i]], send_tensors) + assert len(failed_keys) == 0 + logger.info(f"device {i} set key_list:{key_lists_formal[i]} success") + barrier.wait() + + key_lists = key_lists_formal[0: i] + key_lists_formal[i + 1::] + recv_tensors = [ms.Tensor(np.ones(shape), dtype) + 0 for _ in range(local_rank_num - 1)] + failed_keys = client.dev_mget(key_lists, recv_tensors, 60 * 1000) + assert len(failed_keys) == 0 + logger.info(f"device {i} get key_list:{key_lists} success") + + barrier.wait() + failed_keys = client.dev_delete(key_lists_formal) + assert len(failed_keys) == 0 + + processes = [] + barrier = Barrier(local_rank_num) + for i in range(local_rank_num): + p = Process(target=task, args=(i, barrier, local_rank_num)) + processes.append(p) + p.start() + + for p in processes: + p.join() diff --git a/tests/python/test_oc_client.py b/tests/python/test_oc_client.py index 02a4fa4..eb59cbc 100644 --- a/tests/python/test_oc_client.py +++ b/tests/python/test_oc_client.py @@ -23,7 +23,7 @@ import threading import time import unittest -from datasystem.object_client import Buffer, ConsistencyType, ObjectClient, WriteMode +from datasystem.object_client import Buffer, ConsistencyType, ObjectClient class TestOcClientMethods(unittest.TestCase): @@ -90,7 +90,7 @@ class TestOcClientMethods(unittest.TestCase): client = self.init_test_client() object_key = self.random_str(10) value = bytes(self.random_str(50), encoding='utf8') - param = {"write_mode": WriteMode.NONE_L2_CACHE, "consistency_type": ConsistencyType.PRAM} + param = {"consistency_type": ConsistencyType.PRAM} client.put(object_key, value, param) buffer_list = client.get([object_key], 5) self.assertEqual(value, buffer_list[0].immutable_data()) @@ -102,7 +102,7 @@ class TestOcClientMethods(unittest.TestCase): client = self.init_test_client() object_key = self.random_str(10) value = bytes(self.random_str(50), encoding='utf8') - param = {"write_mode": WriteMode.NONE_L2_CACHE, "consistency_type": ConsistencyType.PRAM} + param = {"consistency_type": ConsistencyType.PRAM} client.put(object_key, value, param) buffer_list = client.get([object_key], 5) read_data = buffer_list[0].immutable_data() @@ -117,7 +117,7 @@ class TestOcClientMethods(unittest.TestCase): object_key = self.random_str(10) value = bytes(self.random_str(50), encoding='utf8') size = len(value) - param = {"write_mode": WriteMode.NONE_L2_CACHE, "consistency_type": ConsistencyType.PRAM} + param = {"consistency_type": ConsistencyType.PRAM} buffer = client.create(object_key, size, param) self.assertEqual(buffer.get_size(), size) buffer.wlatch() @@ -360,7 +360,7 @@ class TestOcClientMethods(unittest.TestCase): self.assertEqual(client.query_global_ref_num(object2), 1) self.assertEqual(client.query_global_ref_num(object3), 1) - param = {"write_mode": WriteMode.NONE_L2_CACHE, "consistency_type": ConsistencyType.PRAM} + param = {"consistency_type": ConsistencyType.PRAM} client.put(object2, value, param) client.put(object3, value, param) diff --git a/tests/python/test_sc_client.py b/tests/python/test_sc_client.py new file mode 100644 index 0000000..53317cd --- /dev/null +++ b/tests/python/test_sc_client.py @@ -0,0 +1,506 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Stream cache test case. +""" +from __future__ import absolute_import +import json +import logging +import os +import random +import time +import unittest + +from datasystem.object_client import ObjectClient +from datasystem.stream_client import StreamClient, SubconfigType + + +def wait_proc(proc): + """wait_proc""" + proc.wait() + stdout, stderr = proc.communicate() + logging.info(stdout) + logging.info(stderr) + + +class TestScClientMethods(unittest.TestCase): + """ + Features: Stream cache client python interface test. + """ + + @classmethod + def setUpClass(cls): + logging.info("********************sc_client test start*********************") + time.sleep(3) + root_dir = os.path.dirname(os.path.abspath('..')) + worker_env_path = os.path.join(root_dir, "output", "service", "worker_config.json") + with open(worker_env_path, "r") as f: + config = json.load(f) + + work_address = config.get("worker_address", {}) + TestScClientMethods.work_addr = work_address.get("value") + logging.info("TestScClientMethods.work_addr: %s", TestScClientMethods.work_addr) + + @staticmethod + def multi_producer_and_consumer(client, element_datas): + """mutil producer and consumer test case""" + length = len(element_datas) + stream_name = "stream_name_multi_" + str(0) + sub_name = "sub_name_multi_" + str(0) + producer_tmp = client.create_producer(stream_name, 50) + + consumer_tmp = client.subscribe(stream_name, sub_name, SubconfigType.STREAM.value) + + before_send = time.perf_counter() + for i in range(length): + # send bytes directly. + producer_tmp.send(element_datas[i]) + + before_flsuh = time.perf_counter() + before_recv = time.perf_counter() + _ = consumer_tmp.receive(length, 0) + after_recv = time.perf_counter() + + time_lst = [before_send, before_flsuh, before_recv, after_recv] + annotation_lst = ['send', 'flush', 'recv'] + for i in range(len(time_lst) - 1): + logging.info('%s: %s ms', annotation_lst[i], 1000 * (time_lst[i + 1] - time_lst[i])) + producer_tmp.close() + consumer_tmp.close() + return True + + @staticmethod + def test_client_delete_stream_success(): + """delete stream test""" + stream_name = "stream_ds" + ip = TestScClientMethods.work_addr.split(":") + client_ds = StreamClient(ip[0], int(ip[1])) + client_ds.init() + producer_ds = client_ds.create_producer(stream_name) + + consumer_ds = client_ds.subscribe(stream_name, "sub_name_d", SubconfigType.STREAM.value) + + producer_ds.close() + consumer_ds.close() + client_ds.delete_stream(stream_name) + + @staticmethod + def test_stream_set_default_size_success(): + """stream set default size test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_sds = StreamClient(ip[0], int(ip[1])) + client_stream_sds.init() + stream_name = "stream_set_default_size" + + client_stream_sds.create_producer(stream_name, 2) + + @staticmethod + def test_multi_producer_set_success(): + """mutil producer set test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_psd = StreamClient(ip[0], int(ip[1])) + client_stream_psd.init() + stream_name = "stream_multi_producer" + + base_size = 1024 * 4 + for i in range(1, 6): + client_stream_psd.create_producer(stream_name, 1 + i // 3, base_size) + for _ in range(1, 6): + client_stream_psd.create_producer(stream_name, 5, base_size) + + @staticmethod + def test_client_send_receive_data_success(): + """client send and receive data test""" + arr = b'10101010' + stream_name = "stream_sr" + ip = TestScClientMethods.work_addr.split(":") + client_srd = StreamClient(ip[0], int(ip[1])) + client_srd.init() + producer_srd = client_srd.create_producer(stream_name) + + consumer_srd = client_srd.subscribe(stream_name, "sub_name_s", SubconfigType.STREAM.value) + producer_srd.send(arr) + + element_list = consumer_srd.receive(1, 0) + + data_element = memoryview(element_list[-1]) + logging.info("element type: %s", type(element_list)) + logging.info("element size: %s", len(data_element) * data_element.itemsize) + logging.info("element context: %s", data_element.tobytes()) + consumer_srd.ack(element_list[-1].get_id()) + + producer_srd.close() + consumer_srd.close() + + @staticmethod + def test_client_send_with_blocking_support_receive_without_expected_num_data(): + """client send and receive data test""" + arr = b'10101010' + stream_name = "stream_sr" + ip = TestScClientMethods.work_addr.split(":") + client_srd = StreamClient(ip[0], int(ip[1])) + client_srd.init() + producer_srd = client_srd.create_producer(stream_name) + + consumer_srd = client_srd.subscribe(stream_name, "sub_name_s", SubconfigType.STREAM.value) + producer_srd.send(arr, 1000) + + element_list = consumer_srd.receive_any(0) + + data_element = memoryview(element_list[-1]) + logging.info("element type: %s", type(element_list)) + logging.info("element size: %s", len(data_element) * data_element.itemsize) + logging.info("element context: %s", data_element.tobytes()) + consumer_srd.ack(element_list[-1].get_id()) + + producer_srd.close() + consumer_srd.close() + + @staticmethod + def test_stream_set_pagesize_sucess(): + """stream set pagesize test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_spd = StreamClient(ip[0], int(ip[1])) + client_stream_spd.init() + stream_name = "stream_set_pagesize" + + client_stream_spd.create_producer(stream_name, 1, 1024 * 4) + + @staticmethod + def test_stream_client_work_success(): + """stream case test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_os = StreamClient(ip[0], int(ip[1])) + client_stream_os.init() + + stream_name = "stream_object" + producer_os = client_stream_os.create_producer(stream_name) + + consumer_os = client_stream_os.subscribe(stream_name, "sub_object_stream", SubconfigType.STREAM.value) + arr = b'0001' + producer_os.send(arr) + element_list = consumer_os.receive(1, 0) + + consumer_os.ack(element_list[-1].get_id()) + + producer_os.close() + consumer_os.close() + + def random_str(self, slen=10): + """random string""" + seed = "1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + sa = [] + for _ in range(slen): + sa.append(random.choice(seed)) + return ''.join(sa) + + def generate_elements(self, element_num, element_size): + """generate elements""" + elements = [] + seed = "1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + tmp = [] + for _ in range(element_num): + for _ in range(element_size): + tmp.append(random.choice(seed)) + rus = ''.join(tmp) + elements.append(bytes(rus, encoding='utf8')) + tmp = [] + return elements + + def test_multi_send_recv_success(self): + """mutil producer and consumer test""" + element_num = 4000 + element_size = 1000 + element_datas = self.generate_elements(element_num, element_size) + ip = TestScClientMethods.work_addr.split(":") + client_sr = StreamClient(ip[0], int(ip[1])) + client_sr.init() + self.multi_producer_and_consumer(client_sr, element_datas) + + def test_query_stream_topo_success(self): + """query stream topo test""" + stream_name1 = "query_topo_stream1" + ip = TestScClientMethods.work_addr.split(":") + client_qst1 = StreamClient(ip[0], int(ip[1])) + client_qst1.init() + node1_producer_qst1 = client_qst1.create_producer(stream_name1) + + node1_consumer_qst = client_qst1.subscribe(stream_name1, "sub_name1", SubconfigType.STREAM.value) + + global_consumer_num = client_qst1.query_global_consumer_num(stream_name1) + self.assertEqual(global_consumer_num, 1) + + global_producer_num = client_qst1.query_global_producer_num(stream_name1) + self.assertEqual(global_producer_num, 1) + + ip = TestScClientMethods.work_addr.split(":") + client_qst2 = StreamClient(ip[0], int(ip[1])) + client_qst2.init() + + node2_producer_qst = client_qst2.create_producer(stream_name1) + + node2_consumer_qst = client_qst2.subscribe(stream_name1, "sub_name2", SubconfigType.STREAM.value) + global_consumer_num = client_qst2.query_global_consumer_num(stream_name1) + self.assertEqual(global_consumer_num, 2) + + global_producer_num = client_qst1.query_global_producer_num(stream_name1) + self.assertEqual(global_producer_num, 1) + + node1_producer_qst1.close() + node1_consumer_qst.close() + node2_producer_qst.close() + node2_consumer_qst.close() + + def test_object_client_work_success(self): + """object case test""" + host, port = TestScClientMethods.work_addr.split(":") + port = int(port) + client_os = ObjectClient(host, port) + client_os.init() + object_key = self.random_str(10) + value = bytearray(self.random_str(50), encoding='utf8') + buffer = client_os.create(object_key, len(value)) + buffer.wlatch() + buffer.memory_copy(value) + buffer.seal() + buffer.unwlatch() + + object_key2 = self.random_str(10) + value2 = bytearray(self.random_str(100), encoding='utf8') + buffer2 = client_os.create(object_key2, len(value2)) + buffer2.wlatch() + buffer2.memory_copy(value2) + buffer2.seal() + buffer2.unwlatch() + + buffer_list = client_os.get([object_key, object_key2], 5) + + self.assertEqual(value, buffer_list[0].immutable_data()) + self.assertEqual(value2, buffer_list[1].immutable_data()) + + def test_stream_auto_send_success(self): + """stream auto send case test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_asd = StreamClient(ip[0], int(ip[1])) + client_stream_asd.init() + + stream_name = "stream_send_auto" + producer_asd = client_stream_asd.create_producer(stream_name) + consumer_asd = client_stream_asd.subscribe(stream_name, "sub_object_stream", SubconfigType.STREAM.value) + + arr = b'10101010' + producer_asd.send(arr) + + element_list = consumer_asd.receive(1, 10) + + consumer_asd.ack(element_list[-1].get_id()) + self.assertEqual(memoryview(element_list[-1]).tobytes(), arr) + + producer_asd.close() + consumer_asd.close() + + def test_stream_auto_send_without_delay_success(self): + """stream auto send without delay case test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_swd = StreamClient(ip[0], int(ip[1])) + client_stream_swd.init() + stream_name = "stream_send_without_delay" + producer_swd = client_stream_swd.create_producer(stream_name, 0) + consumer_swd = client_stream_swd.subscribe(stream_name, "sub_object_stream", SubconfigType.STREAM.value) + + arr = b'10101010' + producer_swd.send(arr) + + element_list = consumer_swd.receive(1, 10) + + consumer_swd.ack(element_list[-1].get_id()) + self.assertEqual(memoryview(element_list[-1]).tobytes(), arr) + + producer_swd.close() + consumer_swd.close() + + def test_stream_sametime_send_success(self): + """stream send sametime test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_ssd = StreamClient(ip[0], int(ip[1])) + client_stream_ssd.init() + stream_name = "stream_sametime_send" + producer_ssd = client_stream_ssd.create_producer(stream_name) + consumer_ssd = client_stream_ssd.subscribe(stream_name, "sub_object_stream", SubconfigType.STREAM.value) + + arr = b'10101010' + arr2 = b'01010101' + producer_ssd.send(arr) + + time.sleep(0.005) + producer_ssd.send(arr2) + + element_list = consumer_ssd.receive(2, 20) + + for _, element in enumerate(element_list): + consumer_ssd.ack(element.get_id()) + + self.assertEqual(memoryview(element_list[0]).tobytes(), arr) + self.assertEqual(memoryview(element_list[-1]).tobytes(), arr2) + + producer_ssd.close() + consumer_ssd.close() + + def test_stream_continuous_send_success(self): + """stream send continue test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_csd = StreamClient(ip[0], int(ip[1])) + client_stream_csd.init() + stream_name = "stream_continuous_send" + producer_csd = client_stream_csd.create_producer(stream_name) + consumer_csd = client_stream_csd.subscribe(stream_name, "sub_object_stream", SubconfigType.STREAM.value) + + arr = b'1010101010101010' + for i in range(1, 11): + tmparr = arr[0:i] + producer_csd.send(tmparr) + + element_list = consumer_csd.receive(15, 20) + + for i, element in enumerate(element_list): + consumer_csd.ack(element.get_id()) + self.assertEqual(memoryview(element).tobytes(), arr[0:(i + 1)]) + + producer_csd.close() + consumer_csd.close() + + def test_stream_producer_without_consumer(self): + """stream send test without consumer""" + ip = TestScClientMethods.work_addr.split(":") + for i in range(5): + client = StreamClient(ip[0], int(ip[1])) + client.init() + stream_name = 'test_dfx_streamcache_node_scale_004' + producer = client.create_producer(stream_name, delay_flush_time_ms=5, page_size_byte=1024 * 1024, + max_stream_size_byte=10 * 1024 * 1024, auto_cleanup=False) + for j in range(100000): + data = ('test' + str(i) + str(j)).encode() + producer.send(data) + producer.close() + client.delete_stream(stream_name) + + def test_stream_sametime_flush_success(self): + """stream flush sametime test""" + ip = TestScClientMethods.work_addr.split(":") + client_stream_sfd = StreamClient(ip[0], int(ip[1])) + client_stream_sfd.init() + stream_name = "stream_sametime_flush" + producer_sfd = client_stream_sfd.create_producer(stream_name) + consumer_sfd = client_stream_sfd.subscribe(stream_name, "sub_object_stream", SubconfigType.STREAM.value) + + arr = b'10101010' + producer_sfd.send(arr) + time.sleep(0.005) + element_list = consumer_sfd.receive(1, 100) + self.assertEqual(len(element_list), 1) + consumer_sfd.ack(element_list[-1].get_id()) + self.assertEqual(memoryview(element_list[-1]).tobytes(), arr) + + producer_sfd.close() + consumer_sfd.close() + + def test_long_max_limit(self): + """Test case in which the value of except_num exceeds the upper limit""" + ip = TestScClientMethods.work_addr.split(":") + client = StreamClient(ip[0], int(ip[1])) + client.init() + stream_name = "test_long_max_limit" + consumer = client.subscribe(stream_name, "sub_object_stream", SubconfigType.STREAM.value) + + int32_max_size = int("0x7FFFFFFF", 16) + + # test except_num + except_num = int32_max_size + 1 + self.assertRaises(RuntimeError, consumer.receive, except_num, 10) + + # test timeout_ms + timeout_ms = int32_max_size + 1 + self.assertRaises(RuntimeError, consumer.receive, 10, timeout_ms) + consumer.close() + + def test_param_max_limit(self): + """Test case in which the value size exceeds the upper limit""" + ip = TestScClientMethods.work_addr.split(":") + client = StreamClient(ip[0], int(ip[1])) + client.init() + stream_name = "test_send_max_limit" + delay_flush_time_ms = int("0x7FFFFFFFFFFFFFFF", 16) + 1 + self.assertRaises(RuntimeError, client.create_producer, stream_name, delay_flush_time_ms) + + def test_stream_producer_retain_for_num_consumer(self): + """stream test retain for consumer""" + ip = TestScClientMethods.work_addr.split(":") + client = StreamClient(ip[0], int(ip[1])) + client.init() + stream_name = 'test_retain' + producer = client.create_producer(stream_name, 5, 1024 * 1024, 10 * 1024 * 1024, False, 1) + + arr = b'10101010' + producer.send(arr) + + consumer = client.subscribe(stream_name, "test_retain_sub1", SubconfigType.STREAM.value) + element_list = consumer.receive(1, 0) + self.assertEqual(len(element_list), 1) + consumer.close() + + # We only retain data for 1 consumer. Second consumer should not get any + consumer2 = client.subscribe(stream_name, "test_retain_sub2", SubconfigType.STREAM.value) + element_list2 = consumer2.receive_any(0) + self.assertEqual(len(element_list2), 0) + consumer2.close() + + producer.close() + client.delete_stream(stream_name) + + def test_create_producer_reserve_size(self): + """test create producer reserve size""" + # This testcase intends to test that invalid reserve size will lead to CreateProducer failure. + ip = TestScClientMethods.work_addr.split(":") + client = StreamClient(ip[0], int(ip[1])) + client.init() + stream_name = 'test_reserve_size' + delay_flush_time_ms = 5 + page_size_byte = 8 * 1024 + max_stream_size_byte = 64 * 1024 * 1024 + auto_cleanup = False + retain_for_num_consumers = 0 + encrypt_stream = False + + # Valid reserve size should be less than or equal to max stream size. + reserve_size = max_stream_size_byte + page_size_byte + self.assertRaises(RuntimeError, client.create_producer, stream_name, delay_flush_time_ms, + page_size_byte, max_stream_size_byte, auto_cleanup, retain_for_num_consumers, encrypt_stream, + reserve_size) + + # Valid reserve size should be a multiple of page size. + reserve_size = 12 * 1024 + self.assertRaises(RuntimeError, client.create_producer, stream_name, delay_flush_time_ms, + page_size_byte, max_stream_size_byte, auto_cleanup, retain_for_num_consumers, encrypt_stream, + reserve_size) + + # 0 is an acceptable input for reserve size, the default reserve size will then be the page size. + reserve_size = 0 + producer = client.create_producer(stream_name, delay_flush_time_ms, page_size_byte, + max_stream_size_byte, auto_cleanup, retain_for_num_consumers, encrypt_stream, reserve_size) + + global_producer_num = client.query_global_producer_num(stream_name) + self.assertEqual(global_producer_num, 1) + + producer.close() + client.delete_stream(stream_name) diff --git a/tests/st/CMakeLists.txt b/tests/st/CMakeLists.txt index fb09cd4..dc2d4f4 100644 --- a/tests/st/CMakeLists.txt +++ b/tests/st/CMakeLists.txt @@ -24,6 +24,10 @@ set(DS_ST_DEPEND_LIBS master_object_cache_store master_object_cache worker_object_cache + master_stream_cache + worker_stream_cache + master_stream_cache_store + master_stream_protos httpclient common_persistence_api) @@ -42,6 +46,9 @@ add_subdirectory(cluster) file(GLOB_RECURSE DS_TEST_ST_SRCS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") +file(GLOB_RECURSE DS_ST_STREAM_CACHE_SRCS CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/**/stream_cache/*.cpp") + file(GLOB_RECURSE DS_ST_OBJECT_CACHE_SRCS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/**/object_cache/*.cpp") @@ -53,6 +60,7 @@ list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*encrypt_util.cpp) list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*hashring_parser.cpp) list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*/device/.*) list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*/cluster/.*) +list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*/stream_cache/.*) list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*/object_cache/.*) list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*/kv_cache/.*) @@ -62,27 +70,33 @@ endif() if (NOT BUILD_GO_API) list(FILTER DS_TEST_ST_SRCS EXCLUDE REGEX .*/client_c_api/.*) + list(FILTER DS_ST_STREAM_CACHE_SRCS EXCLUDE REGEX .*/client_c_api/.*) list(FILTER DS_ST_OBJECT_CACHE_SRCS EXCLUDE REGEX .*/client_c_api/.*) list(FILTER DS_ST_KV_CACHE_SRCS EXCLUDE REGEX .*/client_c_api/.*) endif () set(ST_COMMON_SRCS test_main.cpp - st_oc_service_impl.cpp) + st_oc_service_impl.cpp + common/stream_cache/element_generator.cpp) +add_library(_ds_st_stream_cache_obj OBJECT ${DS_ST_STREAM_CACHE_SRCS} ${ST_COMMON_SRCS}) add_library(_ds_st_object_cache_obj OBJECT ${DS_ST_OBJECT_CACHE_SRCS} ${ST_COMMON_SRCS}) add_library(_ds_st_kv_cache_obj OBJECT ${DS_ST_KV_CACHE_SRCS} ${ST_COMMON_SRCS}) add_library(_ds_st_other_obj OBJECT ${DS_TEST_ST_SRCS} ${ST_COMMON_SRCS}) +target_link_libraries(_ds_st_stream_cache_obj nlohmann_json::nlohmann_json generic_service_protos ds_worker datasystem) target_link_libraries(_ds_st_object_cache_obj nlohmann_json::nlohmann_json generic_service_protos ds_worker datasystem) target_link_libraries(_ds_st_kv_cache_obj nlohmann_json::nlohmann_json generic_service_protos ds_worker datasystem) target_link_libraries(_ds_st_other_obj nlohmann_json::nlohmann_json generic_service_protos ds_worker datasystem) +add_executable(ds_st_stream_cache $) add_executable(ds_st_object_cache $) add_executable(ds_st_kv_cache $) add_executable(ds_st $) target_link_libraries(ds_st PRIVATE ${DS_ST_DEPEND_LIBS}) +target_link_libraries(ds_st_stream_cache PRIVATE ${DS_ST_DEPEND_LIBS}) target_link_libraries(ds_st_object_cache PRIVATE ${DS_ST_DEPEND_LIBS}) target_link_libraries(ds_st_kv_cache PRIVATE ${DS_ST_DEPEND_LIBS}) @@ -223,6 +237,7 @@ add_custom_command( ${CMAKE_BINARY_DIR}/tests/st/data/client_plaintext_zmq_curve_test/worker_authorized_clients/worker.key) add_datasystem_test(ds_st TEST_ENVIRONMENTS ${TEST_ENVIRONMENT}) +add_datasystem_test(ds_st_stream_cache TEST_ENVIRONMENTS ${TEST_ENVIRONMENT}) add_datasystem_test(ds_st_object_cache TEST_ENVIRONMENTS ${TEST_ENVIRONMENT}) add_datasystem_test(ds_st_kv_cache TEST_ENVIRONMENTS ${TEST_ENVIRONMENT}) diff --git a/tests/st/client/kv_cache/kv_cache_client_expire_test.cpp b/tests/st/client/kv_cache/kv_cache_client_expire_test.cpp index 27abc0b..d71fb57 100644 --- a/tests/st/client/kv_cache/kv_cache_client_expire_test.cpp +++ b/tests/st/client/kv_cache/kv_cache_client_expire_test.cpp @@ -221,11 +221,11 @@ void SetClusterSetupOptions(ExternalClusterOptions &opts) override opts.numWorkers = workerNum_; opts.addNodeTime = 3; // add node time is 3 sec opts.workerGflagParams = FormatString( - " -v=1 -node_timeout_s=%d -node_dead_timeout_s=%d -other_az_names=AZ1,AZ2 " + " -v=1 -node_timeout_s=%d -node_dead_timeout_s=%d -other_cluster_names=AZ1,AZ2 " "-cross_az_get_meta_from_worker=true -cross_az_get_data_from_worker=true", timeoutS_, deadTimeoutS_); for (size_t i = 0; i < workerNum_; i++) { - std::string param = "-az_name=" + azNames_[i % azNames_.size()]; + std::string param = "-cluster_name=" + azNames_[i % azNames_.size()]; opts.workerSpecifyGflagParams[i] += param; } } diff --git a/tests/st/client/kv_cache/kv_cache_client_storage_test.cpp b/tests/st/client/kv_cache/kv_cache_client_storage_test.cpp new file mode 100644 index 0000000..9d6d9f4 --- /dev/null +++ b/tests/st/client/kv_cache/kv_cache_client_storage_test.cpp @@ -0,0 +1,441 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: state cache save to cloud storage test + */ + +#include +#include +#include +#include +#include + +#include "client/kv_cache/kv_client_common.h" +#include "common.h" +#include "client/object_cache/oc_client_common.h" +#include "datasystem/common/flags/flags.h" +#include "datasystem/common/kvstore/etcd/etcd_constants.h" +#include "datasystem/common/util/file_util.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/common/util/timer.h" + +namespace datasystem { +namespace st { +const uint32_t ASYNC_DELETE_TIME_MS = 200; + +class KVCacheClientStorageTest : virtual public OCClientCommon, public KVClientCommon { +public: + std::vector workerAddress_; + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + constexpr uint8_t WORKER_NUM = 1; + opts.workerGflagParams = "-check_async_queue_empty_time_s=15"; + opts.numEtcd = 1; + opts.numOBS = 1; + opts.numWorkers = WORKER_NUM; + opts.enableDistributedMaster = "true"; + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + externalCluster_ = dynamic_cast(cluster_.get()); + } + + void InitConnectOptW(uint32_t workerIndex, ConnectOptions &connectOptions, int32_t timeoutMs = 60000) + { + HostPort workerAddress; + ASSERT_TRUE(workerIndex < cluster_->GetWorkerNum()); + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex, workerAddress)); + connectOptions = { .host = workerAddress.Host(), .port = workerAddress.Port(), .connectTimeoutMs = timeoutMs }; + connectOptions.accessKey = "QTWAOYTTINDUT2QVKYUC"; + connectOptions.secretKey = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + connectOptions.tenantId = "tanfasd"; + } + + void InitTestKVClientWithTenant(uint32_t workerIndex, std::shared_ptr &client) + { + ConnectOptions connectOptions; + InitConnectOptW(workerIndex, connectOptions); + client = std::make_shared(connectOptions); + DS_ASSERT_OK(client->Init()); + } + +protected: + ExternalCluster *externalCluster_ = nullptr; +}; + +TEST_F(KVCacheClientStorageTest, TestSetWriteBackDel) +{ + std::shared_ptr client, client1; + InitTestKVClient(0, client); + std::string key = "key"; + std::string value = "value"; + SetParam param{ .writeMode = WriteMode::WRITE_BACK_L2_CACHE }; + ASSERT_EQ(client->Set(key, value, param), Status::OK()); + std::string valueGet; + ASSERT_EQ(client->Get(key, valueGet), Status::OK()); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "ObjectMetaStore.AsyncMetaOpToEtcdStorageHandler.Delay.MetaTable", + "call(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "global_cache_delete.delete_objects", "sleep(5000)")); + ASSERT_EQ(client->Del(key), Status::OK()); + sleep(1); + DS_ASSERT_NOT_OK(client->Get(key, valueGet)); +} + +TEST_F(KVCacheClientStorageTest, TestSetStorageRepeatly) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + + std::string key1 = "key1"; + std::string value1 = "value1"; + std::string value2 = "value2"; + SetParam param1{ .writeMode = WriteMode::WRITE_BACK_L2_CACHE }; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.before_pop_from_queue", "100*sleep(5)")); + Status setRes = client->Set(key1, value1, param1); + ASSERT_TRUE(setRes.IsOk()); + setRes = client->Set(key1, value2, param1); + ASSERT_TRUE(setRes.IsOk()); +} + +TEST_F(KVCacheClientStorageTest, TestSetStorageNotExit) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "persistence.service.del", "sleep(1000)")); + std::string value1 = "value1"; + SetParam param1{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + for (int i = 0; i < 5; i++) { // obj num is 5 + auto key1 = GetStringUuid(); + DS_ASSERT_OK(client->Set(key1, value1, param1)); + } +} + +TEST_F(KVCacheClientStorageTest, LEVEL2_TestSetStorageWhenWorkershuttingdown) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + + ThreadPool threadPool(1); + threadPool.Execute([this, &client]() { + std::string key = "key1"; + std::string value = "value1"; + SetParam param{ .writeMode = WriteMode::WRITE_BACK_L2_CACHE }; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "persistence.service.save", "100*return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.before_pop_from_queue", "100*sleep(5)")); + int requestNum = 10; + int intervalTimeS = 2; + for (int i = 0; i < requestNum; i++) { + DS_ASSERT_OK(client->Set(key, value, param)); + std::this_thread::sleep_for(std::chrono::seconds(intervalTimeS)); + } + client.reset(); + }); + externalCluster_->ShutdownNode(WORKER, 0); +} + +TEST_F(KVCacheClientStorageTest, LEVEL1_TestSetStorageMutable) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + + std::string key1 = "key1"; + std::string value1 = "value1"; + SetParam param1{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + Status setRes = client->Set(key1, value1, param1); + ASSERT_TRUE(setRes.IsOk()); + + Status setAgainRes = client->Set(key1, value1, param1); + ASSERT_TRUE(setAgainRes.IsOk()); +} + +TEST_F(KVCacheClientStorageTest, TestDelStorageObjNeedDelay) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + + std::string key1 = "key1"; + std::string value1 = "value1"; + SetParam param1{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + Status setRes = client->Set(key1, value1, param1); + ASSERT_TRUE(setRes.IsOk()); + + Status delRes = client->Del(key1); + ASSERT_TRUE(delRes.IsOk()); +} + +class KVCacheSpillTest : public KVCacheClientStorageTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) + { + KVCacheClientStorageTest::SetClusterSetupOptions(opts); + const int workrCount = 2; + opts.numWorkers = workrCount; + opts.enableSpill = true; + opts.numEtcd = 1; + opts.workerGflagParams += " -shared_memory_size_mb=10 -v=1 -spill_size_limit=104857600"; + opts.injectActions = "worker.Spill.Sync:return()"; + } +}; + +TEST_F(KVCacheSpillTest, LEVEL1_TestSpillThenSpill) +{ + std::shared_ptr client1; + std::shared_ptr client2; + InitTestKVClient(0, client1); + InitTestKVClient(1, client2); + + std::string value(1020 * 1024, 'a'); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.async_send.before_send", "8*sleep(1000)")); + + SetParam param{ .writeMode = WriteMode::WRITE_BACK_L2_CACHE }; + std::vector objects; + const int loopCnt = 100; + for (int i = 0; i < loopCnt; i++) { + std::string key = "key-" + std::to_string(i); + DS_ASSERT_OK(client1->Set(key, value, param)); + objects.emplace_back(std::move(key)); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + // try again. + for (int i = 0; i < loopCnt; i++) { + std::string key = "again-" + std::to_string(i); + DS_ASSERT_OK(client1->Set(key, value, param)); + objects.emplace_back(std::move(key)); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + for (size_t i = 0; i < objects.size(); i++) { + std::string id = objects[i]; + std::string val; + if (i % 2 == 0) { + // 1. object evict again after delete spill data, will get from obs(mock). + // 2. object still in memory. + DS_ASSERT_OK(client1->Get(id, val)); + } else { + DS_ASSERT_OK(client2->Get(id, val)); + } + ASSERT_TRUE(val == value); + } + + for (const auto &id : objects) { + DS_ASSERT_OK(client1->Del(id)); + } +} + +TEST_F(KVCacheSpillTest, LEVEL1_TestWorkerShutdown) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + const size_t dataSize = 1020 * 1024; // 1MB + std::string value(dataSize, 'a'); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "persistence.service.save", "return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "master.ProcessDeleteObjects", "sleep(5000)")); + SetParam param{ .writeMode = WriteMode::WRITE_BACK_L2_CACHE }; + std::vector objects; + const int loopCnt = 100; + for (int i = 0; i < loopCnt; i++) { + std::string key = "key-" + std::to_string(i); + DS_ASSERT_OK(client->Set(key, value, param)); + objects.emplace_back(std::move(key)); + } + + for (const auto &id : objects) { + DS_ASSERT_OK(client->Del(id)); + } + std::this_thread::sleep_for(std::chrono::seconds(1)); +} + +TEST_F(KVCacheSpillTest, LEVEL1_TestSpillSpaceFull) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + + const size_t testSize = 1024 * 1024; + std::string value(testSize, 'a'); + + std::vector objects; + const int loopCnt = 50; + const int threads = 5; + ThreadPool pool(threads); + for (int i = 0; i < threads; i++) { + pool.Execute([i, &client, value] { + for (int j = 0; j < loopCnt; j++) { + std::string key = "key-" + std::to_string(i) + std::to_string(j); + SetParam param{ .writeMode = WriteMode::WRITE_BACK_L2_CACHE }; + DS_ASSERT_OK(client->Set(key, value, param)); + } + }); + } +} + +class KVCacheNoMetaClientStorageTest : public KVCacheClientStorageTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.workerGflagParams = "-check_async_queue_empty_time_s=15 -oc_io_from_l2cache_need_metadata=false"; + opts.numEtcd = 1; + opts.numWorkers = WORKER_NUM; + opts.numOBS = 1; + opts.enableDistributedMaster = "true"; + } + uint8_t WORKER_NUM = 2; +}; + +TEST_F(KVCacheNoMetaClientStorageTest, LEVEL2_TestDeleteThenPutAndGetAfterRestart) +{ + std::shared_ptr client; + int timeOutMs = 5000; + InitTestKVClient(0, client, timeOutMs); // timeout is 5000 ms + + std::string key = "key"; + std::string targetPath = GetTestCaseDataDir() + "/OBS/test/" + key; + DS_ASSERT_OK(CreateDir(targetPath, true)); + std::ofstream outfile(targetPath + "/100"); + ASSERT_TRUE(outfile.is_open()); + outfile << "hello"; + outfile.close(); + + for (int i = 0; i < WORKER_NUM; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "worker.DelPersistence.delay", "call(3)")); + } + + DS_ASSERT_OK(client->Del(key)); + sleep(1); + SetParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + DS_ASSERT_OK(client->Set(key, "value", param)); + + const int sleepTime = 5; // 5s; + sleep(sleepTime); + for (int i = 0; i < WORKER_NUM; i++) { + DS_ASSERT_OK(cluster_->QuicklyShutdownWorker(i)); + DS_ASSERT_OK(externalCluster_->StartWorker(i, HostPort(), " -client_reconnect_wait_s=1")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, i)); + } + std::string val; + auto func = [&client, &key, &val] { return client->Get(key, val); }; + auto waitTime = 15; + DS_ASSERT_OK(cluster_->WaitForExpectedResult(func, waitTime, K_OK)); + ASSERT_EQ(val, "value"); +} + +TEST_F(KVCacheNoMetaClientStorageTest, DISABLED_TestDeleteVersionBetweenPutAndDel) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + + std::string key = "key"; + SetParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + DS_ASSERT_OK(client->Set(key, "data az1", param)); + + // simulate other az put data to L2 cache. + std::string targetPath = GetTestCaseDataDir() + "/OBS/test/" + key; + DS_ASSERT_OK(CreateDir(targetPath, true)); + uint64_t version = GetSystemClockTimeStampUs(); + std::ofstream outfile(targetPath + "/" + std::to_string(version)); + ASSERT_TRUE(outfile.is_open()); + outfile << "data az2"; + outfile.close(); + + DS_ASSERT_OK(client->Del(key)); + std::string val; + DS_ASSERT_NOT_OK(client->Get(key, val)); +} + +TEST_F(KVCacheNoMetaClientStorageTest, LEVEL2_TestDeleteL2AfterWorkerRestart) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + + SetParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + int keyCount = 32; + std::vector keys; + for (int i = 0; i < keyCount; i++) { + auto val = "data-" + std::to_string(i); + auto key = client->Set(val, param); + ASSERT_TRUE(!key.empty()); + keys.emplace_back(key); + } + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.DelPersistenceObj.beforeDel", "return(K_TRY_AGAIN)")); + for (auto &key : keys) { + DS_ASSERT_OK(client->Del(key)); + } + DS_ASSERT_OK(cluster_->ShutdownNode(ClusterNodeType::WORKER, 0)); + DS_ASSERT_OK(externalCluster_->StartWorker(0, HostPort())); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + int delayMs = 3000; + std::this_thread::sleep_for(std::chrono::milliseconds(delayMs)); + for (auto &key : keys) { + std::string val; + DS_ASSERT_NOT_OK(client->Get(key, val)); + } +} + +class KVCacheClientStorageTestMultiNode : public KVCacheClientStorageTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + KVCacheClientStorageTest::SetClusterSetupOptions(opts); + opts.workerGflagParams += " -object_del_retry_delay_sec=1"; + opts.numEtcd = 1; + opts.numWorkers = 3; // start 3 nodes. + opts.enableDistributedMaster = "true"; + opts.disableRocksDB = false; + } +}; + +TEST_F(KVCacheClientStorageTestMultiNode, TestGlobalCacheTableNotLeakAfterRestart) +{ + std::shared_ptr client; + InitTestKVClient(0, client); + std::string value = "value"; + SetParam param{ .writeMode = WriteMode::WRITE_BACK_L2_CACHE }; + auto key = client->Set(value, param); + ASSERT_NE(key, ""); + DS_ASSERT_OK(cluster_->SetInjectAction( + WORKER, 0, "ObjectMetaStore.AsyncMetaOpToEtcdStorageHandler.Delay.GlobalCacheTable.PassAdd", "call(10000)")); + DS_ASSERT_OK(client->Del(key)); + int waitTimeSec = 2; + sleep(waitTimeSec); // Now the data in L2 cache has been deleted, but the GlobalCacheTable has not been deleted. + DS_ASSERT_OK(externalCluster_->RestartWorkerAndWaitReadyOneByOne( + { 0 }, SIGKILL)); // Do not wait for asynchronous tasks to complete + + InitTestEtcdInstance(); + std::string tableName = ETCD_GLOBAL_CACHE_TABLE_PREFIX; + tableName.pop_back(); + DS_ASSERT_OK(db_->CreateTable(tableName, tableName)); + const int maxWaitTimeSec = 5; + Timer timer; + std::vector> outKeyValues; + while (timer.ElapsedMilliSecond() < maxWaitTimeSec * SECS_TO_MS) { + outKeyValues.clear(); + DS_ASSERT_OK(db_->GetAll(tableName, outKeyValues)); + if (outKeyValues.empty()) { + break; + } + const int intervalMs = 100; + std::this_thread::sleep_for(std::chrono::milliseconds(intervalMs)); + } + ASSERT_TRUE(outKeyValues.empty()); +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/kv_cache/kv_cache_client_test.cpp b/tests/st/client/kv_cache/kv_cache_client_test.cpp index 8e52ea4..6e75807 100644 --- a/tests/st/client/kv_cache/kv_cache_client_test.cpp +++ b/tests/st/client/kv_cache/kv_cache_client_test.cpp @@ -35,6 +35,7 @@ #include "client/object_cache/oc_client_common.h" #include "cluster/base_cluster.h" #include "common.h" +#include "common_distributed_ext.h" #include "datasystem/client/object_cache/client_worker_api.h" #include "datasystem/common/inject/inject_point.h" #include "datasystem/common/metrics/res_metric_collector.h" @@ -44,6 +45,7 @@ #include "datasystem/common/util/thread_pool.h" #include "datasystem/kv/read_only_buffer.h" #include "datasystem/kv_client.h" +#include "datasystem/object/object_enum.h" #include "datasystem/utils/connection.h" #include "datasystem/utils/status.h" #include "datasystem/common/flags/flags.h" @@ -942,6 +944,7 @@ TEST_F(KVCacheClientTest, TestQueryMetaRetry) InitTestKVClient(1, client1, timeoutMs); std::string valueGet; auto rc = client1->Get(key, valueGet); + // ZMQ should be successful, because of the dispatch mode, the test point for uRPC is retry times. if (rc.IsError()) { std::string errMsg = rc.ToString(); std::string checkStr = "RPC unavailable * 2"; @@ -1772,11 +1775,11 @@ public: void InitClients() { InitTestKVClient(0, client_, - [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId_); }); + [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId_); }); InitTestKVClient(0, client1_, - [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId1_); }); + [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId1_); }); InitTestKVClient(1, client2_, - [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId2_); }); + [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId2_); }); } std::shared_ptr client_; @@ -1898,6 +1901,341 @@ TEST_F(KVClientQuerySizeTest, TestRPCError) ASSERT_EQ(outSizes.capacity(), keyCount); } +class KVClientWriteRocksdbTest : public OCClientCommon, public CommonDistributedExt { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + const int workerCount = 2; + opts.numEtcd = 1; + opts.numOBS = 1; + opts.numWorkers = workerCount; + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = " -v=2 -log_monitor=true "; + opts.disableRocksDB = false; + } + + void SetUp() override + { + CommonTest::SetUp(); + DS_ASSERT_OK(Init()); + ASSERT_TRUE(cluster_ != nullptr); + DS_ASSERT_OK(cluster_->StartEtcdCluster()); + DS_ASSERT_OK(cluster_->StartOBS()); + externalCluster_ = dynamic_cast(cluster_.get()); + } + + void TearDown() override + { + ExternalClusterTest::TearDown(); + } + + BaseCluster *GetCluster() override + { + return cluster_.get(); + } + + void VoluntaryScaleDownInject(int workerIdx) + { + std::string checkFilePath = FLAGS_log_dir.c_str(); + std::string client = "client"; + checkFilePath = checkFilePath.substr(0, checkFilePath.length() - client.length()) + "/worker" + + std::to_string(workerIdx) + "/log/worker-status"; + std::ofstream ofs(checkFilePath); + if (!ofs.is_open()) { + LOG(ERROR) << "Can not open worker status file in " << checkFilePath + << ", voluntary scale in will not start, errno: " << errno; + } else { + ofs << "voluntary scale in\n"; + } + ofs.close(); + kill(cluster_->GetWorkerPid(workerIdx), SIGTERM); + } + + void StartWorkerAndWaitReady(std::initializer_list indexes, + const std::unordered_map &workerFlags = {}, int maxWaitTimeSec = 20) + { + for (auto i : indexes) { + std::string flags; + auto iter = workerFlags.find(i); + if (iter != workerFlags.end()) { + flags = " " + iter->second; + } + ASSERT_TRUE(externalCluster_->StartWorker(i, HostPort(), flags).IsOk()) << i; + } + for (auto i : indexes) { + ASSERT_TRUE(cluster_->WaitNodeReady(WORKER, i, maxWaitTimeSec).IsOk()) << i; + } + for (auto i : indexes) { + // When the scale-in scenario is tested, the scale-in failure may not be determined correctly. + // Therefore, the scale-in failure is directly exited. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "Hashring.Scaletask.Fail", "abort()")); + } + InitWorkersInfoMap(indexes); + } + + void StartWorkerAndWaitReady(std::initializer_list indexes, const std::string &flags, int maxWaitTimeSec = 20) + { + std::unordered_map workerFlags; + for (auto i : indexes) { + workerFlags.emplace(i, flags); + } + StartWorkerAndWaitReady(indexes, workerFlags, maxWaitTimeSec); + } + +protected: + ExternalCluster *externalCluster_ = nullptr; +}; + +TEST_F(KVClientWriteRocksdbTest, TestNoneModeNoneL2Cache) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=none"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + InitTestKVClient(0, client1); + std::string val; + ASSERT_EQ(client1->Get(key1, val).GetCode(), StatusCode::K_NOT_FOUND); +} + +TEST_F(KVClientWriteRocksdbTest, TestNoneModeL2Cache) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=none"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + SetParam param1; + param1.writeMode = WriteMode::WRITE_THROUGH_L2_CACHE; + DS_ASSERT_OK(client1->Set(key1, data, param1)); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "master.before_sub_async_send_etcd_req", "1*return(K_OK)")); + std::string key2; + (void)client1->GenerateKey("", key2); + SetParam param2; + param2.writeMode = WriteMode::WRITE_BACK_L2_CACHE; + DS_ASSERT_OK(client1->Set(key2, data, param2)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + InitTestKVClient(0, client1); + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); + ASSERT_EQ(client1->Get(key2, val).GetCode(), StatusCode::K_NOT_FOUND); +} + +TEST_F(KVClientWriteRocksdbTest, TestNoneModeVoluntaryScaleDown) +{ + StartWorkerAndWaitReady({ 0, 1 }, + "-node_timeout_s=5 -node_dead_timeout_s=8 -enable_lossless_data_exit_mode=true " + "-rocksdb_write_mode=none"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + + VoluntaryScaleDownInject(0); + sleep(3); // Wait 3 seconds for voluntary scale down finished + + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); +} + +TEST_F(KVClientWriteRocksdbTest, TestNoneModeScaleUp) +{ + StartWorkerAndWaitReady({ 0 }, "-rocksdb_write_mode=none"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + StartWorkerAndWaitReady({ 1 }); + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); +} + +TEST_F(KVClientWriteRocksdbTest, TestSyncModeNoneL2Cache) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=sync"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + InitTestKVClient(0, client1); + std::string val; + ASSERT_EQ(client1->Get(key1, val).GetCode(), StatusCode::K_RUNTIME_ERROR); +} + +TEST_F(KVClientWriteRocksdbTest, TestSyncModeL2Cache) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=sync"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + SetParam param1; + param1.writeMode = WriteMode::WRITE_THROUGH_L2_CACHE; + DS_ASSERT_OK(client1->Set(key1, data, param1)); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "master.before_sub_async_send_etcd_req", "1*return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "persistence.service.save", "return(K_OK)")); + std::string key2; + (void)client1->GenerateKey("", key2); + SetParam param2; + param2.writeMode = WriteMode::WRITE_BACK_L2_CACHE; + DS_ASSERT_OK(client1->Set(key2, data, param2)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + InitTestKVClient(0, client1); + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); + ASSERT_EQ(client1->Get(key2, val).GetCode(), StatusCode::K_RUNTIME_ERROR); + ASSERT_EQ(val, data); +} +TEST_F(KVClientWriteRocksdbTest, TestSyncModeVoluntaryScaleDown) +{ + StartWorkerAndWaitReady({ 0, 1 }, + "-node_timeout_s=5 -node_dead_timeout_s=8 -enable_lossless_data_exit_mode=true " + "-rocksdb_write_mode=sync"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + + VoluntaryScaleDownInject(0); + sleep(3); // Wait 3 seconds for voluntary scale down finished + + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); +} + +TEST_F(KVClientWriteRocksdbTest, TestSyncModeScaleUp) +{ + StartWorkerAndWaitReady({ 0 }, "-rocksdb_write_mode=sync"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + StartWorkerAndWaitReady({ 1 }); + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); +} + +TEST_F(KVClientWriteRocksdbTest, TestASyncModeNoneL2Cache) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=async"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + InitTestKVClient(0, client1); + std::string val; + ASSERT_EQ(client1->Get(key1, val).GetCode(), StatusCode::K_RUNTIME_ERROR); +} + +TEST_F(KVClientWriteRocksdbTest, TestASyncModeL2Cache) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=async"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + SetParam param1; + param1.writeMode = WriteMode::WRITE_THROUGH_L2_CACHE; + DS_ASSERT_OK(client1->Set(key1, data, param1)); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "master.before_sub_async_send_etcd_req", "1*return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "persistence.service.save", "return(K_OK)")); + std::string key2; + (void)client1->GenerateKey("", key2); + SetParam param2; + param2.writeMode = WriteMode::WRITE_BACK_L2_CACHE; + DS_ASSERT_OK(client1->Set(key2, data, param2)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + InitTestKVClient(0, client1); + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); + ASSERT_EQ(client1->Get(key2, val).GetCode(), StatusCode::K_RUNTIME_ERROR); +} + +TEST_F(KVClientWriteRocksdbTest, TestASyncModeVoluntaryScaleDown) +{ + StartWorkerAndWaitReady({ 0, 1 }, + "-node_timeout_s=5 -node_dead_timeout_s=8 -enable_lossless_data_exit_mode=true " + "-rocksdb_write_mode=async"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + + VoluntaryScaleDownInject(0); + sleep(3); // Wait 3 seconds for voluntary scale down finished + + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); +} + +TEST_F(KVClientWriteRocksdbTest, TestASyncModeScaleUp) +{ + StartWorkerAndWaitReady({ 0 }, "-rocksdb_write_mode=async"); + std::shared_ptr client1; + InitTestKVClient(0, client1); + uint64_t size = 128; + std::string data = GenRandomString(size); + std::string key1; + (void)client1->GenerateKey("", key1); + DS_ASSERT_OK(client1->Set(key1, data)); + StartWorkerAndWaitReady({ 1 }); + std::string val; + DS_ASSERT_OK(client1->Get(key1, val)); + ASSERT_EQ(val, data); +} } // namespace st } // namespace datasystem diff --git a/tests/st/client/kv_cache/kv_client_common.h b/tests/st/client/kv_cache/kv_client_common.h new file mode 100644 index 0000000..96ef1b2 --- /dev/null +++ b/tests/st/client/kv_cache/kv_client_common.h @@ -0,0 +1,57 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Util function for kv client tests. + */ +#ifndef DATASYSTEM_TEST_ST_CLIENT_KV_CACHE_KV_CLIENT_COMMON_H +#define DATASYSTEM_TEST_ST_CLIENT_KV_CACHE_KV_CLIENT_COMMON_H + +#include "client/object_cache/oc_client_common.h" +namespace datasystem { +namespace st { +class KVClientCommon : virtual public OCClientCommon { +public: + void InitTestEtcdInstance(std::vector otherAzNames = {}) + { + if (db_ != nullptr) { + return; + } + std::string etcdAddress; + for (size_t i = 0; i < cluster_->GetEtcdNum(); ++i) { + std::pair addrs; + cluster_->GetEtcdAddrs(i, addrs); + if (!etcdAddress.empty()) { + etcdAddress += ","; + } + etcdAddress += addrs.first.ToString(); + } + FLAGS_etcd_address = etcdAddress; + db_ = std::make_unique(etcdAddress); + DS_ASSERT_OK(db_->Init()); + (void)db_->CreateTable(ETCD_RING_PREFIX, ETCD_RING_PREFIX); + (void)db_->CreateTable(ETCD_CLUSTER_TABLE, "/" + std::string(ETCD_CLUSTER_TABLE)); + for (const auto &otherAzName : otherAzNames) { + auto otherAzRingStr = "/" + otherAzName + ETCD_RING_PREFIX; + (void)db_->CreateTable(otherAzRingStr, otherAzRingStr); + } + } +protected: + std::unique_ptr db_; +}; +} // namespace st +} // namespace datasystem +#endif // DATASYSTEM_TEST_ST_CLIENT_KV_CACHE_KV_CLIENT_COMMON_H diff --git a/tests/st/client/kv_cache/kv_client_cross_az_test.cpp b/tests/st/client/kv_cache/kv_client_cross_az_test.cpp index c1f1d04..30caff2 100644 --- a/tests/st/client/kv_cache/kv_client_cross_az_test.cpp +++ b/tests/st/client/kv_cache/kv_client_cross_az_test.cpp @@ -73,14 +73,14 @@ public: opts.enableDistributedMaster = "true"; opts.numOBS = 1; std::string OBSGflag = FormatString( - "-other_az_names=AZ1,AZ2,AZ3 " + "-other_cluster_names=AZ1,AZ2,AZ3 " "-v=2 " "-cross_az_get_meta_from_worker=%s -inject_actions=TryGetObjectFromRemote.NoRetry:call() ", crossAzGetMetaFromWorker_) + appendCmd_; opts.workerGflagParams = OBSGflag; for (size_t i = 0; i < workerNum_; i++) { - std::string param = "-az_name="; + std::string param = "-cluster_name="; if (i % MASTER_NUM == 0) { param.append(AZ1); } else { @@ -436,7 +436,7 @@ public: opts.enableDistributedMaster = "true"; opts.numOBS = 1; std::string obsGflag = - "-other_az_names=AZ1,AZ2,AZ3 " + "-other_cluster_names=AZ1,AZ2,AZ3 " "-system_access_key=datasystem_ak " "-system_secret_key=datasystem_ak " "-authorization_enable=true " @@ -445,7 +445,7 @@ public: opts.workerGflagParams = obsGflag; for (size_t i = 0; i < DEFAULT_WORKER_NUM; i++) { - std::string param = "-az_name="; + std::string param = "-cluster_name="; if (i % MASTER_NUM == 0) { param.append(AZ1); } else { @@ -479,14 +479,14 @@ public: opts.addNodeTime = SCALE_RESTART_ADD_TIME; std::string obsGflag = "-shared_memory_size_mb=5120 -node_timeout_s=3 -node_dead_timeout_s=8 -auto_del_dead_node=false " - "-other_az_names=AZ1,AZ2 -v=1 -log_monitor=true"; + "-other_cluster_names=AZ1,AZ2 -v=1 -log_monitor=true"; FLAGS_v = 1; opts.workerGflagParams = obsGflag; for (size_t i = 0; i < DEFAULT_WORKER_NUM; i++) { opts.workerConfigs.emplace_back(HOST_IP, GetFreePort()); workerAddress_.emplace_back(opts.workerConfigs.back().ToString()); - std::string param = "-az_name="; + std::string param = "-cluster_name="; if (i < EACH_AZ_WORKER_NUM) { param.append(AZ1); } else { @@ -1021,14 +1021,14 @@ public: opts.numOBS = 1; std::string gflag = " -v=2 -shared_memory_size_mb=5120 -node_timeout_s=3 -node_dead_timeout_s=8 -auto_del_dead_node=true " - "-other_az_names=AZ1,AZ2,AZ3,AZ4 -cross_az_get_meta_from_worker=true -v=2"; + "-other_cluster_names=AZ1,AZ2,AZ3,AZ4 -cross_az_get_meta_from_worker=true -v=2"; opts.workerGflagParams = gflag; std::vector otherAzNames = { "AZ1", "AZ2", "AZ3", "AZ4" }; for (size_t i = 0; i < workerNum_; i++) { opts.workerConfigs.emplace_back(HOST_IP, GetFreePort()); workerAddress_.emplace_back(opts.workerConfigs.back().ToString()); - std::string param = "-az_name="; + std::string param = "-cluster_name="; param.append(otherAzNames[i]); opts.workerSpecifyGflagParams[i] = param; } @@ -1159,11 +1159,12 @@ public: opts.numOBS = 1; std::string gflag = " -v=2 -shared_memory_size_mb=5120 -node_timeout_s=3 -node_dead_timeout_s=8 -auto_del_dead_node=true " - "-other_az_names=AZ1,AZ2 -cross_az_get_meta_from_worker=true -oc_io_from_l2cache_need_metadata=false -v=2"; + "-other_cluster_names=AZ1,AZ2 -cross_az_get_meta_from_worker=true -oc_io_from_l2cache_need_metadata=false " + "-v=2"; opts.workerGflagParams = gflag; - opts.workerSpecifyGflagParams[0] += " -az_name=AZ1 "; - opts.workerSpecifyGflagParams[1] += " -az_name=AZ2 "; + opts.workerSpecifyGflagParams[0] += " -cluster_name=AZ1 "; + opts.workerSpecifyGflagParams[1] += " -cluster_name=AZ2 "; std::vector otherAzNames = { "AZ1", "AZ2" }; for (size_t i = 0; i < workerNum_; i++) { opts.workerConfigs.emplace_back(HOST_IP, GetFreePort()); @@ -1236,7 +1237,7 @@ public: for (size_t i = 0; i < workerNum_; i++) { opts.workerConfigs.emplace_back(HOST_IP, GetFreePort()); workerAddress_.emplace_back(opts.workerConfigs.back().ToString()); - std::string param = "-az_name="; + std::string param = "-cluster_name="; param.append(otherAzNames_[i % otherAzNames_.size()]); opts.workerSpecifyGflagParams[i] = param; } @@ -1273,7 +1274,7 @@ protected: const std::vector otherAzNames_ = { "AZ1", "AZ2", "AZ3", "AZ4" }; std::string gflag_ = " -v=2 -shared_memory_size_mb=5120 -node_timeout_s=3 -node_dead_timeout_s=8 -auto_del_dead_node=true " - "-other_az_names=AZ1,AZ2,AZ3,AZ4 -cross_az_get_meta_from_worker=true"; + "-other_cluster_names=AZ1,AZ2,AZ3,AZ4 -cross_az_get_meta_from_worker=true"; }; TEST_F(KVClientCrossAzGetMetaAndDataTwoWorkerPerAz, LEVEL2_TestParallelCrossAzSet) @@ -1485,12 +1486,12 @@ public: opts.disableRocksDB = false; std::string gflag = " -v=1 -shared_memory_size_mb=512 -node_timeout_s=3 -node_dead_timeout_s=8 -auto_del_dead_node=true " - "-other_az_names=AZ1,AZ2 -cross_az_get_meta_from_worker=true "; + "-other_cluster_names=AZ1,AZ2 -cross_az_get_meta_from_worker=true "; opts.workerGflagParams = gflag; for (size_t i = 0; i < workerNum_; i++) { auto azName = azNames_[i % azNames_.size()]; - std::string param = "-az_name=" + azName; + std::string param = "-cluster_name=" + azName; opts.workerSpecifyGflagParams[i] = param; } } diff --git a/tests/st/client/kv_cache/kv_client_etcd_dfx_test.cpp b/tests/st/client/kv_cache/kv_client_etcd_dfx_test.cpp index 24be4bb..bbe2098 100644 --- a/tests/st/client/kv_cache/kv_client_etcd_dfx_test.cpp +++ b/tests/st/client/kv_cache/kv_client_etcd_dfx_test.cpp @@ -324,11 +324,11 @@ public: opts.enableDistributedMaster = enableDistributedMaster_; opts.disableRocksDB = false; opts.workerGflagParams = FormatString( - " -v=1 -node_timeout_s=%d -node_dead_timeout_s=%d -other_az_names=AZ1,AZ2 " + " -v=1 -node_timeout_s=%d -node_dead_timeout_s=%d -other_cluster_names=AZ1,AZ2 " "-cross_az_get_meta_from_worker=true", timeoutS_, deadTimeoutS_); for (size_t i = 0; i < workerNum_; i++) { - std::string param = "-az_name=" + azNames_[i % azNames_.size()]; + std::string param = "-cluster_name=" + azNames_[i % azNames_.size()]; opts.workerSpecifyGflagParams[i] += param; } } @@ -371,11 +371,11 @@ public: opts.addNodeTime = SCALE_UP_ADD_TIME; opts.enableDistributedMaster = enableDistributedMaster_; opts.workerGflagParams = FormatString( - " -v=1 -node_timeout_s=%d -node_dead_timeout_s=%d -other_az_names=AZ1,AZ2 " + " -v=1 -node_timeout_s=%d -node_dead_timeout_s=%d -other_cluster_names=AZ1,AZ2 " "-cross_az_get_meta_from_worker=true", timeoutS_, deadTimeoutS_); for (size_t i = 0; i < workerNum_; i++) { - std::string param = "-az_name=" + azNames_[i % azNames_.size()]; + std::string param = "-cluster_name=" + azNames_[i % azNames_.size()]; opts.workerSpecifyGflagParams[i] += param; } } diff --git a/tests/st/client/kv_cache/kv_client_mset_test.cpp b/tests/st/client/kv_cache/kv_client_mset_test.cpp index a51482e..0292f87 100644 --- a/tests/st/client/kv_cache/kv_client_mset_test.cpp +++ b/tests/st/client/kv_cache/kv_client_mset_test.cpp @@ -106,6 +106,83 @@ protected: std::shared_ptr client0_, client1_, client2_; }; +class KVClientMSetPerfTest : public KVClientMSetTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.waitWorkerReady = false; + opts.numEtcd = 1; + opts.numOBS = 1; + opts.numWorkers = DEFAULT_WORKER_NUM; + opts.enableDistributedMaster = "false"; + opts.workerGflagParams = "-shared_memory_size_mb=3000 -v=0"; + } + + void SetUp() override + { + CommonTest::SetUp(); + DS_ASSERT_OK(Init()); + ASSERT_TRUE(cluster_ != nullptr); + DS_ASSERT_OK(cluster_->StartEtcdCluster()); + DS_ASSERT_OK(cluster_->StartOBS()); + DS_ASSERT_OK(cluster_->StartWorkers()); + for (size_t i = 0; i < DEFAULT_WORKER_NUM; i++) { + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, i)); + } + InitTestKVClient(0, client0_, 20000); // Init client0 to worker 0 with 20000ms timeout + InitTestKVClient(1, client1_, 20000); // Init client1 to worker 1 with 20000ms timeout + InitTestKVClient(2, client2_, 20000); // Init client2 to worker 2 with 20000ms timeout + } + + void TearDown() override + { + client0_.reset(); + client1_.reset(); + client2_.reset(); + ExternalClusterTest::TearDown(); + } + +protected: + std::shared_ptr client0_, client1_, client2_; +}; + +TEST_F(KVClientMSetPerfTest, MsetNtxSmallObj) +{ + MSetParam param; + param.existence = ExistenceOpt::NX; + std::vector sizeNames{ "2KB", "64KB", "128KB" }; + std::vector dataSizes{ 2048, 64 * 1024, 128 * 1024 }, maxElementSizes{ 64, 256, 1024 }; + auto repeateNum = 3u; + std::vector res, failedKeys; + for (auto i = 0u; i < dataSizes.size(); i++) { + for (auto k = 0ul; k < maxElementSizes.size(); k++) { + std::vector costs; + for (auto n = 0u; n < repeateNum; n++) { + std::vector keys, vals; + std::vector values; + auto dataSize = dataSizes[i]; + GenerateKeyValues(keys, vals, maxElementSizes[k], dataSize); + for (const auto &val : vals) { + values.emplace_back(val); + } + std::vector failedKeys; + Timer t; + DS_ASSERT_OK(client2_->MSet(keys, values, failedKeys, param)); + costs.emplace_back(t.ElapsedMicroSecond()); + DS_ASSERT_OK(client2_->Del(keys, failedKeys)); + } + res.emplace_back(FormatString("data_size:%s, key_num:%ld, repeate_num:%ld ---------> avg cost: %d us ", + sizeNames[i], maxElementSizes[k], repeateNum, + std::accumulate(costs.begin(), costs.end(), 0) / costs.size())); + } + } + LOG(INFO) << "--------------------------TEST RESULT--------------------------"; + for (auto item : res) { + LOG(INFO) << item; + } + LOG(INFO) << "-------------------------- END --------------------------"; +} + TEST_F(KVClientMSetTest, CheckPrameterValidation) { MSetParam param; @@ -560,6 +637,27 @@ TEST_F(KVClientMSetTest, MsetNtxExistenceNx) } } +TEST_F(KVClientMSetTest, MsetNtxBigObj) +{ + MSetParam param; + std::vector keys, vals; + std::vector values; + param.existence = ExistenceOpt::NX; + size_t maxElementSize = 20; + auto dataSize = 3000000; + GenerateKeyValues(keys, vals, maxElementSize, dataSize); + for (const auto &val : vals) { + values.emplace_back(val); + } + std::vector failedKeys; + DS_ASSERT_OK(client2_->MSet(keys, values, failedKeys, param)); + for (size_t i = 0; i < maxElementSize; i++) { + std::string val; + DS_ASSERT_OK(client1_->Get(keys[i], val)); + ASSERT_EQ(val, vals[i]); + } +} + TEST_F(KVClientMSetTest, LEVEL1_SetAndAsyncMset) { MSetParam param; @@ -957,8 +1055,6 @@ TEST_F(KVClientMSetTest, MSetNtxInvalidParam) size_t maxElementSize = 2000; keys = std::vector(maxElementSize, ""); ret = client0_->MSet(keys, values, outFailedKeys, param); - str = ret.ToString(); - ASSERT_TRUE(str.find("The maximum size of keys in single operation is less than 2000") != std::string::npos); maxElementSize = 10; // batch set keys size is 10. keys = std::vector(maxElementSize, ""); @@ -975,13 +1071,6 @@ TEST_F(KVClientMSetTest, MSetNtxInvalidParam) keys[i] = "key" + std::to_string(i); } - values.pop_back(); - values.emplace_back(randomData_.GetRandomString(SHM_THRESHOLD)); - ret = client0_->MSet(keys, values, outFailedKeys, param); - str = ret.ToString(); - ASSERT_TRUE(str.find(FormatString("The size for the val must be less than %d Byte", SHM_THRESHOLD)) - != std::string::npos); - values.pop_back(); values.emplace_back("test"); DS_ASSERT_OK(client0_->MSet(keys, values, outFailedKeys, param)); diff --git a/tests/st/client/kv_cache/kv_client_offset_read_one_host_test.cpp b/tests/st/client/kv_cache/kv_client_offset_read_one_host_test.cpp index 6af17bf..02f1137 100644 --- a/tests/st/client/kv_cache/kv_client_offset_read_one_host_test.cpp +++ b/tests/st/client/kv_cache/kv_client_offset_read_one_host_test.cpp @@ -65,8 +65,8 @@ public: { opts.enableSpill = true; opts.workerGflagParams = - "-shared_memory_size_mb=12 -log_async=false -log_monitor=true -shared_disk_directory=./" + GetStringUuid() - + " -shared_disk_size_mb=100 -v=1 -spill_size_limit=" + std::to_string(maxSize_); + "-shared_memory_size_mb=12 -log_async=false -log_monitor=true -spill_size_limit=" + + std::to_string(maxSize_); opts.numEtcd = 1; opts.numWorkers = 1; opts.numOBS = 1; @@ -654,9 +654,8 @@ public: void SetClusterSetupOptions(ExternalClusterOptions &opts) override { opts.enableSpill = true; - opts.workerGflagParams = - "-shared_memory_size_mb=12 -log_async=false -log_monitor=true -shared_disk_directory=./" + GetStringUuid() - + " -shared_disk_size_mb=1024 -v=1 -spill_size_limit=" + std::to_string(maxSize_); + opts.workerGflagParams = "-shared_memory_size_mb=12 -log_async=false -log_monitor=true -spill_size_limit=" + + std::to_string(maxSize_); opts.numEtcd = 1; opts.numWorkers = 2; // worker num is 2 opts.numOBS = 1; @@ -952,8 +951,7 @@ public: void SetClusterSetupOptions(ExternalClusterOptions &opts) override { opts.workerGflagParams = - "-shared_memory_size_mb=1024 -log_async=false -log_monitor=true -shared_disk_directory=./" + GetStringUuid() - + " -shared_disk_size_mb=1024 -v=1 " + std::to_string(maxSize_); + "-shared_memory_size_mb=1024 -log_async=false -log_monitor=true"; opts.numEtcd = 1; opts.numWorkers = 2; // worker num is 2 opts.numOBS = 1; @@ -1261,12 +1259,12 @@ datasystem::SetParam ConstructParam(CacheType type) } INSTANTIATE_TEST_SUITE_P(ReadOffsetParamTest, KVClientOffsetReadOneHostTest, - ::testing::Values(ConstructParam(CacheType::MEMORY), ConstructParam(CacheType::DISK))); + ::testing::Values(ConstructParam(CacheType::MEMORY))); INSTANTIATE_TEST_SUITE_P(ReadOffsetParamTest, KVClientOffsetReadRemoteTest, - ::testing::Values(ConstructParam(CacheType::MEMORY), ConstructParam(CacheType::DISK))); + ::testing::Values(ConstructParam(CacheType::MEMORY))); INSTANTIATE_TEST_SUITE_P(ReadOffsetParamTest, KVClientOffsetReadRemoteParallelTest, - ::testing::Values(ConstructParam(CacheType::MEMORY), ConstructParam(CacheType::DISK))); + ::testing::Values(ConstructParam(CacheType::MEMORY))); } // namespace st } // namespace datasystem \ No newline at end of file diff --git a/tests/st/client/kv_cache/kv_client_replica_test.cpp b/tests/st/client/kv_cache/kv_client_replica_test.cpp index 3418e19..b1b4808 100644 --- a/tests/st/client/kv_cache/kv_client_replica_test.cpp +++ b/tests/st/client/kv_cache/kv_client_replica_test.cpp @@ -806,7 +806,7 @@ public: opts.enableDistributedMaster = "true"; std::string gflag = " -v=1 -shared_memory_size_mb=5120 -node_timeout_s=3 -node_dead_timeout_s=8 -auto_del_dead_node=true " - "-other_az_names=AZ1,AZ2,AZ3 -cross_az_get_meta_from_worker=true -enable_meta_replica=true " + "-other_cluster_names=AZ1,AZ2,AZ3 -cross_az_get_meta_from_worker=true -enable_meta_replica=true " "-oc_io_from_l2cache_need_metadata=true"; opts.workerGflagParams = gflag; @@ -814,7 +814,7 @@ public: for (size_t i = 0; i < WORKER_NUM; i++) { opts.workerConfigs.emplace_back("127.0.0.1", GetFreePort()); workerAddress_.emplace_back(opts.workerConfigs.back().ToString()); - std::string param = "-az_name="; + std::string param = "-cluster_name="; param.append(otherAzNames_[i % otherAzNames_.size()]); opts.workerSpecifyGflagParams[i] = param; } diff --git a/tests/st/client/kv_cache/kv_client_scale_common.h b/tests/st/client/kv_cache/kv_client_scale_common.h index 8c5aa24..5cadc33 100644 --- a/tests/st/client/kv_cache/kv_client_scale_common.h +++ b/tests/st/client/kv_cache/kv_client_scale_common.h @@ -28,6 +28,7 @@ #include +#include "client/kv_cache/kv_client_common.h" #include "client/object_cache/oc_client_common.h" #include "common.h" #include "datasystem/common/util/hash_algorithm.h" @@ -45,33 +46,8 @@ namespace st { constexpr int SCALE_UP_ADD_TIME = 3; constexpr int SCALE_DOWN_ADD_TIME = 5; constexpr int WORKER_RECEIVE_DELAY = 1; -class KVClientScaleCommon : virtual public OCClientCommon { +class KVClientScaleCommon : virtual public OCClientCommon, public KVClientCommon { public: - void InitTestEtcdInstance(std::vector otherAzNames = {}) - { - if (db_ != nullptr) { - return; - } - std::string etcdAddress; - for (size_t i = 0; i < cluster_->GetEtcdNum(); ++i) { - std::pair addrs; - cluster_->GetEtcdAddrs(i, addrs); - if (!etcdAddress.empty()) { - etcdAddress += ","; - } - etcdAddress += addrs.first.ToString(); - } - FLAGS_etcd_address = etcdAddress; - db_ = std::make_unique(etcdAddress); - DS_ASSERT_OK(db_->Init()); - (void)db_->CreateTable(ETCD_RING_PREFIX, ETCD_RING_PREFIX); - (void)db_->CreateTable(ETCD_CLUSTER_TABLE, "/" + std::string(ETCD_CLUSTER_TABLE)); - for (const auto &otherAzName : otherAzNames) { - auto otherAzRingStr = "/" + otherAzName + ETCD_RING_PREFIX; - (void)db_->CreateTable(otherAzRingStr, otherAzRingStr); - } - } - void AssertAllNodesJoinIntoHashRing(int num) { if (!db_) { @@ -190,7 +166,6 @@ public: } protected: - std::unique_ptr db_; const static uint64_t shutdownTimeoutMs = 60'000; // 1min }; diff --git a/tests/st/client/kv_cache/kv_client_scale_test.cpp b/tests/st/client/kv_cache/kv_client_scale_test.cpp index 6617c8f..415ee7b 100644 --- a/tests/st/client/kv_cache/kv_client_scale_test.cpp +++ b/tests/st/client/kv_cache/kv_client_scale_test.cpp @@ -41,7 +41,7 @@ DS_DECLARE_string(etcd_address); DS_DECLARE_string(master_address); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); namespace datasystem { namespace st { @@ -1507,7 +1507,7 @@ TEST_F(STCScaleDownTest, ShutdownWorkerAndDelKeyInEtcdTest) StartWorkerAndWaitReady({ 0 }); HostPort w0; DS_ASSERT_OK(cluster_->GetWorkerAddr(0, w0)); - std::string key = FLAGS_az_name + "/" + ETCD_CLUSTER_TABLE + "/" + w0.ToString(); + std::string key = FLAGS_cluster_name + "/" + ETCD_CLUSTER_TABLE + "/" + w0.ToString(); RangeSearchResult res; DS_ASSERT_OK(db_->RawGet(key, res)); DS_ASSERT_OK(externalCluster_->ShutdownNode(WORKER, 0)); diff --git a/tests/st/client/kv_cache/kv_client_voluntary_scale_down_test.cpp b/tests/st/client/kv_cache/kv_client_voluntary_scale_down_test.cpp index dad0b71..ad079a5 100644 --- a/tests/st/client/kv_cache/kv_client_voluntary_scale_down_test.cpp +++ b/tests/st/client/kv_cache/kv_client_voluntary_scale_down_test.cpp @@ -692,6 +692,35 @@ TEST_F(KVClientVoluntaryScaleDownTest, LEVEL1_UuidObjectSetGetDelAndVoluntarySca DS_ASSERT_OK(cluster_->StartNode(WORKER, 2, "")); } +TEST_F(KVClientVoluntaryScaleDownTest, VoluntaryWorkersOneByOne) +{ + int objectCnt = 50; + int objectCnt1 = 10; + for (size_t i = 1; i < DEFAULT_WORKER_NUM; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "InspectAndProcessPeriodically.skip", "return()")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "OCMetadataManager.ReplacePrimary", "1*sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "worker.migrate_service.return", "1*return(K_NOT_READY)")); + } + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "WorkerOCServiceImpl.MigrateData.Delay", "call(3000)")); + std::vector objectKey(objectCnt); + std::vector objectKey1(objectCnt1); + std::vector data(objectCnt); + std::vector data1(objectCnt1); + SetNormalObject(client0_, 0, objectKey, data, WriteMode::NONE_L2_CACHE); + SetUuidObject(client0_, 0, objectKey1, data1, WriteMode::NONE_L2_CACHE); + VoluntaryScaleDownInject(0); + sleep(2); // wait 2s for voluntary worker 1 + VoluntaryScaleDownInject(1); + sleep(10); // Wait 10 seconds for voluntary scale down finished + AssertWorkerNum(2); // The number of worker is 2 + for (int i = 0; i < objectCnt; ++i) { + std::string getValue; + DS_ASSERT_OK(client2_->Get(objectKey[i], getValue)); + ASSERT_EQ(data[i], getValue); + DS_ASSERT_OK(client2_->Del(objectKey[i])); + } +} + TEST_F(KVClientVoluntaryScaleDownTest, DISABLED_MasterAsyncTaskRecover) { int objectCnt = 15; diff --git a/tests/st/client/object_cache/client_dfx_test.cpp b/tests/st/client/object_cache/client_dfx_test.cpp index 9f1d00f..508e201 100644 --- a/tests/st/client/object_cache/client_dfx_test.cpp +++ b/tests/st/client/object_cache/client_dfx_test.cpp @@ -53,7 +53,7 @@ public: void SetClusterSetupOptions(ExternalClusterOptions &opts) override { opts.workerGflagParams = - "-client_reconnect_wait_s=1 -ipc_through_shared_memory=true -node_timeout_s=1 " + "-client_reconnect_wait_s=1 -ipc_through_shared_memory=true -sc_stream_socket_num=0 -node_timeout_s=1 " "-heartbeat_interval_ms=500"; opts.numWorkers = 3; opts.masterIdx = 1; @@ -286,7 +286,7 @@ TEST_F(WorkerDfxTest, DISABLED_TestWorkerNotFirstGincreaseAndMasterCrash) std::string objectKey = NewObjectKey(); std::string data = "123456"; int dataSize = data.size(); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; std::shared_ptr buffer1; DS_ASSERT_OK(client2->Create(objectKey, dataSize, param, buffer1)); DS_ASSERT_OK(buffer1->MemoryCopy(data.data(), dataSize)); @@ -717,7 +717,7 @@ TEST_F(WorkerDfxTest, LEVEL1_TestChangePrimaryCopy) DS_ASSERT_OK(client1->GIncreaseRef({ objectKey }, failObjects)); int64_t dataSize = 10; std::string value1 = GenRandomString(dataSize); - CreateParam param = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param = { .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client1->Put(objectKey, reinterpret_cast(value1.data()), value1.size(), param)); std::vector> buffers; DS_ASSERT_OK(client3->Get({ objectKey }, 1'000, buffers)); @@ -851,7 +851,7 @@ TEST_F(MasterDfxTest, LEVEL1_TestMasterCrashAndGet) LOG(INFO) << "Test master crash and get"; InitTestClients(); std::string objKey = "Stuart"; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; size_t objSize = 600 * 1024ul; std::shared_ptr buffer; DS_ASSERT_OK(objClient0_->Create(objKey, objSize, param, buffer)); @@ -1989,7 +1989,7 @@ TEST_F(WorkerKillTest, LEVEL1_GetAfterGDecrease) size_t size = 600 * 1024; std::string data = GenRandomString(size); std::shared_ptr buffer; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::PRAM }; + CreateParam param{ .consistencyType = ConsistencyType::PRAM }; DS_ASSERT_OK(client1->Create(objectKey, size, param, buffer)); buffer->MemoryCopy((void *)data.data(), size); buffer->Publish(); @@ -2026,7 +2026,7 @@ TEST_F(WorkerKillTest, DISABLED_ClearPrimaryAfterWorkerRestart) std::string objectKey = "obj1"; int64_t dataSize = 10; std::string value1 = GenRandomString(dataSize); - CreateParam param = { .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param = {}; DS_ASSERT_OK(client1->Put(objectKey, reinterpret_cast(value1.data()), value1.size(), param)); sleep(1); DS_ASSERT_OK(externalCluster_->KillWorker(1)); diff --git a/tests/st/client/object_cache/client_get_test.cpp b/tests/st/client/object_cache/client_get_test.cpp index eb87abe..825cc0b 100644 --- a/tests/st/client/object_cache/client_get_test.cpp +++ b/tests/st/client/object_cache/client_get_test.cpp @@ -1711,7 +1711,7 @@ TEST_F(OCClientRemoteGetTest4, DISABLED_TestObjectPutAndGetConcurrency) std::string objKey = "2-6-shame"; std::string val = RandomData().GetRandomString(1024ul * 1024ul); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client0->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); std::vector failedObjectKeys; DS_ASSERT_OK(client1->GIncreaseRef({ objKey }, failedObjectKeys)); @@ -1768,7 +1768,7 @@ TEST_F(OCClientRemoteGetTest5, TestGetSameObjectConcurrency) std::string objKey = "2-6-shame"; std::string val = RandomData().GetRandomString(1024ul * 1024ul); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client0->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "worker.after_query_meta", "1*sleep(5000)")); @@ -1811,7 +1811,7 @@ TEST_F(OCClientRemoteGetTest5, TestRemoteGetAndRemoveLocationFailedThenPut) std::string objKey = "Ugly_iPhone15"; std::string val = RandomData().GetRandomString(1024ul * 1024ul); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client0->Put(objKey, (uint8_t *)val.c_str(), val.size(), param)); std::vector> buffers; diff --git a/tests/st/client/object_cache/client_update_test.cpp b/tests/st/client/object_cache/client_update_test.cpp index 48ab2ca..5e727fe 100644 --- a/tests/st/client/object_cache/client_update_test.cpp +++ b/tests/st/client/object_cache/client_update_test.cpp @@ -220,7 +220,7 @@ TEST_F(OCClientUpdateTest, BufferInvalidedTest) std::string objectKey = NewObjectKey(); std::shared_ptr data; int bufferSize = 8; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client->Create(objectKey, bufferSize, param, data)); std::string test = "abcdefg"; DS_ASSERT_OK(data->MemoryCopy(const_cast(test.data()), test.length())); @@ -250,7 +250,7 @@ TEST_F(OCClientUpdateTest, MultiUpdateTest) std::string objectKey = NewObjectKey(); std::shared_ptr workerOBuffer; int bufferSize = 8; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client->Create(objectKey, bufferSize, param, workerOBuffer)); std::string worker0_data = "abcdefg"; DS_ASSERT_OK(workerOBuffer->MemoryCopy(const_cast(worker0_data.data()), worker0_data.length())); @@ -316,7 +316,7 @@ TEST_F(OCClientUpdateTest, UpdateDataTest) std::string objectKey = NewObjectKey(); std::shared_ptr data; int bufferSize = 8; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client1->Create(objectKey, bufferSize, param, data)); std::string worker0_data = "abcdefg"; DS_ASSERT_OK(data->MemoryCopy(const_cast(worker0_data.data()), worker0_data.length())); diff --git a/tests/st/client/object_cache/hetero_client_mock_test.cpp b/tests/st/client/object_cache/hetero_client_mock_test.cpp index 45548bd..0bbc69e 100644 --- a/tests/st/client/object_cache/hetero_client_mock_test.cpp +++ b/tests/st/client/object_cache/hetero_client_mock_test.cpp @@ -95,36 +95,36 @@ TEST_F(HeteroClientMockTest, TestGetP2PMeta) auto lifetime = LifetimeType::REFERENCE; auto bufferInfo = std::make_shared(devObjKey, deviceIdx, lifetime, true, TransferType::P2P); size_t dataSize = 1024; - DataInfo info{ nullptr, DataType::DATA_TYPE_INT8, dataSize }; + Blob info{ nullptr, dataSize }; DS_ASSERT_OK(workerClient1->PutP2PMeta(bufferInfo, { info })); std::vector> bufferInfoList; auto getBufferInfo = std::make_shared(devObjKey, deviceIdx, lifetime, true, TransferType::P2P); - std::vector> dataInfoStorageList; - DataInfo getInfo{ nullptr, DataType::DATA_TYPE_INT8, dataSize }; + std::vector devBlobStorageList; + Blob getInfo{ nullptr, dataSize }; bufferInfoList.emplace_back(getBufferInfo); - std::vector listData; + std::vector listData; listData.emplace_back(getInfo); - dataInfoStorageList.emplace_back(listData); + devBlobStorageList.emplace_back(DeviceBlobList{ listData, deviceIdx }); GetP2PMetaRspPb resp; const int64_t timeoutMs = 1000; - DS_ASSERT_OK(workerClient2->GetP2PMeta(bufferInfoList, dataInfoStorageList, resp, timeoutMs)); + DS_ASSERT_OK(workerClient2->GetP2PMeta(bufferInfoList, devBlobStorageList, resp, timeoutMs)); - std::vector dataInfos; - DS_ASSERT_OK(workerClient1->GetDataInfo(devObjKey, timeoutMs, dataInfos)); - ASSERT_EQ(dataInfos.size(), 1); - dataInfos.clear(); - DS_ASSERT_OK(workerClient2->GetDataInfo(devObjKey, timeoutMs, dataInfos)); - ASSERT_EQ(dataInfos.size(), 1); + std::vector blobs; + DS_ASSERT_OK(workerClient1->GetBlobsInfo(devObjKey, timeoutMs, blobs)); + ASSERT_EQ(blobs.size(), 1); + blobs.clear(); + DS_ASSERT_OK(workerClient2->GetBlobsInfo(devObjKey, timeoutMs, blobs)); + ASSERT_EQ(blobs.size(), 1); auto notExitId = GetStringUuid(); auto notExitBufferInfo = std::make_shared(notExitId, deviceIdx, lifetime, true, TransferType::P2P); bufferInfoList.clear(); bufferInfoList.emplace_back(notExitBufferInfo); - DS_ASSERT_NOT_OK(workerClient2->GetP2PMeta(bufferInfoList, dataInfoStorageList, resp, timeoutMs)); + DS_ASSERT_NOT_OK(workerClient2->GetP2PMeta(bufferInfoList, devBlobStorageList, resp, timeoutMs)); } TEST_F(HeteroClientMockTest, TestRecvRootInfo) @@ -144,7 +144,8 @@ TEST_F(HeteroClientMockTest, TestRecvRootInfo) req.set_dst_device_id(deviceId); req.set_src_client_id(localClientId2); req.set_src_device_id(deviceId); - DS_ASSERT_OK(workerClient1->SendRootInfo(req)); + SendRootInfoRspPb resp; + DS_ASSERT_OK(workerClient1->SendRootInfo(req, resp)); RecvRootInfoReqPb rootInfoReq; rootInfoReq.set_dst_client_id(localClientId1); diff --git a/tests/st/client/object_cache/object_client_replica_test.cpp b/tests/st/client/object_cache/object_client_replica_test.cpp index 90ac850..8d8a8b8 100644 --- a/tests/st/client/object_cache/object_client_replica_test.cpp +++ b/tests/st/client/object_cache/object_client_replica_test.cpp @@ -429,57 +429,6 @@ public: } }; -TEST_F(ObjUpdateToReplicaTest, LEVEL1_TestUpdateToReplicaEnable) -{ - StartWorkerAndWaitReady({ 0, 1 }, FormatString(" -v=2")); - InitClients({ 0, 1 }); - std::vector ids, vals; - int num = 5; - for (int i = 0; i < num; i++) { - std::string objectKey = NewObjectKey(); - ids.emplace_back(objectKey); - std::string data = GenRandomString(10); - vals.emplace_back(data); - for (size_t i = 0; i < clients_.size(); i++) { - std::string objectKey; - DS_ASSERT_OK(clients_[i]->GenerateObjectKey("", objectKey)); - ids.emplace_back(objectKey); - std::string data = GenRandomString(10); - vals.emplace_back(data); - } - } - std::vector failedObjectKeys; - CreateParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; - DS_ASSERT_OK(clients_[idx0]->GIncreaseRef(ids, failedObjectKeys)); - ASSERT_EQ(failedObjectKeys.size(), size_t(0)); - for (size_t i = 0; i < ids.size(); i++) { - DS_ASSERT_OK(clients_[idx0]->Put(ids[i], (uint8_t *)(vals[i].c_str()), vals[i].size(), param)); - std::vector> buffers; - DS_ASSERT_OK(clients_[idx0]->Get({ ids[i] }, 0, buffers)); - ASSERT_TRUE(NotExistsNone(buffers)); - AssertBufferEqual(*buffers[idx0], vals[i]); - } - std::vector> buffers; - DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); - DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); - DS_ASSERT_OK(cluster_->StartNode(WORKER, 1, "-enable_meta_replica=true")); - DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "-enable_meta_replica=true")); - DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); - DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); - WaitReplicaLocationMatch({ 0, 1 }); - for (size_t i = 0; i < ids.size(); i++) { - std::vector> buffers; - DS_ASSERT_OK(clients_[idx0]->Get({ ids[i] }, 0, buffers)); - auto ref = clients_[idx0]->QueryGlobalRefNum(ids[i]); - ASSERT_EQ(ref, 1); - ASSERT_TRUE(NotExistsNone(buffers)); - AssertBufferEqual(*buffers[idx0], vals[i]); - } - failedObjectKeys.clear(); - DS_ASSERT_OK(clients_[idx0]->GDecreaseRef(ids, failedObjectKeys)); - ASSERT_EQ(failedObjectKeys.size(), size_t(0)); -} - class ObjReplicaScaleUpTest : public ObjectClientReplicaTest { public: void SetClusterSetupOptions(ExternalClusterOptions &opts) override diff --git a/tests/st/client/object_cache/object_client_scale_test.cpp b/tests/st/client/object_cache/object_client_scale_test.cpp index 5cec6be..b223712 100644 --- a/tests/st/client/object_cache/object_client_scale_test.cpp +++ b/tests/st/client/object_cache/object_client_scale_test.cpp @@ -55,7 +55,7 @@ DS_DECLARE_string(etcd_address); DS_DECLARE_string(master_address); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); DS_DECLARE_string(log_dir); namespace datasystem { @@ -307,7 +307,7 @@ TEST_F(OCScaleDownTest, TestRefsScaleDownWithoutL2) InitTestClient(2, client2); // client index is 2 std::string objectPrefix = "objecttest_"; std::vector objectKeys; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; uint64_t timeout = 2000; int objNum = 30; std::string value = "data"; @@ -387,7 +387,7 @@ TEST_F(OCScaleUpTest, TestSubscribeScaleUp) StartWorkerAndWaitReady({ 2 }); WaitForScaleUpFinished(60, 3); // worker index is 3, timeout is 60 - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string value = "data"; std::vector failObjects; for (const auto &id : objectKeys) { @@ -409,7 +409,7 @@ TEST_F(OCScaleUpTest, TestNestedObjectScaleUp) std::shared_ptr client, client1; InitTestClient(0, client); InitTestClient(1, client1); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string value = "data"; std::string objectPrefix = "objecttest_"; std::string nestedObjectPrefix = "nestobjecttest_"; @@ -456,7 +456,7 @@ TEST_F(OCScaleUpTest, TestNestedObjectScaleUpRedirect) std::shared_ptr client, client1; InitTestClient(0, client); InitTestClient(1, client1); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string value = "data"; std::string objectPrefix = "objecttest_"; std::string nestedObjectPrefix = "nestobjecttest_"; @@ -664,16 +664,27 @@ public: } void SetObjOnWorker(const int &workerIdx, std::shared_ptr client, const std::string &data, - WriteMode mode, std::vector &objectKey) + std::vector &objectKey) { for (uint32_t i = 0; i < objectKey.size(); ++i) { objectKey[i] = "a_key_hash_to_" + std::to_string(workerHashValue_[workerIdx] - i); - CreateParam param{ .writeMode = mode }; + CreateParam param{}; DS_ASSERT_OK( client->Put(objectKey[i], reinterpret_cast(data.data()), data.size(), param, {})); } } + void SetObjOnWorker(const int &workerIdx, std::shared_ptr client, const std::string &data, + WriteMode mode, std::vector &objectKey) + { + for (uint32_t i = 0; i < objectKey.size(); ++i) { + objectKey[i] = "a_key_hash_to_" + std::to_string(workerHashValue_[workerIdx] - i); + SetParam param{ .writeMode = mode }; + DS_ASSERT_OK( + client->Set(objectKey[i], data, param)); + } + } + Status InitInstanceBase() { hostPort_.ParseString("127.0.0.1:" + std::to_string(GetFreePort())); @@ -747,7 +758,7 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryWorkerScaleDownFinalStageLeaving) InitTestClient(0, client); InitTestClient(1, client1); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string objectPrefix = "objecttest_"; std::vector objectKeys; @@ -798,7 +809,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL2_VoluntaryWorkerScaleDownLeaving) InitTestClient(1, client1); InitTestClient(2, client2); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string value = "data"; std::string objectPrefix = "objecttest_"; std::vector objectKeys; @@ -849,7 +860,7 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryWorkerScaleDownAvailableSpaceRatio40) InitTestClient(1, client1); InitTestClient(2, client2); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string value(kObjectSize, 'x'); std::string objectPrefix = "objecttest_"; std::vector objectKeys; @@ -890,7 +901,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL2_VoluntaryWorkerScaleDown) InitTestClient(0, client); InitTestClient(1, client1); InitTestClient(2, client2); // worker index is 2 - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string value = "data"; std::string objectPrefix = "objecttest_"; std::string nestedObjectPrefix = "nestobjecttest_"; @@ -930,7 +941,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL2_VoluntaryWorkerScaleDown1) InitTestClient(0, client); InitTestClient(1, client1); InitTestClient(2, client2); // worker index is 2 - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string value = "data"; std::string objectPrefix = "objecttest_"; std::string nestedObjectPrefix = "nestobjecttest_"; @@ -966,18 +977,18 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1NoneL2EvictNoCopy) InitTestEtcdInstance(); SetWorkerHashInjection(); GetHashOnWorker(3); // worker num is 3 - std::shared_ptr client, client1, client2; - InitTestClient(0, client); - InitTestClient(1, client1); - InitTestClient(2, client2); // worker index is 2 + std::shared_ptr client, client1, client2; + InitTestKVClient(0, client); + InitTestKVClient(1, client1); + InitTestKVClient(2, client2); // worker index is 2 std::vector objs(400); // obj num is 400 std::string value = "data"; SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE_EVICT, objs); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 - std::vector> buffers; + std::vector buffers; WaitForVoluntaryDownFinished(20, 2, worker0Address_.ToString()); // timeout is 20, left num is 2 - DS_ASSERT_NOT_OK(client2->Get(objs, 0, buffers)); + DS_ASSERT_NOT_OK(client2->Get(objs, buffers)); } TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1NoneL2EvictWithCopy) @@ -987,15 +998,15 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1NoneL2EvictWithCopy) SetWorkerHashInjection(); GetHashOnWorker(3); // worker num is 3 DS_ASSERT_OK(InitInstanceBase()); - std::shared_ptr client, client1, client2; - InitTestClient(0, client); - InitTestClient(1, client1); - InitTestClient(2, client2); // worker index is 2 + std::shared_ptr client, client1, client2; + InitTestKVClient(0, client); + InitTestKVClient(1, client1); + InitTestKVClient(2, client2); // worker index is 2 std::vector objs(400); // obj num is 400 std::string value = "data"; SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE_EVICT, objs); - std::vector> buffers; - DS_ASSERT_OK(client2->Get(objs, 0, buffers)); + std::vector buffers; + DS_ASSERT_OK(client2->Get(objs, buffers)); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 WaitForVoluntaryDownFinished(20, 2, worker0Address_.ToString()); // timeout is 20, left num is 2 @@ -1024,7 +1035,7 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1NoneL2CacheWithCopy) InitTestClient(2, client2); // worker index is 2 std::vector objs(400); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, objs); + SetObjOnWorker(0, client, value, objs); std::vector> buffers; DS_ASSERT_OK(client2->Get(objs, 0, buffers)); client.reset(); @@ -1050,15 +1061,15 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1WriteBackWithCopy) SetWorkerHashInjection(); GetHashOnWorker(3); // worker num is 3 DS_ASSERT_OK(InitInstanceBase()); - std::shared_ptr client, client1, client2; - InitTestClient(0, client); - InitTestClient(1, client1); - InitTestClient(2, client2); // worker index is 2 + std::shared_ptr client, client1, client2; + InitTestKVClient(0, client); + InitTestKVClient(1, client1); + InitTestKVClient(2, client2); // worker index is 2 std::vector objs(400); // obj num is 400 std::string value = "data"; SetObjOnWorker(0, client, value, WriteMode::WRITE_BACK_L2_CACHE, objs); - std::vector> buffers; - DS_ASSERT_OK(client2->Get(objs, 0, buffers)); + std::vector buffers; + DS_ASSERT_OK(client2->Get(objs, buffers)); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 WaitForVoluntaryDownFinished(20, 2, worker0Address_.ToString()); // timeout is 20, left num is 2 @@ -1074,7 +1085,7 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1WriteBackWithCopy) ASSERT_EQ(metaNum, 400); // obj is 400 } -TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1Worker2Failed) +TEST_F(OCVoluntaryScaleDownTest, DISABLED_VoluntaryDownWorker1Worker2Failed) { DS_ASSERT_OK(cluster_->StartOBS()); StartWorkerAndWaitReady({ 0, 1, 2 }); @@ -1082,15 +1093,15 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownWorker1Worker2Failed) SetWorkerHashInjection(); GetHashOnWorker(3); // worker num is 3 DS_ASSERT_OK(InitInstanceBase()); - std::shared_ptr client, client1, client2; - InitTestClient(0, client); - InitTestClient(1, client1); - InitTestClient(2, client2); // worker index is 2 + std::shared_ptr client, client1, client2; + InitTestKVClient(0, client); + InitTestKVClient(1, client1); + InitTestKVClient(2, client2); // worker index is 2 std::vector objs(20); // obj num is 20 std::string value = "data"; SetObjOnWorker(0, client, value, WriteMode::WRITE_THROUGH_L2_CACHE, objs); - std::vector> buffers; - DS_ASSERT_OK(client1->Get(objs, 0, buffers)); + std::vector buffers; + DS_ASSERT_OK(client1->Get(objs, buffers)); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 2)); // worker index is 2 @@ -1122,10 +1133,10 @@ TEST_F(OCVoluntaryScaleDownTest, DISABLED_VoluntaryDownTwoWorkers) InitTestClient(2, client2); // worker index is 2 std::vector objs(400); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, objs); + SetObjOnWorker(0, client, value, objs); std::vector> buffers; DS_ASSERT_OK(client2->Get(objs, 0, buffers)); - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, objs); + SetObjOnWorker(0, client, value, objs); client.reset(); client1.reset(); client2.reset(); @@ -1154,19 +1165,19 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL2_VoluntaryDownWorkerTwoWrokers) SetWorkerHashInjection(); GetHashOnWorker(3); // worker num is 3 DS_ASSERT_OK(InitInstanceBase()); - std::shared_ptr client, client1, client2; - InitTestClient(0, client); - InitTestClient(1, client1); - InitTestClient(2, client2); // worker index is 2 + std::shared_ptr client, client1, client2; + InitTestKVClient(0, client); + InitTestKVClient(1, client1); + InitTestKVClient(2, client2); // worker index is 2 std::vector objs(200); // obj num is 200 std::vector objs1(400); // obj num is 400 std::string value = "data"; SetObjOnWorker(0, client, value, WriteMode::WRITE_THROUGH_L2_CACHE, objs); SetObjOnWorker(1, client, value, WriteMode::WRITE_THROUGH_L2_CACHE, objs1); - std::vector> buffers; - DS_ASSERT_OK(client2->Get(objs, 0, buffers)); + std::vector buffers; + DS_ASSERT_OK(client2->Get(objs, buffers)); buffers.clear(); - DS_ASSERT_OK(client2->Get(objs1, 0, buffers)); + DS_ASSERT_OK(client2->Get(objs1, buffers)); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "ProcessVoluntaryScaledown", "1*call()")); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "BatchMigrateMetadata.delay.left", "1*call(3)")); client1.reset(); @@ -1205,8 +1216,8 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL2_VoluntaryDownMigrateRateLimit) std::vector Objects(count); std::vector Objects1(count); std::string value = std::string(10 * 1024ul, 'a'); - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, Objects); - SetObjOnWorker(2, client, value, WriteMode::NONE_L2_CACHE, Objects1); // worker index is 2 + SetObjOnWorker(0, client, value, Objects); + SetObjOnWorker(2, client, value, Objects1); // worker index is 2 client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 VoluntaryScaleDownInject(2); // worker index is 2 @@ -1239,7 +1250,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL1_TestVoluntaryDownMigrateSmallObjectsData uint64_t count = 400; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1279,7 +1290,7 @@ TEST_F(OCVoluntaryScaleDownTest, TestVoluntaryDownMigrateWhenMetaAddressIsEmpty) uint64_t count = 400; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "WorkerOcService.MigrateData.GetMasterAddr", "1*call()")); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1326,7 +1337,7 @@ TEST_F(OCVoluntaryScaleDownTest, TestVoluntaryDownMigrateToSpillDir) uint64_t count = 400; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1369,7 +1380,7 @@ TEST_F(OCVoluntaryScaleDownTest, TestVoluntaryDownMigrateMeetsNoSpaceError) uint64_t count = 400; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1415,7 +1426,7 @@ TEST_F(OCVoluntaryScaleDownTest, TestMigrateDataAndMeetObjectUpdate) std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; std::string newValue = "Stop the world"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1488,7 +1499,7 @@ TEST_F(OCVoluntaryScaleDownTest, TestMigrateDataAndPartOfObjectsFailed) uint64_t count = 1000; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1529,7 +1540,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL2_VoluntaryDownTwoWorkersAndMigrateData) uint64_t count = 400; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // time interval is 1000 @@ -1568,7 +1579,7 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownTwoWorkersAndMigrateData2) uint64_t count = 400; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); client1.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1598,7 +1609,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL1_VoluntaryDownOnlyOneWorkerLeft) std::string value = "data"; for (uint32_t i = 0; i < noneL2CacheObjects.size(); ++i) { noneL2CacheObjects[i] = "a_key_hash_to_" + std::to_string(i); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; DS_ASSERT_OK(client->Put(noneL2CacheObjects[i], reinterpret_cast(value.data()), value.size(), param, {})); } @@ -1622,7 +1633,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL1_VoluntaryDownOneWorkerWhenDestFailed) uint64_t count = 300; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "ScaleUpTask.NotRunVoluntaryDownTask", "1*sleep(5000)")); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1645,7 +1656,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL1_VoluntaryDownOneWorkerWhenMigrateDataDes uint64_t count = 300; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "VoluntaryScaledown.MigrateData.Delay", "1*sleep(5000)")); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1673,7 +1684,7 @@ TEST_F(OCVoluntaryScaleDownTest, LEVEL2_TestMigrateDataFailAndGet) uint64_t count = 400; std::vector noneL2CacheObjects(count); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, WriteMode::NONE_L2_CACHE, noneL2CacheObjects); + SetObjOnWorker(0, client, value, noneL2CacheObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 sleep(2); // sleep 2 seconds @@ -1707,7 +1718,7 @@ TEST_F(OCVoluntaryScaleDownTest, VoluntaryDownMigrateDataMultiType) std::vector memoryObjects(objCount); const size_t objSize = 5 * 1024ul * 1024ul; // 5 MB std::string value = GenRandomString(objSize); - SetObjOnWorker(0, client, value, {}, memoryObjects); + SetObjOnWorker(0, client, value, memoryObjects); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StandbyWorkerNotSame", "return()")); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "hashring.finishaddnodeinfo", "sleep(2000)")); client.reset(); @@ -1801,7 +1812,7 @@ TEST_F(OCVoluntaryScaleDownNoSpillTest, VoluntaryWorkerMigrateDataFillUp) InitTestClient(2, client2); InitTestClient(client3Index, client3); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string objectPrefix = "objecttest_"; std::vector objectKeys; std::string value(valueSize, 'x'); @@ -1883,7 +1894,7 @@ TEST_F(OCVScaleDownDiskTest, LEVEL1_VoluntaryDownMigrateData) uint64_t objCount = 400; std::vector diskObjects(objCount); // obj num is 400 std::string value = "data"; - SetObjOnWorker(0, client, value, { WriteMode::NONE_L2_CACHE, ConsistencyType::PRAM, CacheType::DISK }, diskObjects); + SetObjOnWorker(0, client, value, { ConsistencyType::PRAM, CacheType::DISK }, diskObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 @@ -1927,7 +1938,7 @@ TEST_F(OCVScaleDownDiskTest, VoluntaryDownMigrateDataMultiType) std::vector memoryObjects(objCount); const size_t objSize = 5 * 1024ul * 1024ul; // 5 MB std::string value = GenRandomString(objSize); - SetObjOnWorker(0, client, value, { WriteMode::NONE_L2_CACHE, ConsistencyType::PRAM, CacheType::DISK }, diskObjects); + SetObjOnWorker(0, client, value, { ConsistencyType::PRAM, CacheType::DISK }, diskObjects); SetObjOnWorker(0, client, value, {}, memoryObjects); client.reset(); VoluntaryScaleDownInject(0); // worker index is 0 diff --git a/tests/st/client/object_cache/object_client_tenant_test.cpp b/tests/st/client/object_cache/object_client_tenant_test.cpp index 2d84673..e80e43e 100644 --- a/tests/st/client/object_cache/object_client_tenant_test.cpp +++ b/tests/st/client/object_cache/object_client_tenant_test.cpp @@ -225,7 +225,7 @@ TEST_F(ObjectClientTenantTest, PutAgainInOtherClient) std::vector failedObjectKeys; DS_ASSERT_OK(client1->GIncreaseRef(objectKeys, failedObjectKeys)); ASSERT_TRUE(failedObjectKeys.empty()); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; DS_ASSERT_OK(client1->Put(objectKey, reinterpret_cast(data.data()), data.length(), param)); client1 = nullptr; @@ -519,7 +519,7 @@ TEST_F(TenantResourceAutoReleaseTest, DISABLED_TestResourceRelease) InitTestClient(1, client2_, [this](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, "tenant1"); }); DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "PreReleaseTenantResourceInfo.IsExpired", "call(4000)")); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string data1 = GenRandomString(SHM_SIZE); int waitTimeS = 15; @@ -538,7 +538,7 @@ TEST_F(TenantResourceAutoReleaseTest, TestResourceNotRelease) InitTestClient(0, client1_, [this](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, "tenant1"); }); InitTestClient(1, client2_, [this](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, "tenant1"); }); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; int waitTimeS = 1; int loopTimes = 5; diff --git a/tests/st/client/object_cache/object_client_test.cpp b/tests/st/client/object_cache/object_client_test.cpp index 3559b9c..6320508 100644 --- a/tests/st/client/object_cache/object_client_test.cpp +++ b/tests/st/client/object_cache/object_client_test.cpp @@ -260,14 +260,14 @@ TEST_F(ObjectClientTest, DISABLED_LEVEL1_TestAsynPublishAndShutdown) TEST_F(ObjectClientTest, CreateShmBufferSuccess) { // Shared memory, non-Keep scenario - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; CreateBufferSuccess(SHM_SIZE, param); } TEST_F(ObjectClientTest, CreateNonShmBufferSuccess) { // Non-shared memory, Keep scenario - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; CreateBufferSuccess(NON_SHM_SIZE, param); } @@ -421,7 +421,7 @@ TEST_F(ObjectClientTest, LEVEL1_InvalidateBufferAndRePublishSuccess) const int32_t timeoutMs = 1'000; InitTestClient(0, client, timeoutMs); int dataSize = SHM_SIZE; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; DS_ASSERT_OK(client->Create(objectKey, dataSize, param, buffer)); ASSERT_NE(buffer, nullptr); ASSERT_EQ(dataSize, buffer->GetSize()); @@ -443,7 +443,7 @@ TEST_F(ObjectClientTest, LEVEL2_InvalidateBufferAndRemoteGetFailed) std::shared_ptr client; InitTestClient(0, client); int dataSize = SHM_SIZE; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; DS_ASSERT_OK(client->Create(objectKey, dataSize, param, buffer)); ASSERT_NE(buffer, nullptr); ASSERT_EQ(dataSize, buffer->GetSize()); @@ -466,7 +466,7 @@ TEST_F(ObjectClientTest, InvalidateBufferAfterRemoteGet) InitTestClient(1, client2); std::string objectKey = NewObjectKey(); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; int size = 10; std::shared_ptr buffer; @@ -652,7 +652,7 @@ TEST_F(ObjectClientTest, LatestObjectGetTest) std::shared_ptr workerOBuffer; int bufferSize = SHM_SIZE; // Causal consistency triggers synchronous invalidation. - CreateParam createParam = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam = { .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client1->Create(objectKey, bufferSize, createParam, workerOBuffer)); std::string worker0_data = GenRandomString(bufferSize); DS_ASSERT_OK(workerOBuffer->MemoryCopy(const_cast(worker0_data.data()), worker0_data.length())); @@ -727,7 +727,7 @@ TEST_F(ObjectClientTest, ExpireObjectUpdate) int bufferSize = SHM_SIZE; // Causal consistency triggers synchronous invalidation. - CreateParam createParam = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam = { .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK(client1->Create(objectKey, bufferSize, createParam, workerOBuffer)); std::string worker0_data = GenRandomString(SHM_SIZE); DS_ASSERT_OK(workerOBuffer->MemoryCopy(const_cast(worker0_data.data()), worker0_data.length())); @@ -924,7 +924,7 @@ TEST_F(ObjectClientTest, TestConsistencyPRAM) InitTestClient(0, client); // create obj - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::PRAM }; + CreateParam param{ .consistencyType = ConsistencyType::PRAM }; std::string objectKey = NewObjectKey(); std::string data = GenRandomString(SHM_SIZE); std::shared_ptr buffer; @@ -988,7 +988,7 @@ TEST_F(ObjectClientTest, TestConsistencyCAUSAL) InitTestClient(0, client); // client create obj and publish - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; std::string objectKey = NewObjectKey(); std::string data = GenRandomString(SHM_SIZE); std::shared_ptr buffer; @@ -1039,7 +1039,7 @@ TEST_F(ObjectClientTest, TestObjectsPRAM) InitTestClient(0, client); // create obj - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::PRAM }; + CreateParam param{ .consistencyType = ConsistencyType::PRAM }; std::string objectKey = NewObjectKey(); std::string data = GenRandomString(SHM_SIZE); std::shared_ptr buffer; @@ -1102,7 +1102,7 @@ TEST_F(ObjectClientTest, TestTwoObjectsRRAM) InitTestClient(0, client); // client create obj1 - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::PRAM }; + CreateParam param{ .consistencyType = ConsistencyType::PRAM }; std::string objectKey = NewObjectKey(); std::string data = GenRandomString(SHM_SIZE); std::shared_ptr buffer; @@ -1331,7 +1331,7 @@ TEST_F(ObjectClientTest, AsyncGetAndDelete) const int loopTimes = 50; for (int i = 0; i < loopTimes; i++) { std::string objectKey = NewObjectKey(); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; char data[] = { '1', '2', '3' }; std::shared_ptr buffer; DS_ASSERT_OK(client1->Create(objectKey, sizeof(data), param, buffer)); @@ -1367,7 +1367,7 @@ TEST_F(ObjectClientTest, TestDeleteAfterCreate) std::shared_ptr client; InitTestClient(0, client); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; std::string objectKey = NewObjectKey(); std::string data = "123"; std::shared_ptr buffer; @@ -1388,7 +1388,7 @@ TEST_F(ObjectClientTest, GRefAsyncPublishAndDelete) const int loopTimes = 100; for (int i = 0; i < loopTimes; i++) { std::string objectKey = NewObjectKey(); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; char data[] = { '1', '2', '3' }; std::shared_ptr buffer; DS_ASSERT_OK(client->Create(objectKey, sizeof(data), param, buffer)); @@ -1516,7 +1516,7 @@ TEST_F(ObjectClientTest, TestPramMultiClientPublishGet) } // client1 create obj - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; std::string objectKey = NewObjectKey(); std::string data1 = GenRandomString(SHM_SIZE); std::shared_ptr buffer1; @@ -1791,7 +1791,7 @@ TEST_F(ObjectClientTest, PutDifferentMetaSize) std::string data1 = GenRandomString(size1); std::string data2 = GenRandomString(size2); std::string objectKey = GetStringUuid(); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; DS_ASSERT_OK( client1->Put(objectKey, reinterpret_cast(const_cast(data1.data())), data1.size(), param)); diff --git a/tests/st/client/object_cache/object_client_with_token_test.cpp b/tests/st/client/object_cache/object_client_with_token_test.cpp index 377ed61..dc874dd 100644 --- a/tests/st/client/object_cache/object_client_with_token_test.cpp +++ b/tests/st/client/object_cache/object_client_with_token_test.cpp @@ -221,7 +221,7 @@ void ObjectClientWithTokenTest::GetMultiObjectSuccess(int64_t size) TEST_F(ObjectClientWithTokenTest, CreateShmBufferSuccess) { // Shared memory, non-Keep scenario - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; CreateBufferSuccess(SHM_SIZE, param); } diff --git a/tests/st/client/object_cache/oc_client_dist_master_test.cpp b/tests/st/client/object_cache/oc_client_dist_master_test.cpp index ef6a357..c623efd 100644 --- a/tests/st/client/object_cache/oc_client_dist_master_test.cpp +++ b/tests/st/client/object_cache/oc_client_dist_master_test.cpp @@ -775,7 +775,7 @@ TEST_F(OCClientDistMasterTest, LEVEL1_AsyncDeleteNestedInOtherMaster) std::string objectKey3 = GenRandomString(); std::string value1 = GenRandomString(); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE }; + CreateParam param{}; uint64_t timeout = 4; // GIncr objectKey1/objectKey2/objectKey3, put DS_ASSERT_OK(client1->GIncreaseRef({ objectKey1 }, failObjects)); diff --git a/tests/st/client/object_cache/oc_client_publish_test.cpp b/tests/st/client/object_cache/oc_client_publish_test.cpp index 340b096..b839c64 100644 --- a/tests/st/client/object_cache/oc_client_publish_test.cpp +++ b/tests/st/client/object_cache/oc_client_publish_test.cpp @@ -111,8 +111,7 @@ protected: // Attention: size_t shmSz = 20 * 1024ul * 1024ul; is not ok because of jemalloc. size_t shmSz = 20 * 1000ul * 1000ul; size_t nonShmSz = 100'000; - CreateParam createParam = { .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE, - .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam = { .consistencyType = ConsistencyType::CAUSAL }; int64_t timeOut = 1'000; }; @@ -407,7 +406,7 @@ protected: std::string objKey = "objKey"; size_t sz = 10'000'000; size_t nonShmSz = 100; - CreateParam createParam = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam = { .consistencyType = ConsistencyType::CAUSAL }; int64_t timeOut = 1'000; }; @@ -889,8 +888,8 @@ TEST_F(OCClientPublishTest, CreateUpdatePubAfterSeal) int64_t dataSz = 100'000; std::shared_ptr bufferPtr1; std::shared_ptr bufferPtr2; - CreateParam createParam1 = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; - CreateParam createParam2 = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam1 = { .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam2 = { .consistencyType = ConsistencyType::CAUSAL }; ASSERT_EQ(Status::OK(), client1->Create(objKey, dataSz, createParam1, bufferPtr1)); ASSERT_EQ(Status::OK(), client1->Create(objKey, dataSz, createParam2, bufferPtr2)); ASSERT_EQ(Status::OK(), bufferPtr1->Seal()); @@ -903,7 +902,7 @@ TEST_F(OCClientPublishTest, ConcurrentSealAfterPub) InitTestClient(0, client1); int64_t dataSz1 = 100'000; int64_t dataSz2 = 200'000; - CreateParam createParam1 = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam1 = { .consistencyType = ConsistencyType::CAUSAL }; std::shared_ptr bufferPtr1; std::shared_ptr bufferPtr2; ASSERT_EQ(Status::OK(), client1->Create(objKey, dataSz1, createParam1, bufferPtr1)); @@ -951,7 +950,7 @@ TEST_F(OCClientPublishTest, CreateUpdateSingleWorkerSealVarySzSeal) InitTestClient(2, client3); int64_t dataSz = 100'000; int64_t bigDataSz = 600'000; - CreateParam createParam = { .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam = { .consistencyType = ConsistencyType::CAUSAL }; std::shared_ptr bufferPtr1; std::shared_ptr bufferPtr2; diff --git a/tests/st/client/object_cache/oc_client_ref_test.cpp b/tests/st/client/object_cache/oc_client_ref_test.cpp index c4642f7..ca3efc1 100644 --- a/tests/st/client/object_cache/oc_client_ref_test.cpp +++ b/tests/st/client/object_cache/oc_client_ref_test.cpp @@ -231,7 +231,7 @@ TEST_F(OCClientRefTest, TestClientDisconnectFailed) InitTestClient(1, client2); // the worker index is 1 int timeout = 10; std::string data = "111111111"; - CreateParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + CreateParam param{}; std::string testcasename = "trouble_ticket_"; std::vector failObjects; int objNum = 10; @@ -515,7 +515,7 @@ TEST_F(OCClientRefTest, TestNoShmNested) { std::shared_ptr client; InitTestClient(0, client); - CreateParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + CreateParam param{}; std::string objectKey1 = NewObjectKey(); std::string objectKey2 = NewObjectKey(); std::string objectKey3 = NewObjectKey(); @@ -562,7 +562,7 @@ TEST_F(OCClientRefTest, TestPutNested) { std::shared_ptr client; InitTestClient(0, client); - CreateParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + CreateParam param{}; std::string objectKey1 = NewObjectKey(); std::string objectKey2 = NewObjectKey(); std::string objectKey3 = NewObjectKey(); @@ -656,7 +656,7 @@ TEST_F(OCClientRefTest, TestPutNested3) std::shared_ptr client2; InitTestClient(0, client); InitTestClient(1, client2); - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; std::string objectKey1 = "key_1"; std::string objectKey2 = "key_2"; std::string objectKey3 = "key_3"; @@ -905,8 +905,8 @@ TEST_F(OCClientRefTest, NestedChildTest) InitTestClient(1, client2); std::shared_ptr client3; InitTestClient(2, client3); - CreateParam param1{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; - CreateParam param2{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + CreateParam param1{}; + CreateParam param2{}; std::string objectKey1 = NewObjectKey(); std::string objectKey2 = NewObjectKey(); diff --git a/tests/st/client/object_cache/oc_service_disable_test.cpp b/tests/st/client/object_cache/oc_service_disable_test.cpp index 0902499..3f74cab 100644 --- a/tests/st/client/object_cache/oc_service_disable_test.cpp +++ b/tests/st/client/object_cache/oc_service_disable_test.cpp @@ -21,6 +21,7 @@ #include "datasystem/common/util/random_data.h" #include "datasystem/object_client.h" #include "datasystem/kv_client.h" +#include "datasystem/stream_client.h" #include #include @@ -47,6 +48,17 @@ void OcOp(const std::shared_ptr &client, bool success) } } +void ScOp(const std::shared_ptr &client, bool success) +{ + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + if (success) { + DS_ASSERT_OK(client->Subscribe("test1", config, consumer)); + } else { + ASSERT_EQ(client->Subscribe("test1", config, consumer).GetCode(), StatusCode::K_RUNTIME_ERROR); + } +} + void KvOp(const std::shared_ptr &client, bool success) { std::string key = "ikun_again"; @@ -65,6 +77,7 @@ public: opts.numWorkers = 1; opts.numRpcThreads = 0; opts.numEtcd = 1; + opts.workerGflagParams = "-sc_stream_socket_num=1 -sc_regular_socket_num=1"; } }; @@ -79,6 +92,9 @@ TEST_F(OcServiceDisableTest, TestInit) auto kVClient = std::make_shared(opts); DS_ASSERT_OK(kVClient->Init()); KvOp(kVClient, false); + auto scClient = std::make_shared(opts); + DS_ASSERT_OK(scClient->Init()); + ScOp(scClient, true); } class ScServiceDisableTest : public OCClientCommon { @@ -101,6 +117,9 @@ TEST_F(ScServiceDisableTest, TestInit) auto kVClient = std::make_shared(opts); DS_ASSERT_OK(kVClient->Init()); KvOp(kVClient, true); + auto scClient = std::make_shared(opts); + DS_ASSERT_OK(scClient->Init()); + ScOp(scClient, false); } class CommonServiceTest : public OCClientCommon { @@ -109,6 +128,7 @@ public: { opts.numWorkers = 1; opts.numEtcd = 1; + opts.workerGflagParams = "-sc_stream_socket_num=1 -sc_regular_socket_num=1"; } }; @@ -123,6 +143,9 @@ TEST_F(CommonServiceTest, TestInit) auto kVClient = std::make_shared(opts); DS_ASSERT_OK(kVClient->Init()); KvOp(kVClient, true); + auto scClient = std::make_shared(opts); + DS_ASSERT_OK(scClient->Init()); + ScOp(scClient, true); } class CommonServiceDisableTest : public OCClientCommon { @@ -147,6 +170,9 @@ TEST_F(CommonServiceDisableTest, TestInit) auto kVClient = std::make_shared(opts); DS_ASSERT_OK(kVClient->Init()); KvOp(kVClient, false); + auto scClient = std::make_shared(opts); + DS_ASSERT_OK(scClient->Init()); + ScOp(scClient, false); } } // namespace st } // namespace datasystem diff --git a/tests/st/client/object_cache/shm_threshold_test.cpp b/tests/st/client/object_cache/shm_threshold_test.cpp index d93d9f7..cf57189 100644 --- a/tests/st/client/object_cache/shm_threshold_test.cpp +++ b/tests/st/client/object_cache/shm_threshold_test.cpp @@ -44,8 +44,7 @@ protected: std::string objKey0 = "objKey0"; std::string objKey1 = "objKey1"; size_t shmSz = 20 * 1000ul * 1000ul; - CreateParam createParam = { .writeMode = WriteMode::NONE_L2_CACHE, - .consistencyType = ConsistencyType::CAUSAL }; + CreateParam createParam = { .consistencyType = ConsistencyType::CAUSAL }; }; TEST_F(ShmThresholdTest, DISABLED_LEVEL1_AllocationFailedForThreshold) diff --git a/tests/st/client/object_cache/urma_object_client_test.cpp b/tests/st/client/object_cache/urma_object_client_test.cpp index 2c400a0..2dda532 100644 --- a/tests/st/client/object_cache/urma_object_client_test.cpp +++ b/tests/st/client/object_cache/urma_object_client_test.cpp @@ -60,6 +60,9 @@ public: opts.workerGflagParams += " -arena_per_tenant=1 -enable_urma=true "; #else opts.workerGflagParams += " -arena_per_tenant=1 -enable_urma=false "; +#endif +#ifdef URMA_OVER_UB + opts.workerGflagParams += " -urma_mode=UB "; #endif } @@ -75,6 +78,22 @@ public: } }; +class UrmaObjectClientAuthorizationTest : public UrmaObjectClientTest { + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + UrmaObjectClientTest::SetClusterSetupOptions(opts); + opts.workerGflagParams += " -authorization_enable=true"; + opts.systemAccessKey = accessKey_; + opts.systemSecretKey = secretKey_; + } + +protected: + std::string tenantId1_ = "tenant1"; + std::string tenantId2_ = "tenant2"; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + TEST_F(UrmaObjectClientTest, UrmaPutGetDeleteShmTest) { std::shared_ptr client; @@ -93,7 +112,8 @@ TEST_F(UrmaObjectClientTest, UrmaPutGetDeleteShmTest) ASSERT_EQ(failedObjectKeys.size(), size_t(0)); } -TEST_F(UrmaObjectClientTest, TestBatchRemoteGet) +// bus error happen in aarch64 +TEST_F(UrmaObjectClientTest, DISABLED_TestBatchRemoteGet1) { // Test that the batch get path in urma case is working as expected. std::shared_ptr client1; @@ -129,6 +149,131 @@ TEST_F(UrmaObjectClientTest, TestBatchRemoteGet) } } +TEST_F(UrmaObjectClientTest, TestBatchRemoteGet2) +{ + // Test specifically batch get for 8KB * 1024, so it needs multiple batches when allocating in URMA case. + std::shared_ptr client1; + std::shared_ptr client2; + InitTestKVClient(0, client1); + InitTestKVClient(1, client2); + + const int numKV = 1024; + const uint64_t objectSize = 8 * 1024; + std::vector keys; + std::vector values; + std::vector valuesForVer; + std::vector> kvPairs; + for (int i = 0; i < numKV; i++) { + keys.emplace_back("keys_" + std::to_string(i)); + valuesForVer.emplace_back(GenRandomString(objectSize)); + values.emplace_back(valuesForVer.back()); + } + + std::vector failedKeys; + DS_ASSERT_OK(client2->MSet(keys, values, failedKeys)); + ASSERT_TRUE(failedKeys.empty()); + + std::vector valuesGet; + std::vector> buffers; + DS_ASSERT_OK(client1->Get(keys, valuesGet)); + DS_ASSERT_OK(client1->Get(keys, buffers)); + ASSERT_TRUE(NotExistsNone(valuesGet)); + ASSERT_EQ(keys.size(), valuesGet.size()); + ASSERT_EQ(keys.size(), buffers.size()); + + for (size_t i = 0; i < keys.size(); i++) { + ASSERT_EQ(valuesForVer[i], std::string(valuesGet[i].data(), valuesGet[i].size())); + ASSERT_EQ(valuesForVer[i], + std::string(reinterpret_cast(buffers[i]->ImmutableData()), buffers[i]->GetSize())); + } +} + +TEST_F(UrmaObjectClientTest, TestBatchRemoteGet3) +{ + // Test that with big objects (>= 1M), the logic batches all the other small objects, + // and allocates memory separate for the big objects. + std::shared_ptr client1; + std::shared_ptr client2; + InitTestKVClient(0, client1); + InitTestKVClient(1, client2); + + const int numKV = 1024; + const uint64_t objectSize = 8 * 1024; + std::vector keys; + std::vector values; + std::vector> kvPairs; + const uint64_t bigSize = 1024 * 1024; + keys.emplace_back("big_data1"); + values.emplace_back(GenRandomString(bigSize)); + const int bigIndex = 300; + for (int i = 0; i < numKV; i++) { + keys.emplace_back("keys_" + std::to_string(i)); + values.emplace_back(GenRandomString(objectSize)); + if (i == bigIndex) { + keys.emplace_back("big_data2"); + values.emplace_back(GenRandomString(bigSize)); + } + } + keys.emplace_back("big_data3"); + values.emplace_back(GenRandomString(bigSize)); + + for (size_t i = 0; i < keys.size(); i++) { + DS_ASSERT_OK(client2->Set(keys[i], values[i])); + } + + std::vector valuesGet; + std::vector> buffers; + DS_ASSERT_OK(client1->Get(keys, valuesGet)); + DS_ASSERT_OK(client1->Get(keys, buffers)); + ASSERT_TRUE(NotExistsNone(valuesGet)); + ASSERT_EQ(keys.size(), valuesGet.size()); + ASSERT_EQ(keys.size(), buffers.size()); + + for (size_t i = 0; i < keys.size(); i++) { + ASSERT_EQ(values[i], std::string(valuesGet[i].data(), valuesGet[i].size())); + ASSERT_EQ(values[i], + std::string(reinterpret_cast(buffers[i]->ImmutableData()), buffers[i]->GetSize())); + } +} + +TEST_F(UrmaObjectClientAuthorizationTest, TestBatchRemoteGet4) +{ + // Test that with tenant authorization enabled, + // the logic still batches the allocation, and the tenant id is selected correctly. + std::shared_ptr client1; + std::shared_ptr client2; + InitTestKVClient(0, client1, [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId2_); }); + InitTestKVClient(1, client2, [&](ConnectOptions &opts) { opts.SetAkSkAuth(accessKey_, secretKey_, tenantId2_); }); + + const int numKV = 1024; + const uint64_t objectSize = 8 * 1024; + std::vector keys; + std::vector values; + std::vector> kvPairs; + for (int i = 0; i < numKV; i++) { + keys.emplace_back("keys_" + std::to_string(i)); + values.emplace_back(GenRandomString(objectSize)); + } + + for (size_t i = 0; i < keys.size(); i++) { + DS_ASSERT_OK(client2->Set(keys[i], values[i])); + } + + std::vector valuesGet; + std::vector> buffers; + DS_ASSERT_OK(client1->Get(keys, valuesGet)); + DS_ASSERT_OK(client1->Get(keys, buffers)); + ASSERT_TRUE(NotExistsNone(valuesGet)); + ASSERT_EQ(keys.size(), valuesGet.size()); + ASSERT_EQ(keys.size(), buffers.size()); + + for (size_t i = 0; i < keys.size(); i++) { + ASSERT_EQ(values[i], std::string(valuesGet[i].data(), valuesGet[i].size())); + ASSERT_EQ(values[i], + std::string(reinterpret_cast(buffers[i]->ImmutableData()), buffers[i]->GetSize())); + } +} + TEST_F(UrmaObjectClientTest, TestBatchRemoteGetErrorCode1) { // Test the error handling in urma batch get logic. diff --git a/tests/st/client/stream_cache/client_crash_test.cpp b/tests/st/client/stream_cache/client_crash_test.cpp new file mode 100644 index 0000000..61b6cf2 --- /dev/null +++ b/tests/st/client/stream_cache/client_crash_test.cpp @@ -0,0 +1,1815 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "datasystem/common/util/format.h" +#include "datasystem/stream_client.h" + +#include +#include + +#include "common.h" +#include "datasystem/utils/status.h" +#include "sc_client_common.h" +#include "common/stream_cache/element_generator.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/common/inject/inject_point.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class ChildPidWrapper { +public: + ChildPidWrapper(pid_t pid) : pid_(pid) + { + } + ~ChildPidWrapper() + { + int status; + waitpid(pid_, &status, 0); + } + + ChildPidWrapper(const ChildPidWrapper &) = delete; + ChildPidWrapper &operator=(const ChildPidWrapper &) = delete; + +private: + pid_t pid_; +}; + +template +std::unique_ptr RunInChildProcess(Func &&func) +{ + pid_t pid = fork(); + if (pid == 0) { + func(); + return nullptr; + } + return std::make_unique(pid); +} + +constexpr int K_TWENTY = 20; +class ClientCrashTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 2; + opts.masterIdx = 0; + opts.numRpcThreads = 0; + opts.numEtcd = 1; + opts.workerGflagParams = "-client_dead_timeout_s=2 -v=2 -log_monitor=true"; + SCClientCommon::SetClusterSetupOptions(opts); + } + +protected: + Status InitClient(int index, std::shared_ptr &client) + { + InitStreamClient(index, client); + return Status::OK(); + } + + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +class ClientSC1 { +public: + explicit ClientSC1(std::string streamName) : streamName_(std::move(streamName)) + { + } + ~ClientSC1() = default; + + Status InitTestClient(const std::string &ip, const int &port, int timeout = 60000); + + Status CreateProducer(std::shared_ptr &producer); + + Status CreateProducer(std::shared_ptr &producer, ProducerConf conf); + + Status CreateProducer(std::shared_ptr &producer, int64_t delayFlushTime); + + Status Subscribe(const std::string &subName, std::shared_ptr &consumer); + + Status QueryTotalProducerNum(uint64_t &totalProducerNum); + + Status QueryTotalConsumerNum(uint64_t &totalConsumerNum); + + Status Shutdown() + { + return client_->ShutDown(); + } + +private: + std::string streamName_; + std::unique_ptr client_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +Status ClientSC1::InitTestClient(const std::string &ip, const int &port, int timeout) +{ + ConnectOptions connectOptions; + connectOptions.host = ip; + connectOptions.port = port; + connectOptions.connectTimeoutMs = timeout; + connectOptions.SetAkSkAuth(accessKey_, secretKey_, ""); + client_ = std::make_unique(connectOptions); + return client_->Init(); +} + +Status ClientSC1::CreateProducer(std::shared_ptr &producer) +{ + const uint64_t maxStreamSize = 20 * 1024 * 1024; // The max size of stream page is 20M + const int64_t pageSize = 4 * 1024; // The size of page is 4096 bytes + ProducerConf conf; + conf.delayFlushTime = -1; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + return client_->CreateProducer(streamName_, producer, conf); +} + +Status ClientSC1::CreateProducer(std::shared_ptr &producer, ProducerConf conf) +{ + return client_->CreateProducer(streamName_, producer, conf); +} + +Status ClientSC1::CreateProducer(std::shared_ptr &producer, int64_t delayFlushTime) +{ + const uint64_t maxStreamSize = 20 * 1024 * 1024; // The max size of stream page is 20M + const int64_t pageSize = 4 * 1024; // The size of page is 4096 bytes + ProducerConf conf; + conf.delayFlushTime = delayFlushTime; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + return client_->CreateProducer(streamName_, producer, conf); +} + +Status ClientSC1::Subscribe(const std::string &subName, std::shared_ptr &consumer) +{ + SubscriptionConfig config(std::move(subName), SubscriptionType::STREAM); + return client_->Subscribe(streamName_, config, consumer); +} + +Status ClientSC1::QueryTotalProducerNum(uint64_t &totalProducerNum) +{ + return client_->QueryGlobalProducersNum(streamName_, totalProducerNum); +} + +Status ClientSC1::QueryTotalConsumerNum(uint64_t &totalConsumerNum) +{ + return client_->QueryGlobalConsumersNum(streamName_, totalConsumerNum); +} + +TEST_F(ClientCrashTest, DISABLED_TestProdClientCloseWhileReceiveSameHost) +{ + FLAGS_v = SC_INTERNAL_LOG_LEVEL; + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + // start client 1 + LOG(INFO) << "start create client 1"; + auto client1 = std::make_unique("SameHostProdClientClose"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + + // start client 2 + LOG(INFO) << "start create client 2"; + auto client2 = std::make_unique("SameHostProdClientClose"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer1; + DS_ASSERT_OK(client2->Subscribe("subName1", consumer1)); + + // Wait till consumer finds the producer. + uint64_t producerCount = 0; + while (producerCount == 0) { + DS_ASSERT_OK(client2->QueryTotalProducerNum(producerCount)); + } + + ThreadPool threadPool(1); + + auto res = threadPool.Submit([consumer1]() { + // Create a pending receive call. + std::vector outElements; + return consumer1->Receive(1, 20 * 1000, outElements); + }); + + // wait a safer period to let the other thread start. + std::this_thread::sleep_for(std::chrono::seconds(1)); + // client 1 abnormal close + DS_ASSERT_OK(client1->Shutdown()); + auto rc = res.get(); + ASSERT_EQ(rc.GetCode(), K_SC_PRODUCER_NOT_FOUND); +} + +TEST_F(ClientCrashTest, DISABLED_TestProdClientCloseWhileReceiveDiffHost) +{ + ThreadPool threadPool(1); + + threadPool.Submit([this]() { + // start client 2 + HostPort workerAddress; + LOG(INFO) << "start create client 2"; + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress)); + auto client2 = std::make_unique("DiffHostProdClientClose"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer1; + DS_ASSERT_OK(client2->Subscribe("subName1", consumer1)); + + // Wait till consumer finds the producer. + uint64_t producerCount = 0; + while (producerCount == 0) { + DS_ASSERT_OK(client2->QueryTotalProducerNum(producerCount)); + } + + // Create a pending receive call. + std::vector outElements; + DS_ASSERT_OK(consumer1->Receive(1, 16 * 1000, outElements)); + ASSERT_EQ(outElements.size(), (size_t)1); + Status rc = consumer1->Receive(1, 16 * 1000, outElements); + DS_ASSERT_NOT_OK(rc); + ASSERT_EQ(rc.GetCode(), K_SC_PRODUCER_NOT_FOUND); + }); + + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + // start client 1 + LOG(INFO) << "start create client 1"; + auto client1 = std::make_unique("DiffHostProdClientClose"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + + // Wait till consumer finds the producer. + uint64_t consumerCount = 0; + while (consumerCount == 0) { + DS_ASSERT_OK(client1->QueryTotalConsumerNum(consumerCount)); + } + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer1->Send(element)); + + // wait a safer period to let the other thread call the second Receive. + std::this_thread::sleep_for(std::chrono::seconds(1)); + + DS_ASSERT_OK(client1->Shutdown()); +} + +TEST_F(ClientCrashTest, TestClientCrashWhenCloseProducer) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress)); + auto client = std::make_unique("Client1CrashWhenCloseProducer"); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port())); + + std::vector> producerList; + + const int producerCount = 50; + for (int i = 0; i < producerCount; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(producer)); + producerList.emplace_back(producer); + } + + std::thread t1([&producerList] { + for (auto producer : producerList) { + Status rc = producer->Close(); + if (rc.IsError()) { + LOG(ERROR) << "Close failed:" << rc.GetMsg(); + } + } + }); + + client = nullptr; + t1.join(); + + // Producers are implicitly closed, calling close again will be a no-op and return OK + for (auto producer : producerList) { + DS_ASSERT_OK(producer->Close()); + } + + // connect to worker again. + auto client2 = std::make_unique("Client2CrashWhenCloseProducer"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); +} + +TEST_F(ClientCrashTest, LEVEL2_TestClientCrashWhenCreateProducer) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + DS_ASSERT_OK(cluster_->SetInjectAction( + ClusterNodeType::WORKER, 0, "ClientWorkerSCServiceImpl.CloseProducerImplForceClose.sleep", "1*sleep(1000)")); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + // Client 1, CreateProducer, then crash + if (pid == 0) { + auto client1 = std::make_unique("testCrashWhenCreateProd"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + // Client 2, CreateProducer, then sleep before adding to local + auto client2 = std::make_unique("testCrashWhenCreateProd"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + DS_ASSERT_OK(cluster_->SetInjectAction( + ClusterNodeType::WORKER, 0, "ClientWorkerSCServiceImpl.CreateProducerImpl.WaitBeforeAdd", "1*sleep(2000)")); + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + + // Wait for cleanup to finish. its set to 2secs above + const int sleepTime = 3; + sleep(sleepTime); + + int status; + waitpid(pid, &status, 0); + // When we close producer created by client2 it should not error out + DS_ASSERT_OK(producer2->Close()); +} + +TEST_F(ClientCrashTest, TestClientCrashWhenCloseConsumer) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + auto client = std::make_unique("CrashWhenCloseProd"); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port())); + + std::vector> consumerList; + + const int consumerCount = 50; + for (int i = 0; i < consumerCount; i++) { + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("sub-" + std::to_string(i), consumer)); + consumerList.emplace_back(consumer); + } + + std::thread t1([&consumerList] { + for (auto consumer : consumerList) { + Status rc = consumer->Close(); + if (rc.IsError()) { + LOG(ERROR) << "Close failed:" << rc.GetMsg(); + } + } + }); + + client = nullptr; + t1.join(); + + // Consumers are implicitly closed, calling close again will be a no-op and return OK + for (auto consumer : consumerList) { + DS_ASSERT_OK(consumer->Close()); + } + + // connect to worker again. + auto client2 = std::make_unique("CrashWhenCloseProd"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); +} + +TEST_F(ClientCrashTest, TestProducerCrash1) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("ProducerCrash1"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("ProducerCrash1"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer2)); + + int status; + waitpid(pid, &status, 0); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + datasystem::inject::Clear("producer_obtained_lock"); + // Consumer is always lock free and can read + std::vector out; + const uint32_t timeoutMs = 100; + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 0); + // Worker should be able to clean up the lock + const int64_t TWO_MINUTES = 120'000; + DS_ASSERT_OK(producer2->Send(element, TWO_MINUTES)); + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(consumer2->Close()); +} + +TEST_F(ClientCrashTest, TestDownLevelProducerCrash1) +{ + DS_ASSERT_OK(datasystem::inject::Set("ClientBaseImpl.force_downlevel_client", "call()")); + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("DownLevelProducerCrash1"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("DownLevelProducerCrash1"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer2)); + + int status; + waitpid(pid, &status, 0); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + datasystem::inject::Clear("producer_obtained_lock"); + // Consumer is always lock free and can read + std::vector out; + const uint32_t timeoutMs = 100; + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 0); + // Worker should be able to clean up the lock + const int64_t TWO_MINUTES = 120'000; + DS_ASSERT_OK(producer2->Send(element, TWO_MINUTES)); + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(consumer2->Close()); +} + +TEST_F(ClientCrashTest, TestProducerCrash2) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("ProducerCrash2"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Fake a crash point within producer after it holds the lock + // and update the slot count + datasystem::inject::Set("producer_update_pending_slot_count_holding_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("ProducerCrash2"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer2)); + int status; + waitpid(pid, &status, 0); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + datasystem::inject::Clear("producer_update_pending_slot_count_holding_lock"); + // Consumer is always lock free and can read + std::vector out; + const uint32_t timeoutMs = 100; + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 0); + // Worker should be able to clean up the lock + const int64_t TWO_MINUTES = 120'000; + DS_ASSERT_OK(producer2->Send(element, TWO_MINUTES)); + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(consumer2->Close()); +} + +TEST_F(ClientCrashTest, TestProducerCrash3) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("ProducerCrash3"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Fake a crash point within producer after it holds the lock + // and update the free space + datasystem::inject::Set("producer_update_free_space", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("ProducerCrash3"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer2)); + + int status; + waitpid(pid, &status, 0); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + datasystem::inject::Clear("producer_update_free_space"); + // Consumer is always lock free and can read + std::vector out; + const uint32_t timeoutMs = 100; + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 0); + // Worker should be able to clean up the lock + const int64_t TWO_MINUTES = 120'000; + DS_ASSERT_OK(producer2->Send(element, TWO_MINUTES)); + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(consumer2->Close()); +} + +TEST_F(ClientCrashTest, DISABLED_TestProducerCrash4) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("ProducerCrash4"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Insert two rows first. + // Fake a crash point within producer after it holds the lock + // and update the free space + DS_ASSERT_OK(producer1->Send(element)); + DS_ASSERT_OK(producer1->Send(element)); + datasystem::inject::Set("producer_update_slot_directory", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("ProducerCrash4"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer2)); + + int status; + waitpid(pid, &status, 0); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + datasystem::inject::Clear("producer_update_slot_directory"); + // Consumer is always lock free and can read + std::vector out; + const uint32_t timeoutMs = 100; + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + const int32_t expected = 2; + DS_ASSERT_TRUE(out.size(), expected); + // Worker should be able to clean up the lock + const int64_t TWO_MINUTES = 120'000; + DS_ASSERT_OK(producer2->Send(element, TWO_MINUTES)); + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(consumer2->Close()); +} + +TEST_F(ClientCrashTest, TestProducerCrash5) +{ + HostPort workerAddress1; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress1)); + HostPort workerAddress2; + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress2)); + + // Create a subscriber on worker2 so that there is a Worker1 to Worker2 flush + auto client2 = std::make_unique("streamCrash5"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress2.Host(), workerAddress2.Port())); + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer2)); + // Create producer on worker1 and crash it after slot count is updated + // This should just discard all data from the producer and remote worker should get ntg + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamDataPool.SendElementsToRemote.wait", "sleep(16000)")); + auto pid = fork(); + if (pid == 0) { + // make scan eval thread wait for client timeout and clearAllRemoteConsumer + auto client1 = std::make_unique("streamCrash5"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress1.Host(), workerAddress1.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Fake a crash point within producer after leaving the lock + DS_ASSERT_OK(producer1->Send(element)); + datasystem::inject::Set("producer_update_pending_slot_count_without_lock", "1*abort()"); + DS_ASSERT_OK(producer1->Send(element)); + _exit(0); + } else { + ASSERT_TRUE(pid > 0); + // Wait for producer to end + int status; + waitpid(pid, &status, 0); + + // wait for client timeout + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "ClientManager.Init.heartbeatInterval", "call(500)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientManager.IsClientLost.heartbeatThreshold", "call(1)")); + // Remote consumer should not get any data as producer crashed + std::vector out; + const uint32_t timeoutMs = 100; + DS_ASSERT_OK(consumer2->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 0); + DS_ASSERT_OK(consumer2->Close()); + } +} + +TEST_F(ClientCrashTest, DISABLED_TestProducerCrash6) +{ + // Constructed based on TestProducerCrash1, while added more producer send + // to trigger more of Ack related logic + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + const uint32_t timeoutMs = 10000; + const size_t elementSize = 2000; + const int numEle = 20; + auto writeElement = RandomData().RandomBytes(elementSize); + Element element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("ProducerCrash6"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Send a few pages of elements to populate ackChain + for (int i = 0; i < numEle; i++) { + DS_ASSERT_OK(producer1->Send(element)); + } + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("ProducerCrash6"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer; + DS_ASSERT_OK(client2->Subscribe("sub", consumer)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(numEle, timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), numEle); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + + int status; + waitpid(pid, &status, 0); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + datasystem::inject::Clear("producer_obtained_lock"); + // Worker should be able to clean up the lock + const int64_t ONE_MINUTE = 60'000; + DS_ASSERT_OK(producer2->Send(element, ONE_MINUTE)); + DS_ASSERT_OK(consumer->Receive(timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + DS_ASSERT_OK(consumer->Close()); +} + +TEST_F(ClientCrashTest, TestProducerCrash7) +{ + // Constructed based on TestProducerCrash6, while included big elements + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + const uint32_t timeoutMs = 10000; + const size_t elementSize = 1500; + const size_t bigElementSize = 1024 * 10; + const int numRound = 19; + const int numEle = 59; + auto writeElement = RandomData().RandomBytes(elementSize); + Element element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + auto writeBigElement = RandomData().RandomBytes(bigElementSize); + Element bigElement = Element(reinterpret_cast(writeBigElement.data()), writeBigElement.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("ProducerCrash7"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + // Send mixture of elements and big elements + for (int i = 0; i < numRound; i++) { + DS_ASSERT_OK(producer1->Send(element)); + DS_ASSERT_OK(producer1->Send(bigElement)); + DS_ASSERT_OK(producer1->Send(element)); + } + DS_ASSERT_OK(producer1->Send(element)); + DS_ASSERT_OK(producer1->Send(bigElement)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("ProducerCrash7"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer; + DS_ASSERT_OK(client2->Subscribe("sub", consumer)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(numEle, timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), numEle); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + + int status; + waitpid(pid, &status, 0); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + datasystem::inject::Clear("producer_obtained_lock"); + // Worker should be able to clean up the lock + const int64_t ONE_MINUTE = 60'000; + DS_ASSERT_OK(producer2->Send(element, ONE_MINUTE)); + DS_ASSERT_OK(consumer->Receive(timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + DS_ASSERT_OK(consumer->Close()); +} + +TEST_F(ClientCrashTest, DISABLED_TestProducerCrash8) +{ + // README + // numElem reduced from 30000 for CI runtime purposes + // This testcase intends to construct the producer crash problem involving 2 clients + // And the first client to recover is not the page lock holder. + HostPort workerAddress1; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress1)); + HostPort workerAddress2; + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress2)); + + const size_t elementSize = 180; + const int numEle = 10000; + auto writeElement = RandomData().RandomBytes(elementSize); + Element element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + auto client1Pid = fork(); + if (client1Pid == 0) { + auto client1 = std::make_unique("ProducerCrash8"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress1.Host(), workerAddress1.Port())); + auto client2 = std::make_unique("ProducerCrash8"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress1.Host(), workerAddress1.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + for (int i = 0; i < numEle; i++) { + DS_ASSERT_OK(producer1->Send(element)); + } + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer2->Send(element)); + _exit(0); + } + ASSERT_TRUE(client1Pid > 0); + auto client3 = std::make_unique("ProducerCrash8"); + DS_ASSERT_OK(client3->InitTestClient(workerAddress2.Host(), workerAddress2.Port())); + std::shared_ptr consumer; + DS_ASSERT_OK(client3->Subscribe("sub", consumer)); + + int status; + waitpid(client1Pid, &status, 0); + + const uint64_t clientDeadTimeoutSec = 15; + sleep(clientDeadTimeoutSec + 1); + DS_ASSERT_OK(consumer->Close()); + auto client4 = std::make_unique("StreamNameTest"); + DS_ASSERT_OK(client4->InitTestClient(workerAddress1.Host(), workerAddress1.Port())); + std::shared_ptr producer3; + DS_ASSERT_OK(client4->CreateProducer(producer3)); +} + +TEST_F(ClientCrashTest, TestProducerCrashFixPage) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + // Create 2 producers for the same stream + // Set element size and page size almost same + // This means every Send() triggers a CreateShmPage + // Once these conditions are there + // Let one of the producer die at the lock + // Now other will send to the same and get stuck + + // Page size is 4KB so one element should fit one page + const size_t elementSize = 4000; + auto writeElement = RandomData().RandomBytes(elementSize); + Element element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("CrashFixPage"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + DS_ASSERT_OK(producer1->Send(element)); + DS_ASSERT_OK(producer1->Send(element)); + // Fake a crash point within producer while holding shared lock to read last page + datasystem::inject::Set("producer_crash_getview", "1*abort()"); + DS_ASSERT_NOT_OK(producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + auto client2 = std::make_unique("CrashFixPage"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + int status; + waitpid(pid, &status, 0); + // Other producer should not be stuck at create shm page + DS_ASSERT_OK(producer2->Send(element)); +} + +TEST_F(ClientCrashTest, DISABLED_TestConsumerCrash1) +{ + // Test that consumer crash with ref count will not block further release and ack + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + const size_t maxStreamSize = 2 * 1024 * 1024; + const size_t pageSize = 1024 * 1024; + const size_t dataSize = 200 * 1024; + const size_t numEle = 20; + std::string data = RandomData().GetRandomString(dataSize); + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + if (pid == 0) { + auto client2 = std::make_unique("ConsumerCrash1"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + uint64_t producerCount = 0; + while (producerCount == 0) { + DS_ASSERT_OK(client2->QueryTotalProducerNum(producerCount)); + } + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer2)); + const uint32_t timeoutMs = 5000; + std::vector outElements; + // Guarantee that the page is fetched, so the ref count is incremented with Worker EyeCatcher V0. + DS_ASSERT_OK(consumer2->Receive(1, timeoutMs, outElements)); + ASSERT_EQ(outElements.size(), 1); + + std::abort(); + } + ASSERT_TRUE(pid > 0); + auto client1 = std::make_unique("ConsumerCrash1"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + conf.retainForNumConsumers = 1; + DS_ASSERT_OK(client1->CreateProducer(producer1, conf)); + DS_ASSERT_OK(producer1->Send(element)); + + int status; + waitpid(pid, &status, 0); + + const uint64_t clientDeadTimeoutSec = 15; + sleep(clientDeadTimeoutSec + 1); + std::shared_ptr consumer1; + DS_ASSERT_OK(client1->Subscribe("sub", consumer1)); + const uint32_t timeoutMs = 5000; + std::vector outElements; + // Continuously send until about 2 times max stream size, + // to verify that pages are acked and released. + for (size_t i = 0; i < numEle; i++) { + DS_ASSERT_OK(producer1->Send(element, timeoutMs)); + DS_ASSERT_OK(consumer1->Receive(1, timeoutMs, outElements)); + ASSERT_EQ(outElements.size(), 1); + DS_ASSERT_OK(consumer1->Ack(outElements.back().id)); + } +} + +TEST_F(ClientCrashTest, AckPointLastRow) +{ + HostPort workerAddress, workerAddress2; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress2)); + + int timeOut = 1000; + // Send enough elements so that the starting cursor picks up from previous consumer's + // last ack point + // 4096 + a few extra elements to get lastrecvcursor to be equal + // to lastAck cursor + int elementCount = 4100; + std::string data = "A"; + Element element(reinterpret_cast(&data.front()), data.size()); + std::vector outElements; + + auto client1 = std::make_unique("testAckPointLastRow"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + auto client2 = std::make_unique("testAckPointLastRow"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress2.Host(), workerAddress2.Port())); + + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(producer)); + + std::shared_ptr consumer1, consumer2; + DS_ASSERT_OK(client2->Subscribe("sub", consumer1)); + + // producer sends data + for (int i = 0; i < elementCount; i++) { + producer->Send(element); + } + + // consume and ack all data so that cursor is on last row + for (int i = 0; i < elementCount; i++) { + DS_ASSERT_OK(consumer1->Receive(1, timeOut, outElements)); + ASSERT_EQ(outElements.size(), 1); + DS_ASSERT_OK(consumer1->Ack(i + 1)); + } + + DS_ASSERT_OK(consumer1->Close()); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("sub2", consumer2)); + + // Assert that it can still send, even with LastRecvCursor = AckCursor + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(consumer2->Receive(1, timeOut, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + DS_ASSERT_OK(consumer2->Close()); +} + +TEST_F(ClientCrashTest, ClientResetTest) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + auto client = std::make_unique("testClientReset"); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port())); + + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(producer)); + + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("sub", consumer)); + + client.reset(); + + int elementCount = 100; + int dataSize = 1024; + std::string data(dataSize, 'a'); + Element element(reinterpret_cast(&data.front()), data.size()); + + // producer sends data + for (int i = 0; i < elementCount; i++) { + producer->Send(element); + } + + int timeout = 3000; + for (int i = 0; i < elementCount; i++) { + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, timeout, outElements)); + ASSERT_EQ(outElements.size(), 1); + DS_ASSERT_OK(consumer->Ack(i + 1)); + } + + producer.reset(); + consumer.reset(); +} + +TEST_F(ClientCrashTest, ClientShutdownWhenOOMTest) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + const int timeout = 10000; + auto client = std::make_unique("testClientShutdownWhenOOM"); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port(), timeout)); + + std::shared_ptr producer; + const uint64_t maxStreamSize = 20 * 1024 * 1024; // The max size of stream page is 10M + const int64_t pageSize = 4 * 1024; // The size of page is 4096 bytes + ProducerConf conf; + conf.delayFlushTime = -1; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + conf.autoCleanup = true; + DS_ASSERT_OK(client->CreateProducer(producer, conf)); + + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("sub", consumer)); + + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.CheckHadEnoughMem", "return(K_OUT_OF_MEMORY)")); + DS_ASSERT_OK(inject::Set("client.CreateWritePage", "call()")); + + Timer timer; + std::thread t([&client, &timer] { + const int delay = 3000; + std::this_thread::sleep_for(std::chrono::milliseconds(delay)); + LOG_IF_ERROR(client->Shutdown(), "shutdown"); + timer.Reset(); + }); + int elementCount = 100; + int dataSize = 1024; + std::string data(dataSize, 'a'); + Element element(reinterpret_cast(&data.front()), data.size()); + + // producer sends data + for (int i = 0; i < elementCount; i++) { + Status rc = producer->Send(element, timeout); + if (rc.IsError()) { + LOG(INFO) << rc.ToString(); + break; + } + } + + // waiting worker call HandleBlockedCreateTimeout after timeout. + while (timer.ElapsedMicroSecond() <= timeout) { + const int interval = 100; + std::this_thread::sleep_for(std::chrono::milliseconds(interval)); + } + + producer.reset(); + consumer.reset(); + t.join(); +} + +TEST_F(ClientCrashTest, DISABLED_TestForceCloseLocalProducersSameWorker) +{ + // We have two clients on same worker + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + const uint64_t maxStreamSize = 1024 * 1024; + const int64_t pageSize = 4 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + std::string streamName = "testForceCloseProdSameWorker"; + + // client1 have 5 producers for same stream + std::vector> producers; + const int num_producers = 5; + for (int i = 0; i < num_producers; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + producers.emplace_back(producer); + } + + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "ClientManager.Init.heartbeatInterval", "call(500)")); + auto pid = fork(); + if (pid == 0) { + // client2 have 5 producers for same stream + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client2)); + std::vector> producers2; + for (int i = 0; i < num_producers; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client2->CreateProducer(streamName, producer, conf)); + producers2.emplace_back(producer); + } + // client2 crashes after producer creation + std::abort(); + ASSERT_TRUE(false); + } + ASSERT_TRUE(pid > 0); + + int status; + waitpid(pid, &status, 0); + + // Wait for cleanup to finish. its set to 2secs above + const int sleepTime = 3; + sleep(sleepTime); + + // Other client will close all its producers + for (auto &producer : producers) { + DS_ASSERT_OK(producer->Close()); + } + // Master metadata should be cleared and + // We should be able to delete the stream + DS_ASSERT_OK(client1->DeleteStream(streamName)); +} + +TEST_F(ClientCrashTest, DISABLED_TestForceCloseLocalProducersDifferentWorker) +{ + // We have two clients on same worker + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + const uint64_t maxStreamSize = 1024 * 1024; + const int64_t pageSize = 4 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + std::string streamName = "testForceCloseProdDiffWorker"; + + // client1 have 5 producers for same stream + std::vector> producers; + const int num_producers = 5; + for (int i = 0; i < num_producers; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + producers.emplace_back(producer); + } + + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "ClientManager.Init.heartbeatInterval", "call(500)")); + auto pid = fork(); + if (pid == 0) { + // client2 have 5 producers for same stream in a different worker + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(1, client2)); + std::vector> producers2; + for (int i = 0; i < num_producers; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client2->CreateProducer(streamName, producer, conf)); + producers2.emplace_back(producer); + } + // client2 crashes after producer creation + std::abort(); + ASSERT_TRUE(false); + } + ASSERT_TRUE(pid > 0); + + int status; + waitpid(pid, &status, 0); + + // Wait for cleanup to finish. its set to 2secs above + const int sleepTime = 3; + sleep(sleepTime); + + // Other client will close all its producers + for (auto &producer : producers) { + DS_ASSERT_OK(producer->Close()); + } + // Master metadata should be cleared and + // We should be able to delete the stream + DS_ASSERT_OK(client1->DeleteStream(streamName)); +} + +TEST_F(ClientCrashTest, LEVEL2_TestForceCloseDeadlock) +{ + // Test that with large amount of streams, force close can generate logical deadlock + // if too many const_accessor are held at the same time. + const int streamNum = 1000; + const uint64_t maxStreamSize = 1024 * 1024; + const int64_t pageSize = 4 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + conf.autoCleanup = true; + auto pid = fork(); + if (pid == 0) { + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + std::vector> producers; + for (int i = 0; i < streamNum; i++) { + std::string streamName = "Stream_" + std::to_string(i); + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(streamName, producer, conf)); + producers.emplace_back(producer); + } + std::abort(); + ASSERT_TRUE(false); + } + ASSERT_TRUE(pid > 0); + + int status; + waitpid(pid, &status, 0); + + // Wait for MasterWorkerSCService thread pool to be occupied. + const int sleepTime = 20; + sleep(sleepTime); + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); +} + +TEST_F(ClientCrashTest, TestForceEarlyReturn) +{ + // Test that force close skips sending CloseProducer request to master + // from some streams after previous manual delete fails. + const int streamNum = 2; + const uint64_t maxStreamSize = 1024 * 1024; + const int64_t pageSize = 4 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + conf.autoCleanup = true; + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "ClientManager.Init.heartbeatInterval", "call(500)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "CloseProducer.TimeoutInMaster", + "1*return(K_RPC_UNAVAILABLE)")); + pid_t pid = fork(); + if (pid == 0) { + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::vector> producers; + for (int i = 0; i < streamNum; i++) { + std::string streamName = "Stream_" + std::to_string(i); + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + producers.emplace_back(producer); + } + // construct the case where manual CloseProducer is successful on master, but timeout on worker. + DS_ASSERT_NOT_OK(producers[0]->Close()); + std::abort(); + }; + + ASSERT_TRUE(pid > 0); + + int status; + waitpid(pid, &status, 0); + + // Wait for force close + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "ClientManager.IsClientLost.heartbeatThreshold", + "call(1)")); + sleep(streamNum); +} + +TEST_F(ClientCrashTest, TestConsuemrBadFnCallCrashWithLock) +{ + size_t maxPageCount = 2; + size_t pageSize = 1024 * 1024; + int streamCount = 2; + int oomTimeout = 3; + + const size_t elementSize = 10240; // 10k. + size_t nums = pageSize * maxPageCount * 2 / elementSize - 10; + + std::string streamNameCrash = "CrashConsumer"; + auto wrapper = RunInChildProcess([&] { + std::shared_ptr client; + DS_ASSERT_OK(InitClient(1, client)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe(streamNameCrash, config, consumer, true)); + + size_t recvNum = 0; + while (recvNum < nums) { + std::vector outElements; + const int recvTimeout = 1000; + ASSERT_EQ(consumer->Receive(1, recvTimeout, outElements), Status::OK()); + if (outElements.empty()) { + continue; + } + if (recvNum == 0) { + DS_ASSERT_OK(inject::Set("SharedMemViewImpl.GetView", "1*call()")); + } + recvNum += outElements.size(); + LOG(INFO) << "stream:" << streamNameCrash << ", SubProcess Recv num:" << recvNum; + } + std::abort(); + }); + if (wrapper == nullptr) { + return; + } + + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string data(elementSize, 'a'); + Element element((uint8_t *)data.data(), data.size()); + + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = pageSize * maxPageCount; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::vector clients = { client1.get(), client2.get() }; + std::vector> consumers; + std::vector> producers; + for (int index = 0; index < streamCount; index++) { + std::string streamName = "TestStream-" + std::to_string(index); + std::shared_ptr producer; + std::shared_ptr consumer; + ASSERT_EQ(clients[index % clients.size()]->CreateProducer(streamName, producer, conf), Status::OK()); + ASSERT_EQ(clients[(index + 1) % clients.size()]->Subscribe(streamName, config, consumer, true), Status::OK()); + consumers.emplace_back(std::move(consumer)); + producers.emplace_back(std::move(producer)); + } + + DS_ASSERT_OK(inject::Set("SharedMemViewImpl.SetView", "1*call()")); + ASSERT_EQ(producers[0]->Send(element).GetCode(), K_RUNTIME_ERROR); + // Send. + for (int index = 0; index < streamCount; index++) { + for (size_t i = 0; i < nums; i++) { + Status rc = producers[index]->Send(element); + Timer timer; + const int maxTimeout = 10; + while (rc.GetCode() == K_OUT_OF_MEMORY && timer.ElapsedSecond() < maxTimeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + rc = producers[index]->Send(element); + } + DS_ASSERT_OK(rc); + LOG(INFO) << FormatString("Stream index %zu, send count: %zu", streamCount, i); + } + } + + std::thread t([=] { + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = pageSize * maxPageCount; + conf.autoCleanup = true; + std::shared_ptr producer; + while (true) { + size_t gConsumerNum = 0; + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamNameCrash, gConsumerNum)); + if (gConsumerNum > 0) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ASSERT_EQ(client1->CreateProducer(streamNameCrash, producer, conf), Status::OK()); + for (size_t i = 0; i < nums; i++) { + Status rc = producer->Send(element); + Timer timer; + const int maxTimeout = 10; + while (rc.GetCode() == K_OUT_OF_MEMORY && timer.ElapsedSecond() < maxTimeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + rc = producer->Send(element); + } + DS_ASSERT_OK(rc); + LOG(INFO) << FormatString("stream: %s,Stream index %zu,: send count: %zu", streamNameCrash, streamCount, i); + } + }); + + sleep(oomTimeout); + + for (int index = 0; index < streamCount; index++) { + size_t recvNum = 0; + while (recvNum < nums) { + std::vector outElements; + const int recvTimeout = 1000; + ASSERT_EQ(consumers[index]->Receive(nums, recvTimeout, outElements), Status::OK()); + recvNum += outElements.size(); + LOG(INFO) << "Recv num:" << recvNum; + } + ASSERT_EQ(recvNum, nums); + } + t.join(); +} + +class ClientCrashWithLockTest : public ClientCrashTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 1; + opts.masterIdx = 0; + opts.numRpcThreads = 0; + opts.numEtcd = 1; + opts.workerGflagParams = FormatString("-client_dead_timeout_s=%zu -v=2 -log_monitor=true", clientDeadTimeout); + SCClientCommon::SetClusterSetupOptions(opts); + } + +protected: + int clientDeadTimeout = 5; +}; + +TEST_F(ClientCrashWithLockTest, TestProducerCrashWithPageMemViewLock) +{ + size_t maxPageCount = 4; + size_t pageSize = 1024 * 1024; + + const size_t elementSize = 10240; // 10k. + std::string data(elementSize, 'a'); + Element element((uint8_t *)data.data(), data.size()); + + std::string streamNameCrash = "CrashConsumer"; + auto wrapper = RunInChildProcess([&] { + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = pageSize * maxPageCount; + while (true) { + size_t gConsumerNum = 0; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamNameCrash, gConsumerNum)); + if (gConsumerNum > 0) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ASSERT_EQ(client->CreateProducer(streamNameCrash, producer, conf), Status::OK()); + DS_ASSERT_OK(inject::Set("client.Producer.beforeCheckNewPage", "2*off->1*call()")); + // send untill crash. + size_t sendCount = 0; + while (true) { + Status rc = producer->Send(element); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + DS_ASSERT_OK(rc); + sendCount++; + LOG(INFO) << "stream:" << streamNameCrash << ", SubProcess send num:" << sendCount; + } + std::abort(); + }); + if (wrapper == nullptr) { + return; + } + + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe(streamNameCrash, config, consumer, true)); + + int delayBeforeRecvMs = 1000; + std::this_thread::sleep_for(std::chrono::milliseconds(delayBeforeRecvMs)); + int testTimeMs = 5000; + size_t recvNum = 0; + Timer timer; + while (timer.ElapsedMilliSecond() < testTimeMs) { + std::vector outElements; + const int recvTimeout = 1000; + ASSERT_EQ(consumer->Receive(1, recvTimeout, outElements), Status::OK()); + if (outElements.empty()) { + continue; + } + recvNum += outElements.size(); + LOG(INFO) << "stream:" << streamNameCrash << ", Recv num:" << recvNum; + } +} + +TEST_F(ClientCrashWithLockTest, LEVEL1_TestProducerCrashWithCursorMemViewLock) +{ + size_t maxPageCount = 5; + size_t pageSize = 1024 * 1024; + + const size_t elementSize = 10240; // 10k. + std::string data(elementSize, 'a'); + Element element((uint8_t *)data.data(), data.size()); + + std::string streamNameCrash = "CrashConsumer"; + auto wrapper = RunInChildProcess([&] { + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = pageSize * maxPageCount; + while (true) { + size_t gConsumerNum = 0; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamNameCrash, gConsumerNum)); + if (gConsumerNum > 0) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ASSERT_EQ(client->CreateProducer(streamNameCrash, producer, conf), Status::OK()); + DS_ASSERT_OK(inject::Set("client.Producer.beforeCheckCursor", "1*call()")); + DS_ASSERT_OK(producer->Send(element)); + std::abort(); + }); + if (wrapper == nullptr) { + return; + } + int delayBeforeSendMs = 1000; + std::this_thread::sleep_for(std::chrono::milliseconds(delayBeforeSendMs)); + + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe(streamNameCrash, config, consumer, true)); + + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = pageSize * maxPageCount; + ASSERT_EQ(client->CreateProducer(streamNameCrash, producer, conf), Status::OK()); + // send again. + size_t sendCount = 0; + int testTimeMs = 5000; + Timer timer; + while (timer.ElapsedMilliSecond() < testTimeMs) { + Status rc = producer->Send(element); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + LOG(INFO) << "send finish with:" << rc.ToString(); + break; + } + DS_ASSERT_OK(rc); + sendCount++; + LOG(INFO) << "stream:" << streamNameCrash << ", Main Process send num:" << sendCount; + } +} + +TEST_F(ClientCrashWithLockTest, DISABLED_TestConsuemrCrashWithPageMemViewLock) +{ + // consumer crash with MemView in Page, worker and producer will block. + size_t maxPageCount = 2; + size_t pageSize = 1024 * 1024; + + const size_t elementSize = 10240; // 10k. + const int testTimeMs = 5000; + + std::string streamNameCrash = "CrashConsumer"; + auto wrapper = RunInChildProcess([&] { + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe(streamNameCrash, config, consumer, true)); + + DS_ASSERT_OK(inject::Set("SharedMemViewImpl.GetView", "abort")); + size_t recvNum = 0; + Timer timer; + // recv until crash + while (timer.ElapsedMilliSecond() < testTimeMs) { + std::vector outElements; + const int recvTimeout = 1000; + ASSERT_EQ(consumer->Receive(1, recvTimeout, outElements), Status::OK()); + if (outElements.empty()) { + continue; + } + recvNum += outElements.size(); + LOG(INFO) << "stream:" << streamNameCrash << ", SubProcess Recv num:" << recvNum; + } + std::abort(); + }); + if (wrapper == nullptr) { + return; + } + + std::string data(elementSize, 'a'); + Element element((uint8_t *)data.data(), data.size()); + + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = pageSize * maxPageCount; + while (true) { + size_t gConsumerNum = 0; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamNameCrash, gConsumerNum)); + if (gConsumerNum > 0) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ASSERT_EQ(client->CreateProducer(streamNameCrash, producer, conf), Status::OK()); + size_t sendCount = 0; + Timer timer; + while (timer.ElapsedMilliSecond() < testTimeMs) { + Status rc = producer->Send(element); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + DS_ASSERT_OK(rc); + sendCount++; + LOG(INFO) << "stream:" << streamNameCrash << ", SubProcess send num:" << sendCount; + } +} + +class ClientLockVersionTest : public ClientCrashTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 1; + opts.masterIdx = 0; + opts.numRpcThreads = 0; + opts.numEtcd = 1; + opts.workerGflagParams = FormatString("-client_dead_timeout_s=%zu -v=2 -log_monitor=true", clientDeadTimeout); + SCClientCommon::SetClusterSetupOptions(opts); + } + +protected: + int clientDeadTimeout = 5; + + Status StreamSendRecvTest(bool newProducer, bool newConsumer, bool newWorker) + { + const size_t maxPageCount = 3; + const size_t pageSize = 100 * 1024; + const size_t testElementCount = 1000; + const size_t producerCount = 2; + const size_t elementSize = 10240; // 10k. + + std::string testStreamName = "SendRecvTestConsumer"; + auto wrapper = RunInChildProcess([&] { + if (!newProducer) { + LOG_IF_ERROR(inject::Set("MemView.Lock.OldVersion", "return"), "inject set failed"); + } + std::shared_ptr client; + DS_ASSERT_OK(InitClient(0, client)); + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = pageSize * maxPageCount; + while (true) { + size_t gConsumerNum = 0; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(testStreamName, gConsumerNum)); + if (gConsumerNum > 0) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + std::vector> producers; + for (size_t n = 0; n < producerCount; n++) { + std::shared_ptr producer; + ASSERT_EQ(client->CreateProducer(testStreamName, producer, conf), Status::OK()); + producers.emplace_back(std::move(producer)); + } + + std::vector threads; + for (size_t n = 0; n < producerCount; n++) { + threads.emplace_back([n, &producers, &testStreamName] { + size_t sendCount = 0; + while (sendCount < testElementCount) { + char ch = sendCount % INT8_MAX; + std::string data(elementSize, ch); + Element element((uint8_t *)data.data(), data.size()); + + Status rc = producers[n]->Send(element); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + DS_ASSERT_OK(rc); + sendCount++; + LOG(INFO) << "stream:" << testStreamName << ", SubProcess send num:" << sendCount; + } + }); + } + + for (auto &t : threads) { + t.join(); + } + std::abort(); + }); + if (wrapper == nullptr) { + return Status(K_INVALID, "invalid"); + } + + if (!newWorker) { + RETURN_IF_NOT_OK(cluster_->SetInjectAction(WORKER, 0, "MemView.Lock.OldVersion", "return")); + } + + if (!newConsumer) { + RETURN_IF_NOT_OK(inject::Set("MemView.Lock.OldVersion", "return")); + } + + std::shared_ptr client; + RETURN_IF_NOT_OK(InitClient(0, client)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(client->Subscribe(testStreamName, config, consumer, true)); + + int testTimeOutMs = 20000; + size_t expectRecvElementCount = testElementCount * producerCount; + size_t recvNum = 0; + Timer timer; + while (timer.ElapsedMilliSecond() < testTimeOutMs && recvNum < expectRecvElementCount) { + std::vector outElements; + const int recvTimeout = 1000; + RETURN_IF_NOT_OK(consumer->Receive(1, recvTimeout, outElements)); + if (outElements.empty()) { + continue; + } + std::string expectData(elementSize, static_cast(outElements[0].ptr[0])); + std::string recvData((char *)(outElements[0].ptr), outElements[0].size); + if (recvData != expectData) { + const int printCount = 100; + LOG(ERROR) << "expectData:" << expectData.substr(0, printCount) + << ", recvData:" << recvData.substr(0, printCount); + return Status(K_RUNTIME_ERROR, "invalid data"); + } + recvNum += outElements.size(); + LOG(INFO) << "stream:" << testStreamName << ", Recv num:" << recvNum; + } + + if (recvNum != expectRecvElementCount) { + return Status(K_RUNTIME_ERROR, + FormatString("Recv count %zu, expect count %zu", recvNum, expectRecvElementCount)); + } + return Status::OK(); + } +}; + +TEST_F(ClientLockVersionTest, SendRecvTest1) +{ + DS_ASSERT_OK(StreamSendRecvTest(false, false, true)); +} + +TEST_F(ClientLockVersionTest, SendRecvTest2) +{ + DS_ASSERT_OK(StreamSendRecvTest(true, false, true)); +} + +TEST_F(ClientLockVersionTest, LEVEL2_SendRecvTest3) +{ + DS_ASSERT_OK(StreamSendRecvTest(false, true, true)); +} + +TEST_F(ClientLockVersionTest, SendRecvTest4) +{ + DS_ASSERT_OK(StreamSendRecvTest(true, true, false)); +} + +TEST_F(ClientLockVersionTest, SendRecvTest5) +{ + DS_ASSERT_OK(StreamSendRecvTest(true, false, false)); +} + +TEST_F(ClientLockVersionTest, SendRecvTest6) +{ + DS_ASSERT_OK(StreamSendRecvTest(false, true, false)); +} + +class ClientCrashShortTimeoutTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 2; + opts.masterIdx = 0; + opts.numRpcThreads = 0; + opts.numEtcd = 1; + opts.workerGflagParams = FormatString("-client_dead_timeout_s=%zu -v=2", clientDeadTimeoutSec); + SCClientCommon::SetClusterSetupOptions(opts); + } + +protected: + const uint64_t clientDeadTimeoutSec = 15; +}; + +TEST_F(ClientCrashShortTimeoutTest, DISABLED_LEVEL1_TestResourceClearDuration) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + auto pid = fork(); + if (pid == 0) { + auto client1 = std::make_unique("testResourceClear"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer; + std::shared_ptr consumer; + DS_ASSERT_OK(client1->CreateProducer(producer)); + DS_ASSERT_OK(client1->Subscribe("sub1", consumer)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(producer->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + + int status; + waitpid(pid, &status, 0); + + auto client2 = std::make_unique("testResourceClear"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + + uint64_t totalConsumerNum; + uint64_t totalProducerNum; + DS_ASSERT_OK(client2->QueryTotalConsumerNum(totalConsumerNum)); + DS_ASSERT_OK(client2->QueryTotalProducerNum(totalProducerNum)); + ASSERT_EQ(totalConsumerNum, 1ul); + ASSERT_EQ(totalProducerNum, 1ul); + + sleep(clientDeadTimeoutSec + 1); + DS_ASSERT_OK(client2->QueryTotalConsumerNum(totalConsumerNum)); + DS_ASSERT_OK(client2->QueryTotalProducerNum(totalProducerNum)); + ASSERT_EQ(totalConsumerNum, 0ul); + ASSERT_EQ(totalProducerNum, 0ul); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/client_worker_heartbeat_test.cpp b/tests/st/client/stream_cache/client_worker_heartbeat_test.cpp new file mode 100644 index 0000000..ce8a6db --- /dev/null +++ b/tests/st/client/stream_cache/client_worker_heartbeat_test.cpp @@ -0,0 +1,477 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: + */ +#include + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/common/inject/inject_point.h" + +using namespace datasystem::client::stream_cache; + +namespace datasystem { +namespace st { +class ClientWorkerSCHeartbeatTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 2; + opts.masterIdx = 1; + opts.numEtcd = 1; + datasystem::inject::Set("ListenWorker.CheckHeartbeat.interval", "call(500)"); + datasystem::inject::Set("ListenWorker.CheckHeartbeat.heartbeat_interval_ms", "call(500)"); + datasystem::inject::Set("ClientWorkerCommonApi.SendHeartbeat.timeoutMs", "call(500)"); + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + } + + void TearDown() override + { + ExternalClusterTest::TearDown(); + } +}; + +class ClientSC { +public: + explicit ClientSC(std::string streamName) : streamName_(std::move(streamName)) + { + } + ~ClientSC() = default; + + Status InitTestClient(const std::string &ip, const int &port, int timeout = 60000); + + Status CreateProducer(std::shared_ptr &producer); + + Status Subscribe(const std::string &subName, std::shared_ptr &consumer); + + Status QueryLocalProducerNum(uint64_t &localProducerNum); + + Status QueryLocalConsumerNum(uint64_t &localProducerNum); + +private: + std::string streamName_; + std::unique_ptr client_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +Status ClientSC::InitTestClient(const std::string &ip, const int &port, int timeout) +{ + ConnectOptions connectOptions; + connectOptions.host = ip; + connectOptions.port = port; + connectOptions.connectTimeoutMs = timeout; + connectOptions.SetAkSkAuth(accessKey_, secretKey_, ""); + client_ = std::make_unique(connectOptions); + return client_->Init(); +} + +Status ClientSC::CreateProducer(std::shared_ptr &producer) +{ + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + return client_->CreateProducer(streamName_, producer, conf); +} + +Status ClientSC::Subscribe(const std::string &subName, std::shared_ptr &consumer) +{ + SubscriptionConfig config(std::move(subName), SubscriptionType::STREAM); + return client_->Subscribe(streamName_, config, consumer); +} + +Status ClientSC::QueryLocalProducerNum(uint64_t &localProducerNum) +{ + auto rc = client_->QueryGlobalProducersNum(streamName_, localProducerNum); + return rc; +} + +Status ClientSC::QueryLocalConsumerNum(uint64_t &localConsumerNum) +{ + auto rc = client_->QueryGlobalConsumersNum(streamName_, localConsumerNum); + return rc; +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestOneClientCrash) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + LOG(INFO) << "start create client 1"; + auto client1 = std::make_unique("OneClientCrash1"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + std::shared_ptr consumer1; + std::shared_ptr consumer2; + DS_ASSERT_OK(client1->Subscribe("sub1", consumer1)); + DS_ASSERT_OK(client1->Subscribe("sub2", consumer2)); + + uint64_t queryRet = 0; + DS_ASSERT_OK(client1->QueryLocalProducerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(1)); + queryRet = 0; + DS_ASSERT_OK(client1->QueryLocalConsumerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(2)); + producer1.reset(); + consumer1.reset(); + consumer2.reset(); + client1.reset(); + + // sleep 1s + usleep(1'000'000); + LOG(INFO) << "start create client 2"; + auto client2 = std::make_unique("OneClientCrash2"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + std::shared_ptr consumer3; + DS_ASSERT_OK(client2->Subscribe("sub3", consumer3)); + queryRet = 0; + DS_ASSERT_OK(client2->QueryLocalProducerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(1)); + queryRet = 0; + DS_ASSERT_OK(client2->QueryLocalConsumerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(1)); + producer2->Close(); + consumer3->Close(); +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestWorkerCrashAndClientCanIdentify) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + const int timeout = 2000; + auto client = std::make_unique("ClientCanIdentify"); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port(), timeout)); + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(producer)); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("subName", consumer)); + + // shutdown worker + LOG(INFO) << "shutdown worker" << WORKER; + cluster_->ShutdownNodes(WORKER); + sleep(1); // The heartbeat interval is 0.5s, and the maximum number of worker disconnections is 1s. + + // Since the worker already exited, can allow client to exit quickly + DS_ASSERT_OK(datasystem::inject::Set("ClientWorkerCommonApi.Disconnect.ShutdownQuickily", "call(200)")); + DS_ASSERT_NOT_OK(producer->Close()); +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestNormalWorkerCrash) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + LOG(INFO) << "start create client"; + auto client1 = std::make_unique("NormalWorkerCrash1"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + + // Shutdown worker + cluster_->ShutdownNode(WORKER, 0); + // Client enable test + uint64_t queryRet = 0; + auto rc = client1->QueryLocalProducerNum(queryRet); + ASSERT_EQ(rc.IsError(), true); + // Restart worker + cluster_->StartNode(WORKER, 0, ""); + cluster_->WaitNodeReady(WORKER, 0); + + // Old Producer enable test + // New producer and consumer test + auto client2 = std::make_unique("NormalWorkerCrash2"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(producer2)); + std::shared_ptr consumer; + DS_ASSERT_OK(client2->Subscribe("sub2", consumer)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size(), ULONG_MAX); + DS_ASSERT_OK(producer2->Send(element)); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestWorkerCrashAndClientReadEleNormally) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + const int timeout = 2000; + auto client = std::make_unique("ClientReadNormally"); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port(), timeout)); + + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(producer)); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("subName", consumer)); + + std::string data = "abc"; + Element element(reinterpret_cast(&data.front()), data.size(), ULONG_MAX); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 10000, outElements), Status::OK()); + + // shutdown worker + LOG(INFO) << "shutdown worker" << WORKER; + cluster_->QuicklyShutdownWorker(0); + + // Since the worker already exited, can allow client to exit quickly + DS_ASSERT_OK(datasystem::inject::Set("ClientWorkerCommonApi.Disconnect.ShutdownQuickily", "call(200)")); + LOG(INFO) << "After closing the worker, read element again"; + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + ASSERT_EQ(data, actualData); +} + +TEST_F(ClientWorkerSCHeartbeatTest, LEVEL1_TestSignalTermWorker) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + auto client1 = std::make_unique("testSignalTermWorker"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + + pid_t pid = cluster_->GetWorkerPid(0); + ASSERT_NE(pid, -1); + cluster_->ShutdownNodes(WORKER); + + uint64_t queryRet = 0; + auto rc = client1->QueryLocalProducerNum(queryRet); + ASSERT_EQ(rc.IsError(), true); +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestOneClientCrashAndReceive) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + LOG(INFO) << "start create client 1"; + auto client1 = std::make_unique("testOneClientCrashAndRecv"); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(producer1)); + std::shared_ptr consumer1; + DS_ASSERT_OK(client1->Subscribe("sub1", consumer1)); + + LOG(INFO) << "start create client 2"; + auto client2 = std::make_unique("testOneClientCrashAndRecv"); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer2; + DS_ASSERT_OK(client2->Subscribe("sub2", consumer2)); + + uint64_t queryRet = 0; + DS_ASSERT_OK(client1->QueryLocalProducerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(1)); + queryRet = 0; + DS_ASSERT_OK(client1->QueryLocalConsumerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(2)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size(), ULONG_MAX); + DS_ASSERT_OK(producer1->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer1->Receive(1, 0, outElements), Status::OK()); + std::string actualData1(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData1); + outElements.clear(); + ASSERT_EQ(consumer2->Receive(1, 0, outElements), Status::OK()); + std::string actualData2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData2); + + consumer2.reset(); + client2.reset(); + outElements.clear(); + + DS_ASSERT_OK(producer1->Send(element)); + ASSERT_EQ(consumer1->Receive(1, 0, outElements), Status::OK()); + std::string actualData3(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData3); +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestOneClientCrashWhenOtherClientCreate) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + std::string streamName = "testOneClientCrashWhenOtherCreate"; + std::thread clientCrash([&workerAddress, streamName]() { + int crashTimes = 2; + while (crashTimes > 0) { + auto client1 = std::make_unique(streamName); + DS_ASSERT_OK(client1->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer1; + DS_ASSERT_OK(client1->Subscribe("threadSub", consumer1)); + client1.reset(); + crashTimes--; + } + }); + + auto client2 = std::make_unique(streamName); + DS_ASSERT_OK(client2->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::vector> consumers; + std::vector subNameList{ "mainSub1", "mainSub2", "mainSub3", "mainSub4", "mainSub5" }; + for (int i = 0; i < 5; i++) { + std::shared_ptr consumer; + DS_ASSERT_OK(client2->Subscribe(subNameList[i], consumer)); + consumers.emplace_back(consumer); + } + std::shared_ptr producer; + DS_ASSERT_OK(client2->CreateProducer(producer)); + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size(), ULONG_MAX); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + for (int i = 0; i < 5; i++) { + outElements.clear(); + ASSERT_EQ(consumers[i]->Receive(1, 0, outElements), Status::OK()); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); + } + + clientCrash.join(); + uint64_t queryRet = 0; + DS_ASSERT_OK(client2->QueryLocalProducerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(1)); + queryRet = 0; + DS_ASSERT_OK(client2->QueryLocalConsumerNum(queryRet)); + ASSERT_EQ(queryRet, size_t(5)); +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestOneClientCrashWhenOtherProducerFlush) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + std::string streamName = "testOneClientCrashWhenOtherProdFlush"; + int sendCount = 30; + ThreadPool pool(4); + pool.Submit([&workerAddress, sendCount, streamName]() { + auto client = std::make_unique(streamName); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("mainSub", consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(producer)); + std::this_thread::sleep_for(std::chrono::milliseconds(9)); + std::string data[30]; + for (int i = 0; i < sendCount; i++) { + data[i] = "Hello World" + std::to_string(i); + uint64_t id = i + 1; + Element element(reinterpret_cast(&data[i].front()), data[i].size()); + element.id = id; + DS_ASSERT_OK(producer->Send(element)); + } + std::vector outElements; + ASSERT_EQ(consumer->Receive(sendCount, 1000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(30)); + for (int i = 0; i < sendCount; i++) { + ASSERT_EQ(outElements[i].id, static_cast(i + 1)); + std::string actualData(reinterpret_cast(outElements[i].ptr), outElements[i].size); + ASSERT_EQ(data[i], actualData); + } + }); + pool.Submit([&workerAddress, sendCount, streamName]() { + auto client = std::make_unique(streamName); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("remoteSub", consumer)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(sendCount, 1000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(30)); + for (int i = 0; i < sendCount; i++) { + ASSERT_EQ(outElements[i].id, static_cast(i + 1)); + std::string actualData(reinterpret_cast(outElements[i].ptr), outElements[i].size); + std::string curData = "Hello World" + std::to_string(i); + ASSERT_EQ(curData, actualData); + } + }); + pool.Submit([&workerAddress, streamName]() { + auto client = std::make_unique(streamName); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("threadSub1", consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(producer)); + std::this_thread::sleep_for(std::chrono::milliseconds(11)); + client.reset(); + }); + pool.Submit([&workerAddress, streamName]() { + auto client = std::make_unique(streamName); + DS_ASSERT_OK(client->InitTestClient(workerAddress.Host(), workerAddress.Port())); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe("threadSub2", consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(producer)); + std::this_thread::sleep_for(std::chrono::milliseconds(13)); + client.reset(); + }); +} + +TEST_F(ClientWorkerSCHeartbeatTest, TestClientWorkerTimeout) +{ + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + + // Create a client with user defined timeout + int32_t timeout = 1000; + ConnectOptions connectOptions = { .host = workerAddress.Host(), + .port = workerAddress.Port(), + .connectTimeoutMs = timeout }; + connectOptions.accessKey = "QTWAOYTTINDUT2QVKYUC"; + connectOptions.secretKey = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + + auto client = std::make_shared(connectOptions); + client->Init(); + + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.CreateProducerImpl.sleep", + "1*sleep(2000)")); + Timer timer; + + // Create a producer + std::shared_ptr producer; + ProducerConf defaultProducerConf; + defaultProducerConf.maxStreamSize = 67108864; + + DS_ASSERT_NOT_OK(client->CreateProducer("testClientWorkerTimeout", producer, defaultProducerConf)); + auto timeCost = static_cast(timer.ElapsedMilliSecond()); + LOG(INFO) << "time cost: " << timeCost; + + // Test timeout + ASSERT_TRUE(timeCost < 1200); + std::this_thread::sleep_for(std::chrono::seconds(1)); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/consumer_large_page_test.cpp b/tests/st/client/stream_cache/consumer_large_page_test.cpp new file mode 100644 index 0000000..0344238 --- /dev/null +++ b/tests/st/client/stream_cache/consumer_large_page_test.cpp @@ -0,0 +1,134 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include +#include + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "common/stream_cache/element_generator.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "client/stream_cache/pub_sub_utils.h" +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class ConsumerLargePageTest : public SCClientCommon { +public: + explicit ConsumerLargePageTest(int pageSize = 1024 * 1024) : pageSize_(pageSize) + { + } + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.workerGflagParams = " -page_size=" + std::to_string(pageSize_); + opts.numRpcThreads = 0; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client_ = nullptr; + ExternalClusterTest::TearDown(); + } + + using InputStreamInfo = mock::InputStreamInfo; + using OutputStreamInfo = mock::OutputStreamInfo; + + Status CreateProducersAndConsumers(std::unordered_map &input, + std::unordered_map &output) + { + return datasystem::st::CreateProducersAndConsumers(client_, input, output); + } + + std::vector GenerateElements(int elementNum, uint64_t elementSize, std::string &outData) + { + outData = RandomData().GetRandomString(elementNum * elementSize); + std::vector ret; + ret.reserve(elementSize); + for (int i = 1; i <= elementNum; i++) { + Element element((uint8_t *)(outData.data()), elementSize, ULONG_MAX); + ret.push_back(element); + } + return ret; + } + +protected: + void InitTest() + { + InitStreamClient(0, client_); + } + std::shared_ptr client_ = nullptr; + uint64_t pageSize_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(ConsumerLargePageTest, SendRecvManyElements) +{ + std::shared_ptr producer; + std::shared_ptr consumer; + std::string streamName = "testSendRecvManyEle"; + + const uint64_t maxStreamSize = 64 * 1024 * 1024; // 64M; + const uint64_t pageSize = 1024 * 1024; // 1M; + DS_ASSERT_OK(client_->CreateProducer( + streamName, producer, { .delayFlushTime = -1, .pageSize = pageSize, .maxStreamSize = maxStreamSize })); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + ElementGenerator generator(1024, 1024); + auto strs = generator.GenElements("producer", 4000); + for (int round = 0; round < 100; round++) { + for (int i = 0; i < 4000; i++) { + ASSERT_EQ(producer->Send(Element{ (uint8_t *)(strs[i].c_str()), strs[i].size() }), Status::OK()); + } + std::vector elements; + consumer->Receive(4000, 0, elements); + for (auto &e : elements) { + ASSERT_EQ(Status::OK(), ElementView(std::string((char *)e.ptr, e.size)).VerifyIntegrity()); + } + consumer->Ack(elements.back().id); + } +} + +TEST_F(ConsumerLargePageTest, PageSizeExceedsStreamSize) +{ + std::shared_ptr producer; + std::shared_ptr consumer; + std::string streamName = "testPgSzExceedsStreamSz"; + + const uint64_t maxStreamSize = 999 * 1024 * 1024; // 999M; + const uint64_t pageSize = 1024 * 1024 * 1024; // 1GB; + ASSERT_EQ(datasystem::StatusCode::K_INVALID, (client_->CreateProducer( + streamName, producer, { .delayFlushTime = -1, .pageSize = pageSize, + .maxStreamSize = maxStreamSize })).GetCode()); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/consumer_test.cpp b/tests/st/client/stream_cache/consumer_test.cpp new file mode 100644 index 0000000..943dc1e --- /dev/null +++ b/tests/st/client/stream_cache/consumer_test.cpp @@ -0,0 +1,1977 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include +#include + +#include "common.h" +#include "datasystem/common/encrypt/secret_manager.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "zmq_curve_test_common.h" +#include "common/stream_cache/element_generator.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/timer.h" +#include "datasystem/stream_client.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "client/stream_cache/pub_sub_utils.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { + +constexpr int K_2 = 2, K_5 = 5, K_10 = 10, K_20 = 20, K_100 = 100, + K_1000 = 1000, K_5000 = 5000; + +class ConsumerTest : public SCClientCommon { +public: + explicit ConsumerTest(int pageSize = 4096) : pageSize_(pageSize) + { + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + defaultProducerConf_.pageSize = pageSize_; + } + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.workerGflagParams = "-v=2 -page_size=" + std::to_string(pageSize_); + opts.numRpcThreads = 0; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client_ = nullptr; + ExternalClusterTest::TearDown(); + } + + using InputStreamInfo = mock::InputStreamInfo; + using OutputStreamInfo = mock::OutputStreamInfo; + + Status CreateProducersAndConsumers(std::unordered_map &input, + std::unordered_map &output) + { + return datasystem::st::CreateProducersAndConsumers(client_, input, output); + } + + std::vector GenerateElements(int elementNum, uint64_t elementSize, std::string &outData) + { + outData = RandomData().GetRandomString(elementNum * elementSize); + std::vector ret; + ret.reserve(elementSize); + for (int i = 1; i <= elementNum; i++) { + Element element(reinterpret_cast(&outData.front()), elementSize); + ret.push_back(element); + } + return ret; + } + + void SendHelper(std::shared_ptr producer, Element element) + { + const int DEFAULT_SLEEP_TIME = 300; + int retryLimit = 30; + datasystem::Status rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + } + DS_ASSERT_OK(rc); + } + + void ReceiveHelper(std::shared_ptr consumer, size_t numElements) + { + size_t remaining = numElements; + int round = 0; + const int roundLimit = 100; + const int PER_RECEIVE_NUM = 500; + const int DEFAULT_WAIT_TIME = 5000; + while (remaining > 0 && round < roundLimit) { + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(PER_RECEIVE_NUM, DEFAULT_WAIT_TIME, outElements)); + LOG(INFO) << "receive num : " << outElements.size() << " ;" << round++; + if (!outElements.empty()) { + remaining -= outElements.size(); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + } + } + } + + /** + * @brief Creates a stream client at the given worker num and timeout + * @param[in] workerNum The worker num to create the stream against + * @param[in] timeout Timeout for RPC requests + * @param[out] spClient Shared pointer to the stream client + * @return status + */ + Status CreateClient(int workerNum, int32_t timeout, std::shared_ptr &spClient) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(workerNum, workerAddress)); + // Create a client with user defined timeout + ConnectOptions connectOptions = { .host = workerAddress.Host(), + .port = workerAddress.Port(), + .connectTimeoutMs = timeout }; + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + spClient = std::make_shared(connectOptions); + RETURN_IF_NOT_OK(spClient->Init()); + return Status::OK(); + } + +protected: + void InitTest() + { + InitStreamClient(0, client_); + } + std::shared_ptr client_ = nullptr; + uint64_t pageSize_; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +class ConsumerDataVerificationTest : public ConsumerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + ConsumerTest::SetClusterSetupOptions(opts); + opts.workerGflagParams += " -enable_stream_data_verification=true "; + } +}; + +TEST_F(ConsumerDataVerificationTest, SendRecvOneBigElement) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testSendRecvBigEle", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testSendRecvBigEle", config, consumer)); + + RandomData rand; + auto data = rand.GetRandomString(defaultProducerConf_.pageSize * 2); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, K_100, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); +} + +TEST_F(ConsumerDataVerificationTest, SendRecvOneSmallElement) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testSendRecvSmallEle", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testSendRecvSmallEle", config, consumer)); + + std::string data = "H"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, K_100, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); +} + +TEST_F(ConsumerTest, GetElementsWhenProducerClosesWithPage0) +{ + const int RECEIVE_WAIT_TIME = 19000; + const int RECEIVE_TIME_COST = 12000; + const int THREAD_WAIT_TIME = 1000; + ThreadPool threadPool(1); + + threadPool.Submit([this]() { + std::shared_ptr cli; + InitStreamClient(0, cli); + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testRecvWhenProdCloseWithPg0", config, consumer)); + + Timer timer; + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(RECEIVE_WAIT_TIME, outElements)); + auto timeCost = static_cast(timer.ElapsedMilliSecond()); + LOG(INFO) << "time cost: " << timeCost; + ASSERT_TRUE(timeCost < RECEIVE_TIME_COST); + ASSERT_EQ(outElements.size(), size_t(1)); + }); + + // wait a safer period to let the other thread call the Receive. + std::this_thread::sleep_for(std::chrono::milliseconds(THREAD_WAIT_TIME)); + + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testRecvWhenProdCloseWithPg0", producer, defaultProducerConf_)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(ConsumerTest, LEVEL2_TestCreateConsumerLongTimeout) +{ + // Request should not timeout if client timeout is set to 10s and master takes more time + + // set timeout to 10 mins + std::shared_ptr client1; + const int32_t timeoutMs = 1000 * 60 * 10; + ASSERT_EQ(CreateClient(0, timeoutMs, client1), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + + // Make master wait for 1 min and it should not timeout + // We actually dont know who is the master so inject in both + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "SCMetadataManager.Subscribe.wait", "1*sleep(60000)")); + + // This request should not timeout as client timeout is 10 mins. + DS_ASSERT_OK(client1->Subscribe("CreateConLongTimeout", config, consumer)); +} + +TEST_F(ConsumerTest, LEVEL2_TestCloseConsumerLongTimeout) +{ + // Request should not timeout if client timeout is set to 10mins and master takes more time + + // set timeout to 10 mins + std::shared_ptr client1; + const int32_t timeoutMs = 1000 * 60 * 10; + ASSERT_EQ(CreateClient(0, timeoutMs, client1), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + + // Make master wait for 1 min and it should not timeout + // We actually dont know who is the master so inject in both + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "SCMetadataManager.CloseConsumer.wait", "1*sleep(60000)")); + + // This request should not timeout as client timeout is 10 mins. + DS_ASSERT_OK(client1->Subscribe("CloseConLongTimeout", config, consumer)); + DS_ASSERT_OK(consumer->Close()); +} + +TEST_F(ConsumerTest, EmptyCaseValidation) +{ + std::string workerip; + int workerport = 0; + + // Test empty workerip is invalid (size 0 is invalid) + LOG(INFO) << "workerip: " << workerip; + LOG(INFO) << "workerport: " << workerport; + std::shared_ptr client1; + ConnectOptions options; + options.host = workerip; + options.port = workerport; + client1 = std::make_shared(options); + DS_ASSERT_NOT_OK(client1->Init()); + + // Test invalid port number + workerip = "0.0.0.0:0"; + LOG(INFO) << "workerip: " << workerip; + LOG(INFO) << "workerport: " << workerport; + std::shared_ptr client2; + options.host = workerip; + options.port = workerport; + client2 = std::make_shared(options); + DS_ASSERT_NOT_OK(client2->Init()); +} + +TEST_F(ConsumerTest, InvalidRecvParams) +{ + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("InvalidRecvParams", config, consumer)); + + // Test invalid parameters + std::vector outElements; + ASSERT_EQ(consumer->Receive(0, 0, outElements).GetCode(), StatusCode::K_INVALID); + ASSERT_EQ(consumer->Receive(0, K_10, outElements).GetCode(), StatusCode::K_INVALID); + + consumer->Close(); +} + +TEST_F(ConsumerTest, NoElementNoProducer) +{ + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("NoEleNoProducer", config, consumer)); + std::vector outElements; + + // Test no element to read + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + + DS_ASSERT_OK(consumer->Receive(1, K_10, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + consumer->Close(); +} + +TEST_F(ConsumerTest, NoElementWithProducer) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("NoEleWithProducer", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("NoEleWithProducer", config, consumer)); + std::vector outElements; + + // Test no element to read + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + DS_ASSERT_OK(consumer->Receive(1, K_10, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + // Read the element + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + + // No element to read when no blocking + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + // No element to read with blocking + DS_ASSERT_OK(consumer->Receive(1, K_10, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + consumer->Close(); +} + +TEST_F(ConsumerTest, RollbackInvalidElement) +{ + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testRollbackInvalidEle", config, consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testRollbackInvalidEle", producer, defaultProducerConf_)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + Element elementInvalid(reinterpret_cast(&data.front()), data.size()); + + // 1. send the invalid element + datasystem::inject::Set("HugeMemoryCopy", "1*return(K_RUNTIME_ERROR)"); + DS_ASSERT_NOT_OK(producer->Send(elementInvalid)); + datasystem::inject::Clear("HugeMemoryCopy"); + + const uint32_t timeoutMs = 1000; + std::vector outElements; + // 2. expect that the invalid element will not be received + DS_ASSERT_OK(consumer->Receive(1, timeoutMs, outElements)); + ASSERT_TRUE(outElements.size() == 0); +} + +TEST_F(ConsumerTest, NoExpectedElementNum) +{ + std::shared_ptr producer; + std::string streamName = "testNoExpectedEleNum"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + std::vector outElements; + // Test no element to read + DS_ASSERT_OK(consumer->Receive(0, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + DS_ASSERT_OK(consumer->Receive(K_10, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + std::string data1 = "Hello World1"; + std::string data2 = "Hello World2"; + Element element1(reinterpret_cast(&data1.front()), data1.size()); + DS_ASSERT_OK(producer->Send(element1)); + Element element2(reinterpret_cast(&data2.front()), data2.size()); + DS_ASSERT_OK(producer->Send(element2)); + // Read one element out of two. Keep the remaining element in client queue. + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + // Using the Receive API with no expectNum, all elements already stored in the client local queue can be read. + // Producer pushes 2 elements, 1 already read, 1 in local queue. Read the element from local queue. + DS_ASSERT_OK(consumer->Receive(0, outElements)); + ASSERT_EQ(outElements.size(), size_t(1)); + std::string actualData2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data2, actualData2); + // Timeout of 0 does not wait for more elements to come. Local queue is empty. Client gets nothing. + DS_ASSERT_OK(consumer->Receive(0, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + + std::string data3 = "Hello World3"; + std::string data4 = "Hello World4"; + Element element3(reinterpret_cast(&data3.front()), data3.size()); + DS_ASSERT_OK(producer->Send(element3)); + Element element4(reinterpret_cast(&data4.front()), data4.size()); + DS_ASSERT_OK(producer->Send(element4)); + + // As expectNum is not set, receive both elements sent by the producer. + DS_ASSERT_OK(consumer->Receive(K_10, outElements)); + ASSERT_EQ(outElements.size(), size_t(K_2)); + std::string actualData3(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data3, actualData3); + std::string actualData4(reinterpret_cast(outElements[1].ptr), outElements[1].size); + EXPECT_EQ(data4, actualData4); +} + +TEST_F(ConsumerTest, DISABLED_ReceiveTimeoutOverwriteRpcTimeout) +{ + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("test1", config, consumer)); + + DS_ASSERT_OK(inject::Set("client.StreamReceive.overwriteRpcTimeout", "call()")); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, UINT32_MAX, outElements), Status::OK()); +} + +TEST_F(ConsumerTest, ConsumerCloseSecondConsumerRecv) +{ + // Create First Producer + std::shared_ptr producer; + const uint64_t maxStreamSize = 1024*1024; + defaultProducerConf_.maxStreamSize = maxStreamSize; + std::string streamName = "OneConCloseOtherConRecv"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + // Create First Consumer + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + const int cacheCapacity = 192; + config.cacheCapacity = cacheCapacity; // Have a low value for cache + // Switch on the AutoAck + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer, true)); + + // Generate data until cache is full + int elementCount = 0; + std::string data; + const uint64_t numElements = 1024; + const uint64_t size = 1024; + std::vector elements = GenerateElements(numElements, size, data); + for (auto &element : elements) { + Status rc = producer->Send(element); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + break; + } else if (rc.IsOk()) { + ++elementCount; + } + } + + // Get all the data and close + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(elementCount, K_100, outElements)); + DS_ASSERT_OK(consumer->Close()); + + // Create another consumer + std::shared_ptr consumer1; + SubscriptionConfig config1("sub2", SubscriptionType::STREAM); + config1.cacheCapacity = cacheCapacity; + // Switch on the AutoAck + DS_ASSERT_OK(client_->Subscribe(streamName, config1, consumer1, true)); + + elementCount = 0; + sleep(K_2); // wait for producer to get unblocked + // Send data to stream until we get OOM + for (auto &element : elements) { + Status rc = producer->Send(element); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + break; + } else if (rc.IsOk()) { + ++elementCount; + } + } + // Start Receiving elements now, this should Ack and clear all elements + outElements.clear(); + DS_ASSERT_OK(consumer1->Receive(K_10, K_100, outElements)); + outElements.clear(); + DS_ASSERT_OK(consumer1->Receive(K_10, K_100, outElements)); + + // Now producer should have space to take atleast an element + sleep(K_2); // wait for producer to get unblocked + DS_ASSERT_OK(producer->Send(elements[0])); + DS_ASSERT_OK(producer->Send(elements[0])); + DS_ASSERT_OK(producer->Send(elements[0])); + DS_ASSERT_OK(producer->Send(elements[0])); +} + +TEST_F(ConsumerTest, SendRecvOneSmallElement) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("SendRecvSmallEle", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("SendRecvSmallEle", config, consumer)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, K_100, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); +} + +TEST_F(ConsumerTest, TestAutoAck1) +{ + // Test that auto ack works for new consumers + const int DEFAULT_WAIT_TIME = 5000; + const int DEFAULT_ELEMENT_SIZE = 180; + const bool ENABLE_AUTO_ACK = true; + std::shared_ptr producer1; + std::string streamName = "testAutoAck1"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer1, defaultProducerConf_)); + std::shared_ptr consumer1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer1, ENABLE_AUTO_ACK)); + const int elementNum = 10000; + + auto func = [](std::shared_ptr &producer, std::shared_ptr &consumer) { + std::thread producerThrd([&producer]() { + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + for (int i = 0; i < elementNum; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + DS_ASSERT_OK(producer->Close()); + }); + std::vector outElements; + int received = 0; + while (received < elementNum) { + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 1); + received++; + } + DS_ASSERT_OK(consumer->Close()); + producerThrd.join(); + }; + + // First round of produce and consumer + func(producer1, consumer1); + + // Create new producer and consumer, for a second round + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config2, consumer2, ENABLE_AUTO_ACK)); + // Make sure here it does not see residue elements + std::vector outElements; + DS_ASSERT_OK(consumer2->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + std::shared_ptr producer2; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer2, defaultProducerConf_)); + + // Second round of produce and consumer + func(producer2, consumer2); + + // Create new consumer only + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config3, consumer3, ENABLE_AUTO_ACK)); + + // Make sure here it does not see residue elements + DS_ASSERT_OK(consumer3->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); +} + +TEST_F(ConsumerTest, TestAutoAck2) +{ + // Test that auto ack works even if the next receive gets no elements + const int DEFAULT_WAIT_TIME = 5000; + const int DEFAULT_ELEMENT_SIZE = 500 * KB; + const int DEFAULT_MAX_STREAM_SIZE = 2 * MB; + const bool ENABLE_AUTO_ACK = true; + std::string streamName = "AutoAck2Test"; + std::shared_ptr producer1; + ProducerConf conf; + conf.maxStreamSize = DEFAULT_MAX_STREAM_SIZE; + conf.pageSize = 1 * MB; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer1, conf)); + std::shared_ptr consumer1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer1, ENABLE_AUTO_ACK)); + const int elementNum = 10; + + std::thread producerThrd([this, &producer1]() { + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + for (int i = 0; i < elementNum; i++) { + SendHelper(producer1, element); + } + DS_ASSERT_OK(producer1->Close()); + }); + std::vector outElements; + int received = 0; + const int MAX_EXPECT_NUM = DEFAULT_MAX_STREAM_SIZE / DEFAULT_ELEMENT_SIZE; + while (received < elementNum) { + DS_ASSERT_OK(consumer1->Receive(MAX_EXPECT_NUM, DEFAULT_WAIT_TIME, outElements)); + LOG(INFO) << "Received element num: " << outElements.size(); + ASSERT_GT(outElements.size(), 0); + received += outElements.size(); + } + DS_ASSERT_OK(consumer1->Close()); + producerThrd.join(); +} + +TEST_F(ConsumerTest, DISABLED_SendRecvOneSmallElementWithTwoFlush) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("test1", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("test1", config, consumer)); + const int K_4 = 4; + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(K_4, K_1000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); + + std::string data2 = "Test"; + Element element2(reinterpret_cast(&data2.front()), data2.size()); + DS_ASSERT_OK(producer->Send(element2)); + outElements.clear(); + ASSERT_EQ(consumer->Receive(K_4, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(K_2)); + std::string actualData2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data2, actualData2); +} + +TEST_F(ConsumerTest, TestReceiveProducerIdle) +{ + // Test consumer receive when a producer is idle + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + const int numElement = 5000; + size_t testSize = 4 * KB; + std::vector writeElement; + writeElement = RandomData().RandomBytes(testSize); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + const int maxStreamSize = 10 * MB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + + std::shared_ptr producer; + std::shared_ptr producer2; + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, conf)); + DS_ASSERT_OK(client_->CreateProducer(streamName, producer2, conf)); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + + std::thread producerThrd([&]() { + for (int i = 0; i < numElement; i++) { + SendHelper(producer, element); + } + for (int i = 0; i < numElement + numElement; i++) { + SendHelper(producer2, element); + } + }); + + ReceiveHelper(consumer, numElement + numElement + numElement); + + producerThrd.join(); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client_->DeleteStream(streamName)); +} + +TEST_F(ConsumerTest, OneSubMultiConsumers1) +{ + // Create one producer and one consumer + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testOSMC1", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testOSMC1", config, consumer)); + + // Write data that less than one page + uint64_t elementSize = 8; + uint64_t onePageElementNum = pageSize_ / elementSize; + std::string data = RandomData().GetRandomString(pageSize_); + for (uint64_t i = 1; i < onePageElementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + // Receive data that less than one page + std::vector outElements; + ASSERT_EQ(consumer->Receive(onePageElementNum - 1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), onePageElementNum - 1); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, onePageElementNum - 1); + ASSERT_EQ(consumer->Ack(outElements.back().id), Status::OK()); +} + +TEST_F(ConsumerTest, OneSubMultiConsumers2) +{ + // Create one producer and one consumer + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testOSMC2", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testOSMC2", config, consumer)); + + // Write less than two page + uint64_t elementSize = 8; + uint64_t twoPageElementNum = 2 * pageSize_ / elementSize; + std::string data = RandomData().GetRandomString(pageSize_); + for (uint64_t i = 1; i < twoPageElementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + // Receive one page data and ack + uint64_t onePageElementNum = twoPageElementNum / K_2; + std::vector outElements; + ASSERT_EQ(consumer->Receive(onePageElementNum, 0, outElements), Status::OK()); + ASSERT_EQ(consumer->Ack(outElements.back().id), Status::OK()); + ASSERT_EQ(outElements.size(), onePageElementNum); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, onePageElementNum); +} + +TEST_F(ConsumerTest, OneSubMultiConsumers3) +{ + // Create one producer and two consumer + std::string streamName("OneSubMultiConsumers3"); + std::vector subNameList{ "sub1", "sub2", "sub3" }; + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer1, consumer2; + SubscriptionConfig config1(subNameList[0], SubscriptionType::STREAM); + SubscriptionConfig config2(subNameList[1], SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config1, consumer1)); + DS_ASSERT_OK(client_->Subscribe(streamName, config2, consumer2)); + + // Write two page data + uint64_t elementSize = 8; + uint64_t twoPageElementNum = K_2 * pageSize_ / elementSize; + std::string data = RandomData().GetRandomString(pageSize_); + for (uint64_t i = 1; i <= twoPageElementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + + // Consumer1 Receive one page data and ack + uint64_t onePageElementNum = twoPageElementNum / K_2; + std::vector outElements; + ASSERT_EQ(consumer1->Receive(onePageElementNum, 0, outElements), Status::OK()); + ASSERT_EQ(consumer1->Ack(outElements.back().id), Status::OK()); + ASSERT_EQ(outElements.size(), onePageElementNum); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, onePageElementNum); + + // Consumer2 Receive one page data but not ack + outElements.clear(); + ASSERT_EQ(consumer2->Receive(onePageElementNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), onePageElementNum); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, onePageElementNum); + + // Create consumer3 and recv data, it should can receive all the data + LOG(INFO) << "Start to create consumer3."; + std::shared_ptr consumer3; + SubscriptionConfig config3(subNameList[2], SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config3, consumer3)); + + outElements.clear(); + ASSERT_EQ(consumer3->Receive(onePageElementNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), onePageElementNum); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, onePageElementNum); +} + +TEST_F(ConsumerTest, RecvWithCache) +{ + std::shared_ptr producer; + int dataNum = 21; + std::string stream0{ "RecvWithCache" }; + ProducerConf conf; + conf.delayFlushTime = 10; + conf.maxStreamSize = TEST_STREAM_SIZE; + DS_ASSERT_OK(client_->CreateProducer(stream0, producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub0", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream0, config, consumer)); + + std::vector dataList; + for (int i = 0; i < dataNum; ++i) { + std::string data = "Test-Data" + std::to_string(i); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + dataList.emplace_back(data); + } + std::vector outElements; + uint64_t cursor = 1; + Timer timer; + while (timer.ElapsedSecond() <= 1) { + std::vector oneTimeOutElements; + consumer->Receive(K_5, K_100, oneTimeOutElements); + if (!oneTimeOutElements.empty()) { + LOG(INFO) << FormatString("Receive %d element", oneTimeOutElements.size()); + for (int i = 0; i < static_cast(oneTimeOutElements.size()); ++i) { + ASSERT_EQ(oneTimeOutElements[i].id, cursor + i); + } + outElements.insert(outElements.end(), oneTimeOutElements.begin(), oneTimeOutElements.end()); + cursor += oneTimeOutElements.size(); + } + } + ASSERT_EQ(outElements.size(), static_cast(dataNum)); + ASSERT_EQ(outElements[0].id, size_t(1)); + for (int i = 0; i < dataNum; ++i) { + std::string receivedData(reinterpret_cast(outElements[i].ptr), outElements[i].size); + EXPECT_EQ(dataList[i], receivedData); + LOG(INFO) << FormatString("No:%d, Data:%s", i, receivedData); + } +} + +TEST_F(ConsumerTest, MultiSubsMultiConsumers) +{ + // Create one producer and two subscription with two consumer firstly + std::string streamName("MultiSubMultiCon"); + std::string sub1("sub1"); + std::string sub2("sub2"); + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer1; + DS_ASSERT_OK(client_->Subscribe(streamName, SubscriptionConfig(sub1, SubscriptionType::STREAM), consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(client_->Subscribe(streamName, SubscriptionConfig(sub2, SubscriptionType::STREAM), consumer2)); + + // Write two page elements + uint64_t elementSize = 8; + uint64_t twoPageElementNum = K_2 * pageSize_ / elementSize; + std::string data = RandomData().GetRandomString(pageSize_); + for (uint64_t i = 1; i <= twoPageElementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + + // Consumer1 Receive one page data and ack + uint64_t onePageElementNum = twoPageElementNum / K_2; + std::vector outElements; + ASSERT_EQ(consumer1->Receive(onePageElementNum, 0, outElements), Status::OK()); + ASSERT_EQ(consumer1->Ack(outElements.back().id), Status::OK()); + ASSERT_EQ(outElements.size(), onePageElementNum); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, onePageElementNum); + + // Consumer2 Receive one page data but not ack + outElements.clear(); + ASSERT_EQ(consumer2->Receive(onePageElementNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), onePageElementNum); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, onePageElementNum); +} + +TEST_F(ConsumerTest, SingleConsumerReceiveNotEnoughElementAfterTimeout) +{ + // Create one producer and one subscription two consumer firstly + std::string streamName("SingleConRecvNotEnough"); + std::string sub1("sub0"); + std::unordered_map info; + info[streamName].producerNum = 1; + info[streamName].subscriptions[sub1] = std::make_pair(SubscriptionType::STREAM, 1); + std::unordered_map output; + DS_ASSERT_OK(CreateProducersAndConsumers(info, output)); + + // Write elements + uint64_t elementSize = 8; + int elementNum = 10; + std::string data = RandomData().GetRandomString(K_100 * elementSize); + for (int i = 1; i <= elementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(output[streamName].producers[0]->Send(element), Status::OK()); + } + + // Consumer1 want recv 20 elements at the same time but only get 10 after timeout + Consumer *consumer = output[streamName].consumers[sub1][0].get(); + int expectNum = 20; + std::vector outElements; + LOG(INFO) << FormatString("Consumer Start recv %d elements.", expectNum); + ASSERT_EQ(consumer->Receive(expectNum, K_1000, outElements), Status::OK()); + LOG(INFO) << FormatString("Consumer Received %d elements.", outElements.size()); + ASSERT_EQ(consumer->Ack(outElements.back().id), Status::OK()); + ASSERT_EQ(outElements.size(), static_cast(elementNum)); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, static_cast(elementNum)); +} + +TEST_F(ConsumerTest, MultiConsumerReceiveNotEnoughElementAfterTimeout) +{ + // Create one producer and one subscription two consumer firstly + std::string streamName("MultiConRecvNotEnough"); + std::string sub1("sub0"); + std::string sub2("sub1"); + std::unordered_map info; + info[streamName].producerNum = 1; + info[streamName].subscriptions[sub1] = std::make_pair(SubscriptionType::STREAM, 1); + info[streamName].subscriptions[sub2] = std::make_pair(SubscriptionType::STREAM, 1); + std::unordered_map output; + DS_ASSERT_OK(CreateProducersAndConsumers(info, output)); + + // Write elements + uint64_t elementSize = 8; + int elementNum = 10; + std::string data = RandomData().GetRandomString(K_100 * elementSize); + for (int i = 1; i <= elementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(output[streamName].producers[0]->Send(element), Status::OK()); + } + + // Consumer1 and Consumer2 want recv 20 elements at the same time but only get 10 after timeout + std::vector threads; + for (size_t i = 0; i < output[streamName].consumers.size(); i++) { + threads.emplace_back([streamName, elementNum, i, &output]() { + auto subName = "sub" + std::to_string(i); + Consumer *consumer = output[streamName].consumers[subName][0].get(); + int expectNum = 20; + std::vector outElements; + LOG(INFO) << FormatString("Consumer %d Start recv %d elements.", i, expectNum); + ASSERT_EQ(consumer->Receive(expectNum, K_1000, outElements), Status::OK()); + LOG(INFO) << FormatString("Consumer %d Received %d elements.", i, outElements.size()); + ASSERT_EQ(outElements.size(), static_cast(elementNum)); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, static_cast(elementNum)); + }); + } + for (size_t i = 0; i < threads.size(); i++) { + threads[i].join(); + } +} + +TEST_F(ConsumerTest, ReceiveEnoughElementBeforeTimeout) +{ + // Create one producer and one subscription two consumer firstly + std::string streamName("RecvEleBeforeTimeout"); + std::string sub1("sub0"); + std::string sub2("sub1"); + std::unordered_map info; + info[streamName].producerNum = 1; + info[streamName].subscriptions[sub1] = std::make_pair(SubscriptionType::STREAM, 1); + info[streamName].subscriptions[sub2] = std::make_pair(SubscriptionType::STREAM, 1); + std::unordered_map output; + DS_ASSERT_OK(CreateProducersAndConsumers(info, output)); + + // Write elements + uint64_t elementSize = 8; + int elementNum = 10; + std::string data = RandomData().GetRandomString(K_100 * elementSize); + for (int i = 1; i <= elementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(output[streamName].producers[0]->Send(element), Status::OK()); + } + LOG(INFO) << "Start to flush first 10 elements."; + + // Consumer1 and Consumer2 want recv 20 elements at the same time And get 20 before timeout + std::vector threads; + for (size_t i = 0; i < output[streamName].consumers.size(); i++) { + threads.emplace_back([streamName, i, &output]() { + auto subName = "sub" + std::to_string(i); + Consumer *consumer = output[streamName].consumers[subName][0].get(); + int expectNum = K_20; + std::vector outElements; + LOG(INFO) << FormatString("Consumer %d Start recv %d elements.", i, expectNum); + ASSERT_EQ(consumer->Receive(expectNum, K_5000, outElements), Status::OK()); + LOG(INFO) << FormatString("Consumer %d Received %d elements.", i, outElements.size()); + ASSERT_EQ(outElements.size(), static_cast(expectNum)); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, size_t(K_20)); + ASSERT_EQ(outElements.back().id, static_cast(expectNum)); + }); + } + sleep(K_2); + for (int i = 1; i <= elementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(output[streamName].producers[0]->Send(element), Status::OK()); + } + LOG(INFO) << "Start to flush second 10 elements."; + for (size_t i = 0; i < threads.size(); i++) { + threads[i].join(); + } +} + +TEST_F(ConsumerTest, InvalidAck) +{ + std::string streamName("TestInvalidAck"); + std::shared_ptr producer; + std::shared_ptr consumer1, consumer2; + ASSERT_EQ(client_->CreateProducer(streamName, producer, defaultProducerConf_), Status::OK()); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config1, consumer1)); + + uint64_t elementSize = 8; + int elementNum = 10; + int invalidAckNum = 20; + std::string data = RandomData().GetRandomString(K_100 * elementSize); + Element element(reinterpret_cast(&data.front()), elementSize); + for (int i = 1; i <= elementNum; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + std::vector outElements; + DS_ASSERT_OK(consumer1->Receive(elementNum, K_1000, outElements)); + ASSERT_EQ(outElements.size(), size_t(elementNum)); + DS_ASSERT_NOT_OK(consumer1->Ack(invalidAckNum)); + DS_ASSERT_OK(consumer1->Ack(elementNum)); + DS_ASSERT_OK(consumer1->Close()); + + outElements.clear(); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config2, consumer2)); + DS_ASSERT_OK(consumer2->Receive(K_100, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); +} + +class DefaultPageTest : public ConsumerTest { +public: + DefaultPageTest() : ConsumerTest(1024 * 1024) + { + } +}; + +TEST_F(DefaultPageTest, InvalidAck2) +{ + LOG(INFO) << "Producer send multi small element(element smaller than the 1/16 of buffer)."; + // create pub + std::string streamName = "stream_001_01"; + std::shared_ptr producer; + ASSERT_EQ(client_->CreateProducer(streamName, producer, defaultProducerConf_), Status::OK()); + + LOG(INFO) << "create producer successfully."; + // create sub and consumer + std::string subName = "stream_001_01_sub_03"; + std::shared_ptr consumer; + SubscriptionConfig config(subName, SubscriptionType::STREAM); + ASSERT_EQ(client_->Subscribe(streamName, config, consumer), Status::OK()); + LOG(INFO) << "Create consumer successfully."; + // generate random buffer + size_t testSize1 = 4096; + std::string writeBuffer = RandomData().GetRandomString(testSize1); + int elementNum = 257; + const int K_157 = 157, K_158 = 158; + for (int i = 0; i < elementNum; i++) { + // Write and flush one element + Element element(reinterpret_cast(&writeBuffer.front()), writeBuffer.size()); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + + // Consumer Receive + std::vector outElements; + ASSERT_EQ(consumer->Receive(K_157, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(K_157)); + ASSERT_EQ(outElements.front().id, size_t(1)); + ASSERT_EQ(outElements.back().id, size_t(K_157)); + + // Producer Close + ASSERT_EQ(producer->Close(), Status::OK()); + LOG(INFO) << "Close producer successfully."; + + // ack + ASSERT_EQ(consumer->Ack(K_158).GetCode(), StatusCode::K_INVALID); + ASSERT_EQ(consumer->Ack(K_157), Status::OK()); + std::vector outElements2; + ASSERT_EQ(consumer->Receive(K_100, 0, outElements2), Status::OK()); +} + +TEST_F(ConsumerTest, DelayFlushSendRecvOneSmallElement) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("auto_flush_test", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("auto_flush_test", config, consumer)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, K_10, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + LOG(INFO) << "receive data is :" << actualData; + ASSERT_EQ(data, actualData); +} + +TEST_F(ConsumerTest, SendElementAndAutoFlushWithoutDelay) +{ + ProducerConf conf; + conf.delayFlushTime = 0; + conf.maxStreamSize = TEST_STREAM_SIZE; + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("auto_no_delay_test", producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("auto_no_delay_test", config, consumer)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, K_100, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); +} + +TEST_F(ConsumerTest, ContinuousSendRecvSmallElement) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("continuous_send_test", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("continuous_send_test", config, consumer)); + + std::string data[K_10]; + for (int i = 0; i < K_10; i++) { + data[i] = "Hello World" + std::to_string(i); + uint8_t id = i + 1; + Element element(reinterpret_cast(&data[i].front()), data[i].size()); + element.id = id; + DS_ASSERT_OK(producer->Send(element)); + } + std::vector outElements; + ASSERT_EQ(consumer->Receive(K_10, K_1000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(K_10)); + for (int i = 0; i < K_10; i++) { + ASSERT_EQ(outElements[i].id, static_cast(i + 1)); + std::string actualData(reinterpret_cast(outElements[i].ptr), outElements[i].size); + ASSERT_EQ(data[i], actualData); + } +} + +TEST_F(ConsumerTest, SendElementWhenFlashSmallElement) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("sametime_send_test", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("sametime_send_test", config, consumer)); + const int K_15 = 15; + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + std::string data2 = "Test"; + Element element2(reinterpret_cast(&data2.front()), data2.size()); + DS_ASSERT_OK(producer->Send(element)); + std::this_thread::sleep_for(std::chrono::milliseconds(K_5)); + DS_ASSERT_OK(producer->Send(element2)); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(K_2, K_15, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(K_2)); + ASSERT_EQ(outElements[0].id, size_t(1)); + ASSERT_EQ(outElements[1].id, size_t(K_2)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); + std::string actualData2(reinterpret_cast(outElements[1].ptr), outElements[1].size); + EXPECT_EQ(data2, actualData2); +} + +TEST_F(ConsumerTest, FlashAgainWhenAutoFlashSmallElement) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("sametime_flush_test", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("sametime_flush_test", config, consumer)); + const int K_5 = 5; + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::this_thread::sleep_for(std::chrono::milliseconds(K_5)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + ASSERT_EQ(data, actualData); +} + +TEST_F(ConsumerTest, ReceiveCacheTest) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("cache_test", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("cache_test", config, consumer)); + + std::string data = "a"; + Element element(reinterpret_cast(&data.front()), data.size()); + for (int i = 0; i < K_20; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + std::vector outElements; + for (int i = 0; i < K_20; i++) { + outElements.clear(); + ASSERT_EQ(consumer->Receive(1, K_1000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + ASSERT_EQ(data, actualData); + outElements.clear(); + } + + // Receive expect greater than cache. Should be no more elements to recv + DS_ASSERT_OK(consumer->Receive(1, K_1000, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); +} + +TEST_F(ConsumerTest, TestInfiniteWaitRecv) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("infinite_wait_test", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("infinite_wait_test", config, consumer)); + + auto producerThread = std::make_unique([&producer]() { + size_t testSize1 = 4096; + std::string writeBuffer = RandomData().GetRandomString(testSize1); + sleep(K_5); + Element element(reinterpret_cast(&writeBuffer.front()), writeBuffer.size()); + ASSERT_EQ(producer->Send(element), Status::OK()); + }); + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, -1, outElements)); + producerThread->join(); +} + +TEST_F(ConsumerTest, TestSpecialOrder) +{ + std::shared_ptr producer; + std::shared_ptr consumer; + std::string streamName = "specialOrderTest"; + const int K_2 = 2; + + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + // after worker write data to client, client send the next receive before worker remove the pending receive. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.stream.after_send_pending", "sleep(500)")); + // worker execute the pending receive timer before add the pending receive task. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.stream.before_add_pending", "sleep(500)")); + + for (int i = 0; i < K_2; i++) { + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(K_10, 1, outElements)); + } +} + +TEST_F(ConsumerTest, GetStatisticsMessage1) +{ + // Create one producer and one consumer + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testGetStatMsg1", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testGetStatMsg1", config, consumer)); + + // Write data that less than one page + uint64_t elementSize = 8; + uint64_t onePageElementNum = pageSize_ / elementSize; + std::string data = RandomData().GetRandomString(pageSize_); + for (uint64_t i = 1; i < onePageElementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + + std::vector outElements; + ASSERT_EQ(consumer->Receive(onePageElementNum - 1, 0, outElements), Status::OK()); + uint64_t recEle; + uint64_t notProcEle; + consumer->GetStatisticsMessage(recEle, notProcEle); + ASSERT_EQ(recEle, outElements.size()); + ASSERT_EQ(notProcEle, recEle); + ASSERT_EQ(consumer->Ack(outElements.back().id), Status::OK()); + consumer->GetStatisticsMessage(recEle, notProcEle); + ASSERT_EQ(recEle, outElements.size()); + ASSERT_EQ(notProcEle, 0u); +} + +TEST_F(ConsumerTest, GetStatisticsMessage2) +{ + // Create one producer and one consumer + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testGetStatMsgTwo", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testGetStatMsgTwo", config, consumer)); + + // Write data that less than one page + uint64_t elementSize = 8; + uint64_t onePageElementNum = pageSize_ / elementSize; + std::string data = RandomData().GetRandomString(pageSize_); + for (uint64_t i = 1; i < onePageElementNum; i++) { + Element element(reinterpret_cast(&data.front()), elementSize); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + + std::vector outElements; + ASSERT_EQ(consumer->Receive(onePageElementNum - 1, 0, outElements), Status::OK()); + uint64_t recEle; + uint64_t notProcEle; + consumer->GetStatisticsMessage(recEle, notProcEle); + ASSERT_EQ(recEle, outElements.size()); + ASSERT_EQ(notProcEle, recEle); + + for (auto ele : outElements) { + ASSERT_EQ(consumer->Ack(ele.id), Status::OK()); + consumer->GetStatisticsMessage(recEle, notProcEle); + ASSERT_EQ(recEle, outElements.size()); + ASSERT_EQ(notProcEle, recEle - ele.id); + } +} + + +TEST_F(ConsumerTest, GetStatisticsMessage3) +{ + // Test GetStatistics with AutoAck and more than one page + // Create one producer and one consumer + std::shared_ptr producer; + ProducerConf conf = {.delayFlushTime = 5, .pageSize = K_2 * K_2 * KB}; + DS_ASSERT_OK(client_->CreateProducer("testGetStatMsgThree", producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testGetStatMsgThree", config, consumer, true)); + + const int DEFAULT_WAIT_TIME = 5000; + uint64_t recEle, notProcEle; + std::string data = RandomData().GetRandomString(1 * KB); + Element element(reinterpret_cast(&data.front()), data.size()); + data = RandomData().GetRandomString(1); + Element element2(reinterpret_cast(&data.front()), data.size()); + + consumer->GetStatisticsMessage(recEle, notProcEle); + ASSERT_EQ(recEle, 0); + ASSERT_EQ(notProcEle, 0); + // producer sends data (3.5 pages of elements) + // First 3 pages have 3x 1KB elements, 4th page has 1x 1KB, 1000x 1B elements + for (int i = 0; i < K_10; i++) { + SendHelper(producer, element); + } + for (int i = 0; i < K_1000; i++) { + SendHelper(producer, element2); + } + + std::vector outElements; + for (int i = 0; i < K_10 + K_1000; i++) { + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 1); + } + // Last Receive to trigger auto-ack on last element + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + consumer->GetStatisticsMessage(recEle, notProcEle); + ASSERT_EQ(recEle, K_10 + K_1000); + ASSERT_EQ(notProcEle, 0); +} + + +class SCClientZmqCurveTest : public ConsumerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = workerCount; + opts.enableLivenessProbe = true; + ConsumerTest::SetClusterSetupOptions(opts); + // use default configurations for all the other zmq curve gflags settings + opts.numEtcd = 1; + opts.vLogLevel = defaultLogLevel; + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + } + + void TearDown() override + { + client1_.reset(); + client2_.reset(); + ExternalClusterTest::TearDown(); + } + +protected: + Status InitTest() + { + InitStreamClient(0, client1_); + InitStreamClient(1, client2_); + return Status::OK(); + } + std::shared_ptr client1_ = nullptr; + std::shared_ptr client2_ = nullptr; + const uint32_t workerCount = 2; + const uint32_t defaultLogLevel = 3; +}; + +/* +On same node. Producer created before consumer. Consumer calls receive before producer send. +Consumer should be able to receive all data from producer. +*/ +TEST_F(ConsumerTest, ReceiveThenSend) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testRecvThenSend", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testRecvThenSend", config, consumer)); + + int timeOut = 10000; + std::vector outElements; + std::thread receiveThread([&]() { ASSERT_EQ(consumer->Receive(1, timeOut, outElements), Status::OK()); }); + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + sleep(1); + receiveThread.join(); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); +} + +/* +On same node. Producer created and calls send before consumer creation. Consumer should not +be able to receive any data because of late create/subscribe. +*/ +TEST_F(ConsumerTest, SendBeforeConsumerCreate) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("SendBeforeConCreate", producer, defaultProducerConf_)); + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + int timeOut = 100; + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("SendBeforeConCreate", config, consumer)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, timeOut, outElements), Status::OK()); + // Cannot receive sent element when Consumer create after send + ASSERT_EQ(outElements.size(), size_t(0)); +} + +/* +On same node. Consumer created before producer. Producer sends data before consumer receive. +Consumer should be able to receive all data from producer. +*/ +TEST_F(ConsumerTest, CreateConsumerBeforeProducer) +{ + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("CreateConBeforeProd", config, consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("CreateConBeforeProd", producer, defaultProducerConf_)); + + int timeOut = 100; + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, timeOut, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); +} + +/* +On same node. Consumer created before producer. Consumer receives before producer sends data. +Consumer should be able to receive all data from producer. +*/ +TEST_F(ConsumerTest, CreateConsumerFirstReceiveThenSend) +{ + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testRecvBeforeSend", config, consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testRecvBeforeSend", producer, defaultProducerConf_)); + + int timeOut = 10000; + std::vector outElements; + std::thread receiveThread([&]() { ASSERT_EQ(consumer->Receive(1, timeOut, outElements), Status::OK()); }); + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + sleep(1); + receiveThread.join(); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); +} + +TEST_F(ConsumerTest, TestDoubleClose) +{ + // Test that if CloseConsumer RPC fails in ConsumerImpl::Close(), + // the implicit Close triggered by the destructor will not double release the page. + const int maxStreamSize = 10 * MB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testDoubleClose", producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testDoubleClose", config, consumer)); + + const size_t DEFAULT_ELEMENT_SIZE = 500 * KB; + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, K_1000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + // Call close and expect the injected failure + datasystem::inject::Set("ConsumerImpl.CloseConsumerRPC.Fail", "1*return(K_RPC_UNAVAILABLE)"); + DS_ASSERT_NOT_OK(consumer->Close()); + consumer.reset(); + const int numElement = 10; + for (int i = 0; i < numElement; i++) { + SendHelper(producer, element); + } +} + +TEST_F(ConsumerTest, TestIdempotentClose) +{ + // Test that if producer/consumer is closed, calling Close() again will return OK + const int maxStreamSize = 10 * MB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("testIdempotentClose", producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testIdempotentClose", config, consumer)); + + const size_t DEFAULT_ELEMENT_SIZE = 500 * KB; + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, K_1000, outElements), Status::OK()); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + ASSERT_EQ(outElements.size(), size_t(1)); + // First close + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + // Other methods should return error + DS_ASSERT_NOT_OK(producer->Send(element)); + DS_ASSERT_NOT_OK(consumer->Receive(1, K_1000, outElements)); + DS_ASSERT_NOT_OK(consumer->Ack(outElements.back().id)); + // Second close should return OK + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); +} + +class SPMCTest : public ConsumerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = numWorkers; + opts.numEtcd = numEtcd; + opts.numRpcThreads = numRpcThreads; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + client3_ = nullptr; + ExternalClusterTest::TearDown(); + } + + /* + Help setup test-case configurtation where prod/cons are located on same/diff node + */ + void CreateProducerAndConsumerHelper(int workerNum, std::shared_ptr &p1, std::shared_ptr &c1, + std::shared_ptr &c2, std::string stream) + { + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + if (workerNum == ONE_WORKER) { + DS_ASSERT_OK(client1_->CreateProducer(stream, p1, defaultProducerConf_)); + DS_ASSERT_OK(client1_->Subscribe(stream, config1, c1)); + DS_ASSERT_OK(client1_->Subscribe(stream, config2, c2)); + } else if (workerNum == TWO_WORKER) { + DS_ASSERT_OK(client1_->CreateProducer(stream, p1, defaultProducerConf_)); + DS_ASSERT_OK(client2_->Subscribe(stream, config1, c1)); + DS_ASSERT_OK(client2_->Subscribe(stream, config2, c2)); + } else if (workerNum == THREE_WORKER) { + DS_ASSERT_OK(client1_->CreateProducer(stream, p1, defaultProducerConf_)); + DS_ASSERT_OK(client2_->Subscribe(stream, config1, c1)); + DS_ASSERT_OK(client3_->Subscribe(stream, config2, c2)); + } + } + +protected: + void InitTest() + { + uint32_t workerIndex = 0; + HostPort workerAddress1; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex++, workerAddress1)); + HostPort workerAddress2; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex++, workerAddress2)); + HostPort workerAddress3; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex, workerAddress3)); + InitStreamClient(0, client1_); + InitStreamClient(1, client2_); + InitStreamClient(2, client3_); // worker index is 2 + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + std::shared_ptr client1_ = nullptr; + std::shared_ptr client2_ = nullptr; + std::shared_ptr client3_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + + const int ONE_WORKER = 1; + const int TWO_WORKER = 2; + const int THREE_WORKER = 3; + + // cluster config + int numWorkers = 3; + int numEtcd = 1; + int numRpcThreads = 0; +}; + +/* +Create 1 producer 2 consumer. Create send thread for producer. Create receive thread for consumer2. +Consumer1 closes during data send and receive. Wait for threads. Consumer2 still receives all data +sent by producer. +*/ +TEST_F(SPMCTest, CloseOneConsumerDuringSendAndReceive) +{ + for (int workers = 1; workers <= numWorkers; workers++) { + LOG(INFO) << "Start test with configuration: " << workers; + std::shared_ptr producer; + std::shared_ptr consumer1; + std::shared_ptr consumer2; + std::string streamName = "Close1ConDuringSendRecv" + std::to_string(workers); + CreateProducerAndConsumerHelper(workers, producer, consumer1, consumer2, streamName); + + int timeOut = 10000; + int numElement = 10; + std::string data[numElement]; + std::vector outElements; + LOG(INFO) << "outElements size: " << outElements.size(); + std::thread sendThread([&]() { + for (int i = 0; i < numElement; i++) { + data[i] = "Hello World" + std::to_string(i); + uint8_t id = i + 1; + Element element(reinterpret_cast(&data[i].front()), data[i].size()); + element.id = id; + DS_ASSERT_OK(producer->Send(element)); + } + }); + std::thread receiveThread( + [&]() { ASSERT_EQ(consumer2->Receive(numElement, timeOut, outElements), Status::OK()); }); + DS_ASSERT_OK(consumer1->Close()); + sendThread.join(); + receiveThread.join(); + + ASSERT_EQ(outElements.size(), size_t(numElement)); + LOG(INFO) << "outElements size: " << outElements.size(); + LOG(INFO) << "End test with configuration: " << workers; + } +} + +/* +Create 1 producer 2 consumers. Create a send thread for producer to send. Both consumer1 and +consumer2 are closed during send. Nothing should be received and outElements still empty. +Producer sends normally. +*/ +TEST_F(SPMCTest, CloseBothConsumerDuringSend) +{ + for (int workers = 1; workers <= numWorkers; workers++) { + LOG(INFO) << "Start test with configuration: " << workers; + std::shared_ptr producer; + std::shared_ptr consumer1; + std::shared_ptr consumer2; + std::string streamName = "Close2ConDuringSendRecv" + std::to_string(workers); + CreateProducerAndConsumerHelper(workers, producer, consumer1, consumer2, streamName); + + int numElement = 10; + std::string data[numElement]; + std::vector outElements; + std::thread sendThread([&]() { + for (int i = 0; i < numElement; i++) { + data[i] = "Hello World" + std::to_string(i); + uint8_t id = i + 1; + Element element(reinterpret_cast(&data[i].front()), data[i].size()); + element.id = id; + DS_ASSERT_OK(producer->Send(element)); + } + }); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + sendThread.join(); + + ASSERT_EQ(outElements.size(), size_t(0)); + LOG(INFO) << "End test with configuration: " << workers; + } +} + +/* +Create 1 producer 2 consumers. create 3 threads. Two for receive, one for each consumer. +One thread for producer1 send. Close producer1 during receive and send. Create new producer2 +to continue sending. Both consumers should be able to receive all elements from both producers. +*/ +TEST_F(SPMCTest, DISABLED_NewProducerSend) +{ + for (int workers = 1; workers <= numWorkers; workers++) { + LOG(INFO) << "Start test with configuration: " << workers; + std::shared_ptr producer1; + std::shared_ptr consumer1; + std::shared_ptr consumer2; + std::string streamName = "testNewProducerSend" + std::to_string(workers); + CreateProducerAndConsumerHelper(workers, producer1, consumer1, consumer2, streamName); + + int timeOut = 10000; + int numElement = 10; + int totalElement = 20; + std::string data[numElement]; + std::vector outElements1; + std::vector outElements2; + + std::thread receiveThread1( + [&]() { ASSERT_EQ(consumer1->Receive(totalElement, timeOut, outElements1), Status::OK()); }); + std::thread receiveThread2( + [&]() { ASSERT_EQ(consumer2->Receive(totalElement, timeOut, outElements2), Status::OK()); }); + std::thread sendThread([&]() { + for (int i = 0; i < numElement; i++) { + data[i] = "Hello World" + std::to_string(i); + uint8_t id = i + 1; + Element element(reinterpret_cast(&data[i].front()), data[i].size()); + element.id = id; + DS_ASSERT_OK(producer1->Send(element)); + } + }); + const int SLEEP_TIME = 10; + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); + DS_ASSERT_OK(producer1->Close()); + std::shared_ptr producer2; + DS_ASSERT_OK(client1_->CreateProducer(streamName, producer2, defaultProducerConf_)); + for (int i = 0; i < numElement; i++) { + data[i] = "Hello World" + std::to_string(i); + uint8_t id = i + 1; + Element element(reinterpret_cast(&data[i].front()), data[i].size()); + element.id = id; + DS_ASSERT_OK(producer2->Send(element)); + } + receiveThread1.join(); + receiveThread2.join(); + sendThread.join(); + + ASSERT_EQ(outElements1.size(), size_t(totalElement)); + ASSERT_EQ(outElements2.size(), size_t(totalElement)); + + LOG(INFO) << "End test with configuration: " << workers; + } +} + +class StreamReserveMemoryTest : public SPMCTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = numWorkers; + opts.numEtcd = numEtcd; + opts.numRpcThreads = numRpcThreads; + opts.workerGflagParams = " -sc_local_cache_memory_size_mb=2 -page_size=" + std::to_string(pageSize_); + SCClientCommon::SetClusterSetupOptions(opts); + } +}; + +TEST_F(StreamReserveMemoryTest, TestReserveLocalCacheMemory1) +{ + // Make sure that Subcribe can be rejected if memory reservation fails for local cache memory. + // sc_local_cache_memory_size_mb is set to 2MB, so it can only accept 2 streams in this case. + std::string streamName1 = "testReservelocalCacheMem1"; + std::string streamName2 = "testReservelocalCacheMemTwo"; + std::string streamName3 = "testReservelocalCacheMemThree"; + ProducerConf conf; + conf.pageSize = MB; + conf.maxStreamSize = TEST_STREAM_SIZE; + std::shared_ptr producer1; + DS_ASSERT_OK(client1_->CreateProducer(streamName1, producer1, conf)); + std::shared_ptr producer2; + DS_ASSERT_OK(client1_->CreateProducer(streamName2, producer2, conf)); + std::shared_ptr producer3; + DS_ASSERT_OK(client1_->CreateProducer(streamName3, producer3, conf)); + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2_->Subscribe(streamName1, config1, consumer1)); + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2_->Subscribe(streamName2, config2, consumer2)); + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + DS_ASSERT_NOT_OK(client2_->Subscribe(streamName3, config3, consumer3)); +} + +TEST_F(StreamReserveMemoryTest, TestReserveLocalCacheMemory2) +{ + // Make sure that CloseConsumer would early reclaim the reservation. + // sc_local_cache_memory_size_mb is set to 2MB, so it can only accept 2 streams with consumers in this case. + std::string streamName1 = "testReservelocalCacheMem1"; + std::string streamName2 = "testReservelocalCacheMemTwo"; + std::string streamName3 = "testReservelocalCacheMemThree"; + ProducerConf conf; + conf.pageSize = MB; + conf.maxStreamSize = TEST_STREAM_SIZE; + std::shared_ptr producer1; + DS_ASSERT_OK(client1_->CreateProducer(streamName1, producer1, conf)); + std::shared_ptr producer2; + DS_ASSERT_OK(client1_->CreateProducer(streamName2, producer2, conf)); + std::shared_ptr producer3; + DS_ASSERT_OK(client1_->CreateProducer(streamName3, producer3, conf)); + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2_->Subscribe(streamName1, config1, consumer1)); + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2_->Subscribe(streamName2, config2, consumer2)); + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + // Consumer cannot be created because the 2 streams used up the local cache memory. + DS_ASSERT_NOT_OK(client2_->Subscribe(streamName3, config3, consumer3)); + // Now after we close consumer2, local cache memory is reclaimed, so consumer3 can be created. + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(client2_->Subscribe(streamName3, config3, consumer3)); + // Now we do the opposite to make sure the reservation can still be done for stream2. + DS_ASSERT_OK(consumer3->Close()); + DS_ASSERT_OK(client2_->Subscribe(streamName2, config3, consumer2)); + DS_ASSERT_NOT_OK(client2_->Subscribe(streamName3, config3, consumer3)); +} + +TEST_F(ConsumerTest, TestOneMsTimeout) +{ + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("TestTimeoutMs", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("TestTimeoutMs", config, consumer)); + + int dataNum = 5; + std::vector dataList; + for (int i = 0; i < dataNum; ++i) { + std::string data = "Test-Data" + std::to_string(i); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + dataList.emplace_back(data); + } + + std::vector outElements; + ASSERT_EQ(consumer->Receive(dataNum, 1, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(dataNum)); +} + +TEST_F(ConsumerTest, TestParallelConsumerUse) +{ + std::shared_ptr producer; + defaultProducerConf_.retainForNumConsumers = 1; + DS_ASSERT_OK(client_->CreateProducer("testParallelConsumerUse", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("testParallelConsumerUse", config, consumer)); + + std::string data = "Hello"; + Element element(reinterpret_cast(&data.front()), data.size()); + + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer->Send(element)); + + DS_ASSERT_OK(datasystem::inject::Set("CheckAndSetInUse.success.sleep", "sleep(5000)")); + + // Create a consumer thread that Receive() last at least 5 seconds. + ThreadPool pool(1); + auto consumerReceiveFunc([&consumer]() { + std::vector outElements; + return consumer->Receive(1, RPC_TIMEOUT, outElements); + }); + std::future fut = pool.Submit([&consumerReceiveFunc]() { return consumerReceiveFunc(); }); + + sleep(1); + + // Parallel call from the same consumer should fail. + StatusCode expectedCode = K_SC_STREAM_IN_USE; + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, RPC_TIMEOUT, outElements).GetCode(), expectedCode); + ASSERT_EQ(consumer->Ack(0).GetCode(), expectedCode); + ASSERT_EQ(consumer->Close().GetCode(), expectedCode); + + DS_ASSERT_OK(fut.get()); + + DS_ASSERT_OK(datasystem::inject::Clear("CheckAndSetInUse.success.sleep")); + + DS_ASSERT_OK(consumer->Receive(1, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), 1); + DS_ASSERT_OK(consumer->Ack(outElements[0].id)); + DS_ASSERT_OK(consumer->Close()); +} + +TEST_F(ConsumerTest, TestRecvDelay) +{ + std::shared_ptr producer; + const int delayFlushTimeMs = 10; + defaultProducerConf_.delayFlushTime = delayFlushTimeMs; + DS_ASSERT_OK(client_->CreateProducer("stream", producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("stream", config, consumer)); + + std::string data = "Hello"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + Timer timer; + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, RPC_TIMEOUT, outElements)); + double delayLimit = 500; // ms + ASSERT_LT(timer.ElapsedMilliSecond(), delayLimit); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/delete_stream_concurrent_test.cpp b/tests/st/client/stream_cache/delete_stream_concurrent_test.cpp new file mode 100644 index 0000000..12c34da --- /dev/null +++ b/tests/st/client/stream_cache/delete_stream_concurrent_test.cpp @@ -0,0 +1,388 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include +#include + +#include + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/common/util/thread_pool.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +constexpr int K_TWO = 2; +constexpr int K_TEN = 10; +class DeleteStreamConcurrentTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 3; + opts.numEtcd = 1; + opts.numRpcThreads = 0; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + client3_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + void InitTest() + { + HostPort workerAddress1; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress1)); + HostPort workerAddress2; + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress2)); + HostPort workerAddress3; + DS_ASSERT_OK(cluster_->GetWorkerAddr(2, workerAddress3)); + LOG(INFO) << FormatString("\n Worker1: <%s>\n Worker2: <%s>\n Worker3: <%s>", workerAddress1.ToString(), + workerAddress2.ToString(), workerAddress3.ToString()); + InitStreamClient(0, client1_); + InitStreamClient(1, client2_); + InitStreamClient(2, client3_); // worker index is 2 + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + + Status TryAndDeleteStream(std::shared_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + std::shared_ptr client1_ = nullptr; + std::shared_ptr client2_ = nullptr; + std::shared_ptr client3_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(DeleteStreamConcurrentTest, DeleteBySequence) +{ + std::string stream1("DeleteBySequence"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + + std::shared_ptr node3Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Producer1, defaultProducerConf_)); + std::shared_ptr node3Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Consumer1)); + + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(node1Consumer1->Close()); + DS_ASSERT_NOT_OK(client1_->DeleteStream(stream1)); + + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(node2Consumer1->Close()); + DS_ASSERT_NOT_OK(client2_->DeleteStream(stream1)); + + DS_ASSERT_OK(node3Producer1->Close()); + DS_ASSERT_OK(node3Consumer1->Close()); + // Now close consumer sends async update topo notifications + // So, we need to wait sometime before doing delete stream + DS_ASSERT_OK(TryAndDeleteStream(client3_, stream1)); +} + +TEST_F(DeleteStreamConcurrentTest, DeleteFromUnrelatedNode) +{ + std::string stream1("DeleteFromUnrelatedNode"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + std::promise promise1; + std::promise promise2; + std::future future1 = promise1.get_future(); + std::future future2 = promise2.get_future(); + ThreadPool pool(3); + pool.Submit([this, stream1, &config1, &promise1]() { + std::shared_ptr n1p1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, n1p1, defaultProducerConf_)); + std::shared_ptr n1c1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, n1c1)); + DS_ASSERT_OK(n1p1->Close()); + DS_ASSERT_OK(n1c1->Close()); + promise1.set_value(); + }); + pool.Submit([this, stream1, &config2, &promise2]() { + std::shared_ptr n2p1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, n2p1, defaultProducerConf_)); + std::shared_ptr n2c1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config2, n2c1)); + DS_ASSERT_OK(n2p1->Close()); + DS_ASSERT_OK(n2c1->Close()); + promise2.set_value(); + }); + pool.Submit([this, stream1, &future1, &future2]() { + future1.get(); + future2.get(); + DS_ASSERT_OK(client3_->DeleteStream(stream1)); + }); +} + +TEST_F(DeleteStreamConcurrentTest, ConcurrentDelete) +{ + std::string stream1("ConcurrentDelete"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(node1Consumer1->Close()); + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(node2Consumer1->Close()); + { + ThreadPool pool(3); + pool.Submit([this, stream1]() { client1_->DeleteStream(stream1); }); + pool.Submit([this, stream1]() { client2_->DeleteStream(stream1); }); + } +} + +TEST_F(DeleteStreamConcurrentTest, DeleteWhenSub) +{ + std::string stream1("DeleteWhenSub"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + + { + ThreadPool pool(2); + pool.Submit([this, stream1, &node1Producer1, &node1Consumer1]() { + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(node1Consumer1->Close()); + LOG(INFO) << "Thread:, State:"; + client2_->DeleteStream(stream1); + LOG(INFO) << "Thread:, State:"; + }); + pool.Submit([this, stream1, &node2Producer1, &node2Consumer1]() { + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(node2Consumer1->Close()); + LOG(INFO) << "Thread:, State:"; + client3_->DeleteStream(stream1); + LOG(INFO) << "Thread:, State:"; + }); + } +} + +TEST_F(DeleteStreamConcurrentTest, DeleteWhenPubSub) +{ + std::string stream1("DeleteWhenPubSub"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + + { + ThreadPool pool(3); + pool.Submit([this, stream1, config1]() { + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(node1Consumer1->Close()); + client3_->DeleteStream(stream1); + }); + pool.Submit([this, stream1, config2]() { + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(node2Consumer1->Close()); + client1_->DeleteStream(stream1); + }); + pool.Submit([this, stream1, config3]() { + std::shared_ptr node3Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Producer1, defaultProducerConf_)); + std::shared_ptr node3Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Consumer1)); + DS_ASSERT_OK(node3Producer1->Close()); + DS_ASSERT_OK(node3Consumer1->Close()); + client2_->DeleteStream(stream1); + }); + } +} + +TEST_F(DeleteStreamConcurrentTest, ParallelDeleteCreate) +{ + std::string stream1("ParallelDeleteCreate"); + { + ThreadPool pool(2); + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + DS_ASSERT_OK(node1Producer1->Close()); + // Inject delay in DeleteStream so that it waits before checking in master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DeleteStreamLocally.sleep", + "1*sleep(7000)")); + pool.Submit([this, stream1]() { + client1_->DeleteStream(stream1); + }); + pool.Submit([this, stream1]() { + sleep(1); + std::shared_ptr node1Producer2; + ASSERT_EQ(client1_->CreateProducer(stream1, node1Producer2, defaultProducerConf_).GetCode(), + StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS); + }); + } +} + +TEST_F(DeleteStreamConcurrentTest, TestAutoDeleteWhileCloseConsumerNotification) +{ + // 1. Slowdown the async notifications to 3 secs -> So that notifications are sent slowly + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "master.SendPendingNotification", "1*sleep(3000)")); + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "master.SendPendingNotification", "1*sleep(3000)")); + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 2, "master.SendPendingNotification", "1*sleep(3000)")); + + // 2. Create a producer and consumer on different nodes -> So that they send notifications + auto streamName = "AutoDeleteWhileCloseCon"; + defaultProducerConf_.autoCleanup = true; + std::shared_ptr producer; + DS_ASSERT_OK(client1_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2_->Subscribe(streamName, config, consumer)); + + // 3. Close consumer before producer -> This will introduce async notifications + consumer->Close(); + producer->Close(); + + // Wait for Auto delete to kick in + sleep(K_TWO); + + // Auto delete wont be deleting the stream as it has pending notifications + // We use delete stream API to check if stream still exists + // This call should fail with the error K_SC_STREAM_NOTIFICATION_PENDING + ASSERT_EQ(client1_->DeleteStream(streamName).GetCode(), StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + + // Time between auto delete retries is 10 secs + // Sleep for 10 secs, so that auto delete kicks in again + sleep(K_TEN); + + // Retry should be successful as notification take 3.x secs (because of introduced delay) + // Auto delete must have deleted this stream and we should get K_NOT_FOUND on manual delete + ASSERT_EQ(client1_->DeleteStream(streamName).GetCode(), StatusCode::K_NOT_FOUND); +} + + +TEST_F(DeleteStreamConcurrentTest, ParallelDeleteReset) +{ + std::string stream1("ParallelDeleteReset"); + { + ThreadPool pool(2); + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + DS_ASSERT_OK(node1Producer1->Close()); + // Inject delay in DeleteStream so that it waits before checking in master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DeleteStreamLocally.sleep", + "1*sleep(7000)")); + pool.Submit([this, stream1]() { + client1_->DeleteStream(stream1); + }); + pool.Submit([this, stream1]() { + std::vector streamNames; + streamNames.push_back(stream1); + sleep(1); + }); + } +} + +TEST_F(DeleteStreamConcurrentTest, ConcurrentDeleteCheckErrorCodes) +{ + std::string stream1("ConcurrentDeleteCheckErrorCodes"); + std::shared_ptr node1Producer; + ThreadPool pool(2); + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer, defaultProducerConf_)); + DS_ASSERT_OK(node1Producer->Close()); + // When 2 delete operations come in parallel + // 3 types of return codes are possible: + // OK - if the call is successful + // K_SC_STREAM_DELETE_IN_PROGRESS - Another call is still in progress + // K_NOT_FOUND - Stream already deleted + pool.Submit([this, stream1]() { + Status rc = client1_->DeleteStream(stream1); + EXPECT_TRUE(rc.IsOk() || + rc.GetCode() == StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS || + rc.GetCode() == StatusCode::K_NOT_FOUND); + }); + pool.Submit([this, stream1] () { + Status rc = client2_->DeleteStream(stream1); + EXPECT_TRUE(rc.IsOk() || + rc.GetCode() == StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS || + rc.GetCode() == StatusCode::K_NOT_FOUND); + }); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/delete_stream_test.cpp b/tests/st/client/stream_cache/delete_stream_test.cpp new file mode 100644 index 0000000..42a4ca5 --- /dev/null +++ b/tests/st/client/stream_cache/delete_stream_test.cpp @@ -0,0 +1,764 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "common.h" +#include "sc_client_common.h" +#include "client/stream_cache/pub_sub_utils.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/common/util/random_data.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class DeleteStreamTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numRpcThreads = 0; + opts.numWorkers = 2; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTestClientInstance(); + } + + void TearDown() override + { + client_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + void InitTestClientInstance() + { + InitStreamClient(0, client_); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + + Status CreateClient(int workerNum, std::shared_ptr &spClient) + { + InitStreamClient(workerNum, spClient); + return Status::OK(); + } + + void SendHelper(std::shared_ptr producer, Element element) + { + const int DEFAULT_SLEEP_TIME = 300; + int retryLimit = 30; + datasystem::Status rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + } + DS_ASSERT_OK(rc); + } + + void ReceiveHelper(std::shared_ptr consumer, size_t numElements) + { + Timer timer; + size_t remaining = numElements; + int round = 0; + const int DEFAULT_RETRY_TIME = 10; + const size_t PER_RECEIVE_NUM = 500; + const int DEFAULT_WAIT_TIME = 1000; + while (remaining > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(std::max(PER_RECEIVE_NUM, remaining), DEFAULT_WAIT_TIME, outElements)); + LOG(INFO) << "receive num : " << outElements.size() << " ;" << round++; + if (!outElements.empty()) { + remaining -= outElements.size(); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + } + } + } + + /** + * @brief Creates a stream client at the given worker num and timeout + * @param[in] workerNum The worker num to create the stream against + * @param[in] timeout Timeout for RPC requests + * @param[out] spClient Shared pointer to the stream client + * @return status + */ + Status CreateClient(int workerNum, int32_t timeout, std::shared_ptr &spClient) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(workerNum, workerAddress)); + // Create a client with user defined timeout + ConnectOptions connectOptions = { .host = workerAddress.Host(), + .port = workerAddress.Port(), + .connectTimeoutMs = timeout }; + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + spClient = std::make_shared(connectOptions); + RETURN_IF_NOT_OK(spClient->Init()); + return Status::OK(); + } + + std::shared_ptr client_ = nullptr; + ProducerConf defaultProducerConf_; + SubscriptionConfig config = SubscriptionConfig("sub1", SubscriptionType::STREAM); + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(DeleteStreamTest, CloseProducersAndConsumers) +{ + std::shared_ptr producer; + std::string streamName = "testCloseProdCon"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + std::shared_ptr consumer; + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client_->DeleteStream(streamName)); +} + +TEST_F(DeleteStreamTest, ProducerExist) +{ + std::shared_ptr producer; + std::string streamName = "testProdExist"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + Status rc = client_->DeleteStream(streamName); + ASSERT_EQ(rc.GetCode(), StatusCode::K_RUNTIME_ERROR); + ASSERT_EQ(producer->Close(), Status::OK()); + ASSERT_EQ(client_->DeleteStream(streamName), Status::OK()); +} + +TEST_F(DeleteStreamTest, ConsumerExist) +{ + std::shared_ptr consumer; + std::string streamName = "testConExist"; + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + + Status rc = client_->DeleteStream(streamName); + ASSERT_EQ(rc.GetCode(), StatusCode::K_RUNTIME_ERROR); + DS_ASSERT_OK(consumer->Close()); + ASSERT_EQ(client_->DeleteStream(streamName), Status::OK()); +} + +TEST_F(DeleteStreamTest, MultiSubsExist) +{ + std::shared_ptr producer; + std::string streamName = "testMultiExist"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + std::shared_ptr consumer1; + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer1)); + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config2, consumer2)); + + ASSERT_EQ(producer->Close(), Status::OK()); + ASSERT_EQ(client_->DeleteStream(streamName).GetCode(), StatusCode::K_RUNTIME_ERROR); + + ASSERT_EQ(consumer1->Close(), Status::OK()); + ASSERT_EQ(client_->DeleteStream(streamName).GetCode(), StatusCode::K_RUNTIME_ERROR); + ASSERT_EQ(consumer2->Close(), Status::OK()); + ASSERT_EQ(client_->DeleteStream(streamName), Status::OK()); +} + +TEST_F(DeleteStreamTest, CheckEmptyStreamAfterDelete) +{ + std::shared_ptr producer; + std::string streamName("testEmptyStreamAfterDelete"); + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + + size_t testSize = 1024ul * 1024ul; + std::vector writeElement = RandomData().RandomBytes(testSize); + Element element(reinterpret_cast(writeElement.data()), writeElement.size(), ULONG_MAX); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client_->DeleteStream(streamName)); + + // Check that the buffer is cleared. There should not be any elements + // to receive after purge. + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config2, consumer)); + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer->Close()); +} + +// Testing CreateProducer while DeleteStreams is called +TEST_F(DeleteStreamTest, TestParallelDeleteStreamCreateProducer) +{ + const int timeout = 10000; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, timeout, spClient), Status::OK()); + std::shared_ptr spClient1; + ASSERT_EQ(CreateClient(1, timeout, spClient1), Status::OK()); + ThreadPool pool(1); + // Create a producer and consumer on different nodes + auto stream_name = "testParallelDelStreamCreateProd"; + ProducerConf prodCfg = { .delayFlushTime = 5, .pageSize = 1 * MB, .maxStreamSize = 2 * MB}; + std::shared_ptr producer; + std::shared_ptr producer1; + spClient->CreateProducer(stream_name, producer, prodCfg); + producer->Close(); + spClient1->CreateProducer(stream_name, producer1, prodCfg); + producer1->Close(); + // Inject delay in DeleteStreamContext request in both workers and increase timeout for the request + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "MasterRemoteWorkerSCApi.DelStreamContextBroadcast.sleep", + "1*sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "MasterLocalWorkerSCApi.DelStreamContextBroadcast.sleep", + "1*sleep(5000)")); + // Make a create producer call while delete stream is active + auto delFut = pool.Submit([this, &spClient, &stream_name]() { + // Fails with a timeout + spClient->DeleteStream(stream_name); + }); + sleep(1); + // This should be rejected with a runtime error as delete stream is in progress + ASSERT_EQ(spClient1->CreateProducer(stream_name, producer, prodCfg).GetCode(), + StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS); + delFut.get(); +} + +TEST_F(DeleteStreamTest, TestParallelDeleteStreamCreateSubscriber) +{ + const int timeout = 10000; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, timeout, spClient), Status::OK()); + std::shared_ptr spClient1; + ASSERT_EQ(CreateClient(1, timeout, spClient1), Status::OK()); + ThreadPool pool(1); + // Create a producer and consumer on different nodes + auto stream_name = "testParallelDelStreamCreateCon"; + ProducerConf prodCfg = { .delayFlushTime = 5, .pageSize = 1 * MB, .maxStreamSize = 2 * MB}; + std::shared_ptr producer; + std::shared_ptr producer1; + spClient->CreateProducer(stream_name, producer, prodCfg); + producer->Close(); + spClient1->CreateProducer(stream_name, producer1, prodCfg); + producer1->Close(); + // Inject delay in DeleteStreamContext request in both workers and increase timeout for the request + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "MasterRemoteWorkerSCApi.DelStreamContextBroadcast.sleep", + "1*sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "MasterLocalWorkerSCApi.DelStreamContextBroadcast.sleep", + "1*sleep(5000)")); + // Make a create producer call while delete stream is active + auto delFut = pool.Submit([this, &spClient, &stream_name]() { + // Fails with a timeout + spClient->DeleteStream(stream_name); + }); + sleep(1); + std::shared_ptr consumer; + // This should be rejected with a runtime error as delete stream is in progress + ASSERT_EQ(spClient1->Subscribe(stream_name, config, consumer).GetCode(), + StatusCode::K_SC_STREAM_DELETE_IN_PROGRESS); + delFut.get(); +} + +// Testing AutoDelete retry while CreateProducer is done +TEST_F(DeleteStreamTest, TestParallelCreateProducerDeleteStream) +{ + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + std::shared_ptr spClient1; + ASSERT_EQ(CreateClient(1, spClient1), Status::OK()); + ThreadPool pool(1); + // Create a producer and consumer on different nodes + auto stream_name = "testParallelCreateProdDelStream"; + ProducerConf prodCfg = { .delayFlushTime = 5, + .pageSize = 1 * MB, + .maxStreamSize = 2 * MB}; + std::shared_ptr producer; + std::shared_ptr producer1; + spClient->CreateProducer(stream_name, producer, prodCfg); + producer->Close(); + spClient1->CreateProducer(stream_name, producer1, prodCfg); + producer1->Close(); + // Inject delay in DeleteStream so that it waits before checking in master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(7000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(7000)")); + // As create producer comes in between check fails + auto delFut = pool.Submit([this, &spClient, &stream_name]() { + // Fails with a timeout + ASSERT_EQ(spClient->DeleteStream(stream_name).GetCode(), StatusCode::K_SC_STREAM_IN_USE); + }); + usleep(4000); + // This should be rejected with a runtime error as delete stream is in progress + ASSERT_EQ(spClient1->CreateProducer(stream_name, producer, prodCfg).GetCode(), + StatusCode::K_OK); + delFut.get(); +} + +TEST_F(DeleteStreamTest, LEVEL2_TestDeleteLongTimeout) +{ + // Request should not timeout if client timeout is set to 10s and master takes more time + + // set timeout to 10 mins + std::shared_ptr client1; + const int32_t timeoutMs = 1000 * 60 * 10; + ASSERT_EQ(CreateClient(0, timeoutMs, client1), Status::OK()); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + + // Make master wait for 1 min and it should not timeout + // We actually dont know who is the master so inject in both + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(60000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(60000)")); + + // This request should not timeout as client timeout is 10 mins. + DS_ASSERT_OK(client1->CreateProducer("testDelLongTimeout", producer, conf)); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(client1->DeleteStream("testDelLongTimeout")); +} + +TEST_F(DeleteStreamTest, TestDeleteStreamTimingHole1) +{ + // The purpose of the testcase is to test a timing hole in DeleteStream. + // That is DeleteStreamLocally and DeleteStreamContext can both be skipped, + // leading to residue in RemoteWorkerManager, etc. and cause other problems. + std::shared_ptr client1; + DS_ASSERT_OK(CreateClient(0, client1)); + std::shared_ptr client2; + DS_ASSERT_OK(CreateClient(1, client2)); + // Create a producer and consumer on different nodes + std::string streamName = "testDelStreamTimingHole1"; + const int64_t DEFAULT_PAGE_SIZE = 4 * KB; + ProducerConf prodCfg = { .delayFlushTime = 5, + .pageSize = DEFAULT_PAGE_SIZE, + .maxStreamSize = 1 * MB, + .autoCleanup = true, + .retainForNumConsumers = 1 }; + const size_t DEFAULT_ELEMENT_SIZE = 2000; + const int ELE_NUM = 10; + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + // Add injection so that auto-delete is triggered after the start of manual-delete + // but still executed before manual-delete sends out RPC. + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DELETE_IN_PROGRESS.sleep", "1*sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(2000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(2000)")); + std::shared_ptr producer; + std::shared_ptr consumer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + DS_ASSERT_OK(producer->Close()); + + // Delete will not succeed, since K_NOT_FOUND + DS_ASSERT_NOT_OK(client1->DeleteStream(streamName)); + + // Recreate the same stream. + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + ReceiveHelper(consumer, ELE_NUM); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(DeleteStreamTest, TestDeleteStreamTimingHole2) +{ + // The purpose of the test case is to test a timing hole in DeleteStream. + // That is DeleteStreamLocally and DeleteStreamContext can both be skipped, + // leading to residue in RemoteWorkerManager, etc. and cause other problems. + std::shared_ptr client1; + DS_ASSERT_OK(CreateClient(0, client1)); + std::shared_ptr client2; + DS_ASSERT_OK(CreateClient(1, client2)); + // Create a producer and consumer on different nodes + std::string streamName = "testDeleteStreamTimingHoleTwo"; + const int64_t DEFAULT_PAGE_SIZE = 4 * KB; + ProducerConf prodCfg = { .delayFlushTime = 5, + .pageSize = DEFAULT_PAGE_SIZE, + .maxStreamSize = 1 * MB, + .autoCleanup = true, + .retainForNumConsumers = 1 }; + const size_t DEFAULT_ELEMENT_SIZE = 2000; + const int ELE_NUM = 10; + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + // Add injection so that auto-delete is triggered after the start of manual-delete + // but still executed before manual-delete sends out RPC. + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DELETE_IN_PROGRESS.sleep", "1*sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(2000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.sleep", "1*sleep(2000)")); + std::shared_ptr producer; + std::shared_ptr consumer; + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(client2->CreateProducer(streamName, producer, prodCfg)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + ReceiveHelper(consumer, ELE_NUM); + // Give producer enough time to process ack + sleep(1); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + + // Delete will not succeed, due to K_NOT_FOUND + DS_ASSERT_NOT_OK(client1->DeleteStream(streamName)); + + // Recreate the same stream. + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + ReceiveHelper(consumer, ELE_NUM); + DS_ASSERT_OK(producer->Close()); +} + +class DeleteStreamTimingTest : public DeleteStreamTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numRpcThreads = 0; + opts.numWorkers = 2; + DeleteStreamTest::SetClusterSetupOptions(opts); + opts.enableDistributedMaster = "false"; + opts.masterIdx = 1; + } + +protected: + const size_t DEFAULT_ELEMENT_SIZE = 2000; + const int64_t DEFAULT_PAGE_SIZE = 4 * KB; + const int ELE_NUM = 10, TWO = 2, FIVE = 5; + uint64_t producersCount, consumersCount; + ProducerConf prodCfgAutoDel = { .delayFlushTime = 5, + .pageSize = DEFAULT_PAGE_SIZE, + .maxStreamSize = 1 * MB, + .autoCleanup = true, + .retainForNumConsumers = 1 }; + ProducerConf prodCfg = { .delayFlushTime = 5, + .pageSize = DEFAULT_PAGE_SIZE, + .maxStreamSize = 1 * MB, + .autoCleanup = false, + .retainForNumConsumers = 1 }; + std::shared_ptr client1, client2; + std::shared_ptr producer; + std::shared_ptr consumer; +}; + +TEST_F(DeleteStreamTimingTest, TestDeleteStreamTimingHole3) +{ + // The purpose of the testcase is to test a timing hole in DeleteStream. + // Running both AutoDelete and Manual Delete at the same time + DS_ASSERT_OK(CreateClient(0, client1)); + std::string streamName = "testDelStreamTimingHoleThree"; + // Create a producer and consumer + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + // Add injection so that master sends deletestream requests (auto-delete) after + // but still executed before manual-delete sends out RPC to master to broadcast Delete. + // slow down the state setting so that both will conflict with each other + // Sleep manual before api->DeleteStream to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DELETE_IN_PROGRESS.sleep", "1*sleep(3000)")); + + // Sleep autodelete after sending broadcast to worker + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.SentReqs", "1*sleep(5000)")); + + // Create a producer and consumer + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfgAutoDel)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + ReceiveHelper(consumer, ELE_NUM); + // Give producer enough time to process ack + sleep(1); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + // Delete will not succeed, since AutoDelete deletes first. + DS_ASSERT_NOT_OK(client1->DeleteStream(streamName)); + LOG(INFO) << "Deleted stream"; + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 0ul); + ASSERT_EQ(consumersCount, 0ul); + sleep(FIVE); + // Check logs for "Setting Active State" +} + +TEST_F(DeleteStreamTimingTest, TestDeleteStreamTimingHole4) +{ + // The purpose of the testcase is to test running both Auto Delete and CreateProducer + // at the same time. AutoDelete is not finished when CreateProducer starts. + std::string streamName = "TestDelStreamTimingHoleFour"; + DS_ASSERT_OK(CreateClient(0, client1)); + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + + // Sleep autodelete before sending broadcast to worker + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.SendReqs", "1*sleep(5000)")); + + // Create a producer and consumer + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfgAutoDel)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + ReceiveHelper(consumer, ELE_NUM); + // Give producer enough time to process ack + sleep(1); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + sleep(1); + // AutoDelete is running, so createproducer does not succeed until delete is finished + DS_ASSERT_NOT_OK(client1->CreateProducer(streamName, producer, prodCfgAutoDel)); + sleep(FIVE); + // Delete finished + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfgAutoDel)); + + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 1ul); + ASSERT_EQ(consumersCount, 0ul); +} + +TEST_F(DeleteStreamTimingTest, TestDeleteStreamTimingHole5) +{ + // The purpose of the testcase is to test running 2 Manual Delete (no AutoDelete), + // and CreateProducer at the same time. Manual Deletes ignore the + // delete-in-progress, while CreateProducer runs. + DS_ASSERT_OK(CreateClient(0, client1)); + DS_ASSERT_OK(CreateClient(1, client2)); + std::string streamName = "TestDelStreamTimingHole5"; + ThreadPool pool(TWO); + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + + // Sleep manual before api->DeleteStream to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DELETE_IN_PROGRESS.sleep", "1*sleep(3000)")); + + // Sleep autodelete after sending broadcast to worker + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.SentReqs", "1*sleep(5000)")); + + // Create a producer and consumer + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + ReceiveHelper(consumer, ELE_NUM); + // Give producer enough time to process ack + sleep(1); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + sleep(1); + bool a, b; + auto delFut = pool.Submit([this, &a, streamName]() { + Status rc = client1->DeleteStream(streamName); + a = rc.IsOk(); + LOG(INFO) << FormatString("DeleteStream on W0 %ssuccessful", (a ? "" : "un")); + }); + auto delFut2 = pool.Submit([this, &b, streamName]() { + Status rc = client2->DeleteStream(streamName); + b = rc.IsOk(); + LOG(INFO) << FormatString("DeleteStream on W1 %ssuccessful", (b ? "" : "un")); + }); + sleep(1); + // AutoDelete is running, so createproducer does not succeed until delete is finished + DS_ASSERT_NOT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + delFut.get(); + delFut2.get(); + ASSERT_TRUE(a != b && (a || b)); + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 0ul); + ASSERT_EQ(consumersCount, 0ul); + sleep(FIVE); + // Delete finished + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 1ul); + ASSERT_EQ(consumersCount, 0ul); +} + +TEST_F(DeleteStreamTimingTest, LEVEL2_TestDeleteStreamTimingHole6) +{ + // The purpose of the testcase is to test running CreateProducer during + // DeleteStream + AutoDelete but with producers and consumers on both workers. + std::string streamName = "testDelStreamTimingHole6"; + DS_ASSERT_OK(CreateClient(0, client1)); + DS_ASSERT_OK(CreateClient(1, client2)); + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + + // Sleep manual before api->DeleteStream to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DELETE_IN_PROGRESS.sleep", "1*sleep(3000)")); + + // Sleep autodelete after sending broadcast to worker + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.SentReqs", "1*sleep(5000)")); + + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "master.ProcessDeleteStreams", "pause()")); + + // Create a producer and consumer + std::shared_ptr producer2; + std::shared_ptr consumer2; + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfgAutoDel)); + DS_ASSERT_OK(client2->Subscribe(streamName, SubscriptionConfig("sub2", SubscriptionType::STREAM), consumer2)); + DS_ASSERT_OK(client2->CreateProducer(streamName, producer2, prodCfgAutoDel)); + + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + SendHelper(producer2, element); + } + ReceiveHelper(consumer, TWO * ELE_NUM); + ReceiveHelper(consumer2, TWO * ELE_NUM); + // Give producer enough time to process ack + sleep(TWO); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(consumer2->Close()); + sleep(1); + DS_ASSERT_OK(cluster_->ClearInjectAction(ClusterNodeType::WORKER, 1, "master.ProcessDeleteStreams")); + sleep(1); + // AutoDelete is running, so createproducer does not succeed until delete is finished + DS_ASSERT_NOT_OK(client1->DeleteStream(streamName)); + + // Without a lock on undo in StreamMetadata, PubIncreaseNode will succeed when StreamManager + // has not been deleted. + LOG(INFO) << "Deleted stream"; + DS_ASSERT_NOT_OK(client1->CreateProducer(streamName, producer, prodCfgAutoDel)); + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 0ul); + ASSERT_EQ(consumersCount, 0ul); + sleep(FIVE); + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 0ul); + ASSERT_EQ(consumersCount, 0ul); + // AutoDelete finished + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfgAutoDel)); + + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 1ul); + ASSERT_EQ(consumersCount, 0ul); +} + +TEST_F(DeleteStreamTimingTest, LEVEL1_TestDeleteStreamTimingHole7) +{ + std::string streamName = "testDelStreamTimingHole7"; + // The purpose of the testcase is to test the UndoDeleteStream. We ensure that there is at least one reference to + // DeleteStream on master, inject an RPC failure so that DeleteStream fails and hits UndoDeleteStream, + // and then ensure producer can be created successfully. + int timeoutMs = 15 * 1000; + DS_ASSERT_OK(CreateClient(0, timeoutMs, client1)); + DS_ASSERT_OK(CreateClient(1, timeoutMs, client2)); + ThreadPool pool(TWO); + std::string data = RandomData().GetRandomString(DEFAULT_ELEMENT_SIZE); + Element element(reinterpret_cast(&data.front()), data.size()); + + // Sleep manual before api->DeleteStream to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DELETE_IN_PROGRESS.sleep", "1*sleep(3000)")); + + // Sleep manual after api->DeleteStream to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "ClientWorkerSCServiceImpl.DeleteStreamHandleSend.sleep", "1*sleep(5000)")); + + // Force RPC Timeout by timing out in master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.sleep", + "2*sleep(2500)")); + // sleep delete from master before sending broadcast to worker + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, + "SCMetadataManager.DeleteStream.SendReqs", "1*sleep(10000)")); + + // Create a producer and consumer + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + for (int i = 0; i < ELE_NUM; i++) { + SendHelper(producer, element); + } + ReceiveHelper(consumer, ELE_NUM); + // Give producer enough time to process ack + sleep(1); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + + bool a, b; + auto delFut = pool.Submit([this, &a, streamName]() { + ASSERT_EQ(client1->DeleteStream(streamName).GetCode(), K_SC_STREAM_DELETE_IN_PROGRESS); + LOG(INFO) << "DeleteStream on W0 failed due to Delete-In-Progress"; + }); + auto delFut2 = pool.Submit([this, &b, streamName]() { + ASSERT_EQ(client2->DeleteStream(streamName).GetCode(), K_RPC_DEADLINE_EXCEEDED); + LOG(INFO) << "DeleteStream on W1 failed due to RPC"; + }); + sleep(1); + // Deletes are running, so createproducer does not succeed until delete is finished + DS_ASSERT_NOT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + delFut.get(); + delFut2.get(); + // Both deletes failed, we should now be able to create because delete master fails + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, prodCfg)); + DS_ASSERT_OK(client1->QueryGlobalProducersNum(streamName, producersCount)); + DS_ASSERT_OK(client1->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(producersCount, 1ul); + ASSERT_EQ(consumersCount, 0ul); +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/mem_ctrl_boundary_case_test.cpp b/tests/st/client/stream_cache/mem_ctrl_boundary_case_test.cpp new file mode 100644 index 0000000..2022e70 --- /dev/null +++ b/tests/st/client/stream_cache/mem_ctrl_boundary_case_test.cpp @@ -0,0 +1,189 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Remote send test. + */ +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" + +namespace datasystem { +namespace st { + +constexpr uint64_t SMALL_SHM_SIZE_MB = 2; +constexpr uint64_t ADDITIONAL_INFO_SZ = 100; + +class MemCtrlBoundaryCaseTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = workerNum; + opts.workerGflagParams = " -page_size=" + std::to_string(BIG_PAGE_SIZE) + + " -shared_memory_size_mb=" + std::to_string(SMALL_SHM_SIZE_MB); + opts.numRpcThreads = 0; + opts.vLogLevel = 2; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + if (client0_) { + client0_ = nullptr; + } + if (client1_) { + client1_ = nullptr; + } + ExternalClusterTest::TearDown(); + } + +protected: + void InitTest() + { + InitStreamClient(0, client0_); + InitStreamClient(1, client1_); + } + + Status ProduceFixedSzData(std::shared_ptr &producer, const std::string &producerName, uint64_t eleSz, + uint64_t eleNum) + { + ElementGenerator generator(eleSz, eleSz); + auto eleList = generator.GenElements(producerName, eleNum, 1); + for (size_t i = 0; i < eleNum; ++i) { + LOG(INFO) << FormatString("Element idx:%zu, Element size:%zu", i, eleList[i].size()); + auto mutableData = const_cast(eleList[i].data()); + Element ele(reinterpret_cast(mutableData), eleList[i].size()); + RETURN_IF_NOT_OK(producer->Send(ele)); + } + return Status::OK(); + } + + const int workerNum = 2; + std::shared_ptr client0_; + std::shared_ptr client1_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(MemCtrlBoundaryCaseTest, RepeatRecv) +{ + std::string stream1("testRepeatRecv"); + ThreadPool pool(2); + std::vector> futs; + + std::promise topoPromise; + std::shared_future sFut = topoPromise.get_future(); + + std::promise promise; + std::shared_future eleNumFut = promise.get_future(); + + futs.emplace_back(pool.Submit([this, &promise, &stream1, &sFut]() { + const uint64_t eleSz0 = 16 * KB - ADDITIONAL_INFO_SZ; + const uint64_t eleNum0 = 31; + LOG(INFO) << FormatString("Round0, Element size: %zu, Element number: %zu", eleSz0, eleNum0); + + const uint64_t eleSz1 = 16 * KB - ADDITIONAL_INFO_SZ; + const uint64_t eleNum1 = 31; + LOG(INFO) << FormatString("Round1, Element size: %zu, Element number: %zu", eleSz1, eleNum1); + + auto totalEleNum = eleNum0 + eleNum1; + + // Create producer0 with PAGE_SIZE = 512KB on worker0 + ProducerConf producerConf = { + .delayFlushTime = -1, .pageSize = BIG_PAGE_SIZE, .maxStreamSize = SMALL_SHM_SIZE_MB * MB + }; + std::shared_ptr producer0; + RETURN_IF_NOT_OK(client0_->CreateProducer(stream1, producer0, producerConf)); + // Create producer1 with PAGE_SIZE = 512KB on worker0 + std::shared_ptr producer1; + RETURN_IF_NOT_OK(client0_->CreateProducer(stream1, producer1, producerConf)); + + sFut.get(); // After consumer1 subscribed, producer0 and producer1 can send data + + // Using producer0 and producer1 to send 62 * 16KB = 992KB (Each 496KB) data to worker1 + auto producer0Id = "producer0"; + auto producer1Id = "producer1"; + RETURN_IF_NOT_OK(ProduceFixedSzData(producer0, producer0Id, eleSz0, eleNum0)); + RETURN_IF_NOT_OK(ProduceFixedSzData(producer1, producer1Id, eleSz1, eleNum1)); + + promise.set_value(totalEleNum); // After sent all (eleSz0 + eleSz1) data, consumer can receive + RETURN_IF_NOT_OK(producer0->Close()); + RETURN_IF_NOT_OK(producer1->Close()); + return Status::OK(); + })); + futs.emplace_back(pool.Submit([this, &stream1, &eleNumFut, &topoPromise]() { + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(client1_->Subscribe(stream1, config1, consumer1)); + std::shared_ptr producer2; + ProducerConf producerConf = { + .delayFlushTime = -1, .pageSize = BIG_PAGE_SIZE, .maxStreamSize = SMALL_SHM_SIZE_MB * MB + }; + RETURN_IF_NOT_OK(client1_->CreateProducer(stream1, producer2, producerConf)); + int localEleNum = 1; + auto producer2Id = "producer2"; + RETURN_IF_NOT_OK(ProduceFixedSzData(producer2, producer2Id, (1 * KB / 8) - ADDITIONAL_INFO_SZ, localEleNum)); + RETURN_IF_NOT_OK(producer2->Close()); + topoPromise.set_value(); // Notify another thread consumer is set up + + auto remoteEleNum = eleNumFut.get(); + LOG(INFO) << FormatString("Remote element number:%zu", remoteEleNum); + auto totalEleNum = localEleNum + remoteEleNum; + + std::vector outElements; + thread_local auto timer = Timer(); + uint64_t timeOut = 1000; + uint64_t lastRecvCursor = 0; + while (timer.ElapsedMilliSecond() <= timeOut) { + std::vector output; + consumer1->Receive(totalEleNum, 100, output); // Receive 63 element without wait one time + if (!output.empty()) { + outElements.insert(outElements.end(), output.begin(), output.end()); + lastRecvCursor += output.size(); + LOG(INFO) << FormatString("consumer1 received %zu elements in total, hence we ack to cursor:%zu", + outElements.size(), lastRecvCursor); + RETURN_IF_NOT_OK(consumer1->Ack(lastRecvCursor)); // Release shm on worker1 + } + } + CHECK_FAIL_RETURN_STATUS( + outElements.size() == totalEleNum, K_INVALID, + FormatString("Out element1 size %zu not equal to expectedNum: %zu.", outElements.size(), totalEleNum)); + + RETURN_IF_NOT_OK(consumer1->Close()); + return Status::OK(); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + LOG(INFO) << "Success"; +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/mem_ctrl_test.cpp b/tests/st/client/stream_cache/mem_ctrl_test.cpp new file mode 100644 index 0000000..4b6781d --- /dev/null +++ b/tests/st/client/stream_cache/mem_ctrl_test.cpp @@ -0,0 +1,607 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Remote send test. + */ +#include + +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/common/log/log.h" + +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +class RemoteMemCtrlTest : public SCClientCommon { +public: + RemoteMemCtrlTest(uint32_t maxStreamSizeMb = 2, int64_t pageSize = 1024 * 4) + : maxStreamSizeMb_(maxStreamSizeMb), pageSize_(pageSize), bigSize_(pageSize_ / 16) + { + const uint64_t MB = 1024 * 1024; + producerConf_.pageSize = pageSize; + producerConf_.maxStreamSize = maxStreamSizeMb * MB; + } + virtual void SetClusterSetupOptions(ExternalClusterOptions &opts) override; + + void SetUp() override; + + void TearDown() override; + + static std::string streamName_; + static std::once_flag onceFlag_; + +protected: + static Status Produce(std::shared_ptr &producer, std::string producerName, size_t flushIntervals, + size_t numElements, uint64_t maxEleSz, uint64_t minEleSz, uint64_t *res = nullptr); + + static Status ConsumeAll(std::shared_ptr &consumer, size_t numElements, int batchNum, int ackInterval, + int sleepAfterRecvUs, int timeout = 5000, uint64_t *res = nullptr); + + static Status ConsumeAllClose(std::shared_ptr &consumer, size_t numElements, int batchNum, + int ackInterval, int sleepAfterRecv, int timeout = 5000, uint64_t *res = nullptr); + + void BasicSPSC(int round, int minEleSz, int maxEleSz, size_t flushIntervals, size_t numElements, int recvBatchNum, + int recvTimeout, int ackInterval, int sleepAfterRecvUs); + + void SendSideMultiProducers(int round, int minEleSz, int maxEleSz, size_t flushIntervals, size_t numElements, + int recvBatchNum, int recvTimeout, int ackInterval, int sleepAfterRecvUs); + + void RecvSideAddConsumer(int round, int minEleSz, int maxEleSz, size_t flushIntervals, size_t numElements, + int recvBatchNum, int recvTimeout, int ackInterval, int sleepAfterRecvUs); + + void SendSideConsumer(int round, int minEleSz, int maxEleSz, size_t flushIntervals, size_t numElements, + int recvBatchNum, int recvTimeout, int ackInterval, int sleepAfterRecvUs); + + void BothDirection(int round, int minEleSz, int maxEleSz, size_t flushIntervals, size_t numElements, + int recvBatchNum, int recvTimeout, int ackInterval, int sleepAfterRecvUs); + + uint32_t maxStreamSizeMb_; + uint64_t pageSize_; + uint64_t bigSize_; + + // Mock producer worker. + HostPort w1Addr_; + HostPort w2Addr_; + HostPort w3Addr_; + + std::shared_ptr w1Client_ = nullptr; + std::shared_ptr w2Client_ = nullptr; + std::shared_ptr w3Client_ = nullptr; + ProducerConf producerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; +std::string RemoteMemCtrlTest::streamName_ = "stream"; +std::once_flag RemoteMemCtrlTest::onceFlag_; + +void RemoteMemCtrlTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = 3; + opts.workerGflagParams = " -page_size=" + std::to_string(pageSize_); + opts.vLogLevel = 2; + SCClientCommon::SetClusterSetupOptions(opts); +} + +void RemoteMemCtrlTest::SetUp() +{ + ExternalClusterTest::SetUp(); + InitStreamClient(0, w1Client_); + InitStreamClient(0, w2Client_); + InitStreamClient(0, w3Client_); +} + +void RemoteMemCtrlTest::TearDown() +{ + w1Client_ = nullptr; + w2Client_ = nullptr; + w3Client_ = nullptr; + ExternalClusterTest::TearDown(); +} + +Status RemoteMemCtrlTest::Produce(std::shared_ptr &producer, std::string producerName, size_t flushIntervals, + size_t numElements, uint64_t maxEleSz, uint64_t minEleSz, uint64_t *res) +{ + (void)flushIntervals; + uint64_t totalEleSz = 0; + ElementGenerator elementGenerator(maxEleSz, minEleSz); + auto strs = elementGenerator.GenElements(producerName, numElements); + for (size_t i = 0; i < numElements; i++) { + totalEleSz += strs[i].size(); + Status status; + int retry = 0; + int cnt = 0; + do { + ++retry; + status = producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size())); + if (status.IsError()) { + if (cnt % 30 == 0) { + LOG(ERROR) << "Send error:" << status.ToString() << ", with retry: " << retry; + } + cnt++; + std::this_thread::sleep_for(std::chrono::seconds(2)); + } + } while (status.IsError() && status.GetCode() == StatusCode::K_OUT_OF_MEMORY); + RETURN_IF_NOT_OK(status); + } + if (res) { + *res = totalEleSz; + } + RETURN_IF_NOT_OK(producer->Close()); + return Status::OK(); +} + +Status RemoteMemCtrlTest::ConsumeAll(std::shared_ptr &consumer, size_t numElements, int batchNum, + int ackInterval, int sleepAfterRecvUs, int timeout, uint64_t *res) +{ + std::vector outElements; + int remainNum = numElements; + std::unordered_map seqNoMap; + uint64_t eleTotalSz = 0; + std::unordered_map> seqNums; + while (remainNum > 0) { + int expectNum = std::min(batchNum, remainNum); + std::vector out; + RETURN_IF_NOT_OK(consumer->Receive(expectNum, timeout, out)); + for (const auto &element : out) { + std::string info = FormatString("Element ID %zu :", element.id); + ElementView view(std::string((const char *)element.ptr, element.size)); + RETURN_IF_NOT_OK_APPEND_MSG(view.VerifyIntegrity(), info); + RETURN_IF_NOT_OK_APPEND_MSG(view.VerifyFifo(seqNoMap, 0), info); + eleTotalSz += element.size; + uint64_t seqNo; + RETURN_IF_NOT_OK(view.GetSeqNo(seqNo)); + std::string producerId; + RETURN_IF_NOT_OK(view.GetProducerId(producerId)); + seqNums[std::string(producerId)].push_back(seqNo); + } + outElements.insert(outElements.end(), out.begin(), out.end()); + remainNum -= out.size(); + if (sleepAfterRecvUs > 0) { + std::this_thread::sleep_for(std::chrono::microseconds(sleepAfterRecvUs)); + } + if (outElements.size() % ackInterval == 0) { + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + } + } + for (auto &entry : seqNums) { + LOG(INFO) << FormatString("Receive form producer[%s], nums:[%zu].", entry.first, entry.second.size()); + } + LOG(INFO) << "Actual got element size:" << outElements.size(); + CHECK_FAIL_RETURN_STATUS( + outElements.size() == numElements, StatusCode::K_RUNTIME_ERROR, + FormatString("Should recv %zu elements, actual recv %zu elements", numElements, outElements.size())); + if (res != nullptr) { + *res = eleTotalSz; + } + return Status::OK(); +} + +Status RemoteMemCtrlTest::ConsumeAllClose(std::shared_ptr &consumer, size_t numElements, int batchNum, + int ackInterval, int sleepAfterRecv, int timeout, uint64_t *res) +{ + RETURN_IF_NOT_OK(ConsumeAll(consumer, numElements, batchNum, ackInterval, sleepAfterRecv, timeout, res)); + return consumer->Close(); +} + +void RemoteMemCtrlTest::BasicSPSC(int round, int minEleSz, int maxEleSz, size_t flushIntervals, size_t numElements, + int recvBatchNum, int recvTimeout, int ackInterval, int sleepAfterRecvUs) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + ThreadPool pool(10); + for (int i = 0; i < 2; i++) { + uint64_t sendDataSz = 0; + uint64_t recvDataSz = 0; + std::vector> futs; + std::promise promise; + std::vector> producers(1); + // W1: 1 producer + futs.emplace_back(pool.Submit([this, &promise, streamName, &flushIntervals, &numElements, &maxEleSz, &minEleSz, + &sendDataSz, &producers, i]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, producerConf_)); + LOG(INFO) << "Succeed CreateProducer with stream " << streamName; + promise.get_future().get(); + // Send. + auto producerId = "producer" + std::to_string(i); + RETURN_IF_NOT_OK( + Produce(producer, producerId, flushIntervals, numElements, maxEleSz, minEleSz, &sendDataSz)); + producers[0] = std::move(producer); + return Status::OK(); + })); + // W2: 1 consumer + futs.emplace_back(pool.Submit([this, &promise, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout, &recvDataSz]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promise.set_value(); + RETURN_IF_NOT_OK(ConsumeAllClose(consumer, numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, + recvTimeout, &recvDataSz)); + return Status::OK(); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + EXPECT_EQ(sendDataSz, recvDataSz); + LOG(INFO) << FormatString("Finish: %d, sendDataSz:%zu, recvDataSz:%zu", i, sendDataSz, recvDataSz); + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + } +} + +// Recv slower than send and all elements is small element. +// W1: Producer. +// W2: Consumer. +TEST_F(RemoteMemCtrlTest, TestBasicSPSC0) +{ + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { BasicSPSC(round, 1, 1, 10, 1000, 1, 30'000, 10, 0); }); + } +} + +// Recv slower than send and partial elements are big elements. +// W1: Producer. +// W2: Consumer. +TEST_F(RemoteMemCtrlTest, TestBasicSPSC1) +{ + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { BasicSPSC(round, 1 * KB, 2 * KB, 10, 1000, 1, 30'000, 10, 5'000); }); + } +} + +void RemoteMemCtrlTest::SendSideMultiProducers(int round, int minEleSz, int maxEleSz, size_t flushIntervals, + size_t numElements, int recvBatchNum, int recvTimeout, int ackInterval, + int sleepAfterRecvUs) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + ThreadPool pool(10); + for (int i = 0; i < 1; i++) { + std::atomic sendDataSz = { 0 }; + uint64_t recvDataSz = 0; + std::vector> futs; + std::promise promise; + std::shared_future sFut = promise.get_future(); + std::vector> producers(2); + for (auto j = 0; j < 2; j++) { + futs.emplace_back(pool.Submit([this, &sFut, j, streamName, &flushIntervals, &numElements, &maxEleSz, + &minEleSz, &sendDataSz, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, producerConf_)); + sFut.get(); + std::string producerName = FormatString("producer-[%d]", j); + LOG(INFO) << FormatString("The producer name:%s, ", producerName); + // Send. + uint64_t sendData = 0; + auto producerId = "producer" + std::to_string(j); + RETURN_IF_NOT_OK( + Produce(producer, producerId, flushIntervals, numElements, maxEleSz, minEleSz, &sendData)); + sendDataSz.fetch_add(sendData); + producers[j] = std::move(producer); + return Status::OK(); + })); + } + // W2: 1 consumer + futs.emplace_back(pool.Submit([this, &promise, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout, &recvDataSz]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promise.set_value(); + RETURN_IF_NOT_OK(ConsumeAllClose(consumer, numElements * 2, recvBatchNum, ackInterval, sleepAfterRecvUs, + recvTimeout, &recvDataSz)); + return Status::OK(); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + } +} + +// W1: Two producers. +// W2: Consumer. +// Flush need FIFO for a producer. +TEST_F(RemoteMemCtrlTest, LEVEL1_TestSendSideMultiProducers) +{ + auto rounds = 10; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { SendSideMultiProducers(round, 1, 1, 10, 1000, 1, 30'000, 10, 0); }); + } +} + +void RemoteMemCtrlTest::RecvSideAddConsumer(int round, int minEleSz, int maxEleSz, size_t flushIntervals, + size_t numElements, int recvBatchNum, int recvTimeout, int ackInterval, + int sleepAfterRecvUs) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + ThreadPool pool(10); + for (int i = 0; i < 10; i++) { + std::vector> futs; + std::future futs1; + std::promise promise; + // Get shared future form promise, it can be get multi times. + std::shared_future sfut = promise.get_future(); + std::vector> producers(1); + + // W1: P1, send data after C1 created. + futs.emplace_back( + pool.Submit([this, &sfut, streamName, &flushIntervals, &numElements, &maxEleSz, &minEleSz, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, producerConf_)); + sfut.get(); + + // Send. + RETURN_IF_NOT_OK(Produce(producer, "producer", flushIntervals, numElements, maxEleSz, minEleSz)); + producers[0] = std::move(producer); + return Status::OK(); + })); + + // W2: C1 + futs1 = pool.Submit( + [this, &promise, streamName, &numElements, &recvBatchNum, &ackInterval, &sleepAfterRecvUs, &recvTimeout]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promise.set_value(); + RETURN_IF_NOT_OK( + ConsumeAllClose(consumer, numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, recvTimeout)); + return Status::OK(); + }); + + // W2: C2, create C2 after C1 created. + futs.emplace_back(pool.Submit([this, &futs1, &sfut, streamName, &numElements]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub2", SubscriptionType::STREAM); + sfut.get(); + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + + std::vector outElements; + std::unordered_map seqNoMap; + size_t recvNum = 0; + while (!IsThreadFinished(futs1, 0)) { + auto stat = consumer->Receive(1, 0, outElements); + if (stat == Status::OK() && !outElements.empty()) { + const auto &e = outElements.back(); + LOG(INFO) << "Cursor: " << e.id << ", Sz: " << e.size; + ElementView view(std::string((const char *)e.ptr, e.size)); + RETURN_IF_NOT_OK(view.VerifyIntegrity()); + RETURN_IF_NOT_OK(view.VerifyFifoInitOff(seqNoMap)); + recvNum++; + } + } + CHECK_FAIL_RETURN_STATUS(recvNum <= numElements, StatusCode::K_RUNTIME_ERROR, ""); + return consumer->Close(); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + futs1.get(); + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + LOG(INFO) << FormatString("Finish: %d", i); + } +} + +// W1: Producer. +// W: Consumer, then dynamically add another. +TEST_F(RemoteMemCtrlTest, TestRecvSideAddConsumer) +{ + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { RecvSideAddConsumer(round, 1, 1, 10, 1000, 1, 30'000, 10, 0); }); + } +} + +void RemoteMemCtrlTest::SendSideConsumer(int round, int minEleSz, int maxEleSz, size_t flushIntervals, + size_t numElements, int recvBatchNum, int recvTimeout, int ackInterval, + int sleepAfterRecvUs) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + ThreadPool pool(10); + for (int i = 0; i < 10; i++) { + std::vector> futs; + std::vector> promises(3); + std::vector> sFuts; // Can be get multi times + for (auto &promise : promises) { + sFuts.emplace_back(promise.get_future()); + } + std::vector> producers(1); + // W1: C1, can receive all elements. + futs.emplace_back(pool.Submit([this, &promises, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout]() { + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(w1Client_->Subscribe(streamName, config, consumer)); + promises[0].set_value(); + + RETURN_IF_NOT_OK( + ConsumeAllClose(consumer, numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, recvTimeout)); + return Status::OK(); + })); + // W1: C2, can receive all elements. + futs.emplace_back(pool.Submit([this, &promises, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout]() { + SubscriptionConfig config("sub2", SubscriptionType::STREAM); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(w1Client_->Subscribe(streamName, config, consumer)); + promises[1].set_value(); + return ConsumeAllClose(consumer, numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, recvTimeout); + })); + // W1: P1, Send data after c1,c2,c3 create successfully. + futs.emplace_back( + pool.Submit([this, &sFuts, streamName, &flushIntervals, &numElements, &maxEleSz, &minEleSz, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, producerConf_)); + + // Send, wait c1,c2,c3 create successfully. + for (auto &sFut : sFuts) { + sFut.get(); + } + RETURN_IF_NOT_OK(Produce(producer, "producer", flushIntervals, numElements, maxEleSz, minEleSz)); + producers[0] = std::move(producer); + return Status::OK(); + })); + // W2: C3, can receive all elements. + futs.emplace_back(pool.Submit([this, &promises, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout]() { + SubscriptionConfig config("sub3", SubscriptionType::STREAM); + std::shared_ptr consumer; + + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promises[2].set_value(); + RETURN_IF_NOT_OK( + ConsumeAllClose(consumer, numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, recvTimeout)); + return Status::OK(); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + LOG(INFO) << FormatString("Finish: %d", i); + } +} + +// W1: Producer, C1, C2. +// W2: Consumer. +TEST_F(RemoteMemCtrlTest, TestSendSideConsumer) +{ + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { SendSideConsumer(round, 1, 1, 10, 1000, 1, 30'000, 10, 0); }); + } +} + +void RemoteMemCtrlTest::BothDirection(int round, int minEleSz, int maxEleSz, size_t flushIntervals, size_t numElements, + int recvBatchNum, int recvTimeout, int ackInterval, int sleepAfterRecvUs) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + ThreadPool pool(10); + for (int i = 0; i < 10; i++) { + LOG(INFO) << FormatString("===================== [Round: %d] Start =====================", i); + std::vector> futs; + std::vector> promises(3); + std::vector> sFuts; + for (auto &promise : promises) { + sFuts.emplace_back(promise.get_future()); + } + std::vector> producers(3); + // W1 : P1 + futs.emplace_back( + pool.Submit([this, &sFuts, streamName, &flushIntervals, &numElements, &maxEleSz, &minEleSz, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, producerConf_)); + // Wait for all consumers created successfully. + for (auto &sFut : sFuts) { + sFut.get(); + } + RETURN_IF_NOT_OK(Produce(producer, "producer1", flushIntervals, numElements, maxEleSz, minEleSz)); + producers[0] = std::move(producer); + return Status::OK(); + })); + // W2: P2 + futs.emplace_back( + pool.Submit([this, &sFuts, streamName, &flushIntervals, &numElements, &maxEleSz, &minEleSz, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w2Client_->CreateProducer(streamName, producer, producerConf_)); + // Wait for all consumers created successfully. + for (auto &sFut : sFuts) { + sFut.get(); + } + RETURN_IF_NOT_OK(Produce(producer, "producer2", flushIntervals, numElements, maxEleSz, minEleSz)); + producers[1] = std::move(producer); + return Status::OK(); + })); + // W3: P3 + futs.emplace_back( + pool.Submit([this, &sFuts, streamName, &flushIntervals, &numElements, &maxEleSz, &minEleSz, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w3Client_->CreateProducer(streamName, producer, producerConf_)); + // Wait for all consumers created successfully. + for (auto &sFut : sFuts) { + sFut.get(); + } + RETURN_IF_NOT_OK(Produce(producer, "producer3", flushIntervals, numElements, maxEleSz, minEleSz)); + producers[2] = std::move(producer); + return Status::OK(); + })); + // W1: C1, C2 + futs.emplace_back(pool.Submit([this, &promises, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout]() { + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(w1Client_->Subscribe(streamName, config, consumer)); + promises[0].set_value(); // Notify producers to send elements + + return ConsumeAllClose(consumer, 3 * numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, recvTimeout); + })); + futs.emplace_back(pool.Submit([this, &promises, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout]() { + SubscriptionConfig config("sub2", SubscriptionType::STREAM); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(w1Client_->Subscribe(streamName, config, consumer)); + promises[1].set_value(); // Notify producers to send elements + + return ConsumeAllClose(consumer, 3 * numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, recvTimeout); + })); + + // W2: C3 + futs.emplace_back(pool.Submit([this, &promises, streamName, &numElements, &recvBatchNum, &ackInterval, + &sleepAfterRecvUs, &recvTimeout]() { + SubscriptionConfig config("sub3", SubscriptionType::STREAM); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promises[2].set_value(); // Notify producers to send elements + return ConsumeAllClose(consumer, 3 * numElements, recvBatchNum, ackInterval, sleepAfterRecvUs, recvTimeout); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + LOG(INFO) << FormatString("Finish: %d", i); + LOG(INFO) << FormatString("===================== [Round: %d] End =====================", i); + } +} + +// W1: Producer, 2Consumer. +// W2: Producer, Consumer. +// W3: Producer. +TEST_F(RemoteMemCtrlTest, TestBothDirection) +{ + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { BothDirection(round, 1, 1, 10, 1000, 1, 30'000, 10, 0); }); + } +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/multi_producer_multi_consumer_test.cpp b/tests/st/client/stream_cache/multi_producer_multi_consumer_test.cpp new file mode 100644 index 0000000..3882b3d --- /dev/null +++ b/tests/st/client/stream_cache/multi_producer_multi_consumer_test.cpp @@ -0,0 +1,695 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include +#include +#include +#include +#include +#include + +#include + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/common/inject/inject_point.h" + +using namespace datasystem::client::stream_cache; +using ::testing::Values; +namespace datasystem { +namespace st { +struct InputStreamInfo { + InputStreamInfo &SetProducers(int producerNum) + { + this->producerNum = producerNum; + return *this; + } + InputStreamInfo &AddSub(const std::string &subName, SubscriptionType type, int consumerNum) + { + subscriptions[subName] = std::make_pair(type, consumerNum); + return *this; + } + int producerNum = 0; + std::unordered_map> subscriptions; +}; + +struct StreamParas { + explicit StreamParas(int pageSize = -1, int sharedMemMB = 1024, int regularSocketNum = 16, int streamSocketNum = 16) + : pageSize(pageSize), + sharedMemorySizeMB(sharedMemMB), + regularSocketNum(regularSocketNum), + streamSocketNum(streamSocketNum){}; + ~StreamParas() = default; + InputStreamInfo &MutableStream(const std::string &streamName) + { + return this->streams[streamName]; + } + int pageSize; + int sharedMemorySizeMB; + int regularSocketNum; + int streamSocketNum; + std::unordered_map streams; +}; + +struct OutputStreamInfo { + std::atomic totalSend; + std::atomic totalRecv; + std::vector> producers; + std::unordered_map>> consumers; +}; + + +// Constant expressions +constexpr int K_3 = 3, K_5 = 5, K_10 = 10, K_50 = 50, K_100 = 100; + +class MultiProducerMultiConsumerTest : public SCClientCommon { +public: + // Called by Base class SetUp + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + if (inputParas_.pageSize > 0) { + opts.workerGflagParams = " -page_size=" + std::to_string(inputParas_.pageSize); + } + opts.workerGflagParams += " -shared_memory_size_mb=" + std::to_string(inputParas_.sharedMemorySizeMB); + opts.numScRegularSocket = inputParas_.regularSocketNum; + opts.numScStreamSocket = inputParas_.streamSocketNum; + opts.numRpcThreads = 0; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitClient(); + } + + void TearDown() override + { + client_ = nullptr; + ExternalClusterTest::TearDown(); + } + + std::shared_ptr GetInputParas(int streamNum, int producerNum, int subNum); + Status AsyncCreateProducersAndConsumers(std::unordered_map &input, + std::unordered_map &output); + Status MultiSendRecv(std::unordered_map &output, int elementNum, int elementSize, + int flushIntervals, int elementBatchNum, int blockingMs); + Status MultiDeleteStream(std::unordered_map &output); + std::vector GenerateElements(int elementNum, uint64_t elementSize, std::string &outData); + +protected: + void InitClient() + { + InitStreamClient(0, client_); + } + std::shared_ptr client_ = nullptr; + StreamParas inputParas_; + std::unordered_map streamInfo_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +std::shared_ptr MultiProducerMultiConsumerTest::GetInputParas(int streamNum, int producerNum, int subNum) +{ + auto paras = std::make_shared(); + for (int i = 0; i < streamNum; i++) { + std::string streamName = "stream" + std::to_string(i); + paras->MutableStream(streamName).SetProducers(producerNum); + for (int j = 0; j < subNum; j++) { + std::string subName = "sub" + std::to_string(j); + paras->MutableStream(streamName).AddSub(subName, SubscriptionType::STREAM, 1); + } + } + return paras; +} + +Status MultiProducerMultiConsumerTest::AsyncCreateProducersAndConsumers( + std::unordered_map &input, std::unordered_map &output) +{ + output.reserve(input.size()); + std::vector producerThreads; + std::vector consumerThreads; + // Avoid data race when output[streamName] insert node and producer/consumer threads read output[streamName] + for (auto &iter : input) { + const auto &streamName = iter.first; + auto &info = iter.second; + // Create producers for the stream + output[streamName].producers.resize(info.producerNum); + } + + for (auto &iter : input) { + const auto &streamName = iter.first; + auto &info = iter.second; + for (int i = 0; i < info.producerNum; i++) { + producerThreads.emplace_back([i, streamName, &output, this]() { + LOG(INFO) << "Start create producer " << i << " for stream " << streamName; + std::shared_ptr producer; + EXPECT_EQ(client_->CreateProducer(streamName, producer), Status::OK()); + LOG(INFO) << FormatString("Finished create producer %d for stream %s", i, streamName); + output[streamName].producers[i] = std::move(producer); + }); + } + // Create consumers for the subscriptions + for (auto &subInfo : info.subscriptions) { + const auto &subName = subInfo.first; + SubscriptionConfig config(subName, subInfo.second.first); + int consumerNum = subInfo.second.second; + output[streamName].consumers[subName].resize(consumerNum); + for (int i = 0; i < consumerNum; i++) { + consumerThreads.emplace_back([i, streamName, subName, config, &output, this]() { + LOG(INFO) << FormatString("Start create consumer %d for (stream %s, sub %s).", i, streamName, + subName); + std::shared_ptr consumer; + EXPECT_EQ(client_->Subscribe(streamName, config, consumer), Status::OK()); + LOG(INFO) << FormatString("Finished create consumer %d for (stream %s, sub %s).", i, streamName, + subName); + output[streamName].consumers[subName][i] = std::move(consumer); + }); + } + } + } + for (std::thread &producer : producerThreads) { + producer.join(); + } + for (std::thread &consumer : consumerThreads) { + consumer.join(); + } + return Status::OK(); +} + +Status MultiProducerMultiConsumerTest::MultiSendRecv(std::unordered_map &output, + int elementNum, int elementSize, int flushIntervals, + int elementBatchNum, int blockingMs) +{ + std::string data; + std::vector elements = GenerateElements(elementNum, elementSize, data); + std::vector threads; + for (auto &iter : output) { + std::string streamName = iter.first; + int producerNum = iter.second.producers.size(); + iter.second.totalSend.store(0); + iter.second.totalRecv.store(0); + for (int i = 0; i < producerNum; i++) { + threads.emplace_back([&elements, &output, i, streamName, flushIntervals]() { + auto producer = output[streamName].producers[i].get(); + for (size_t j = 0; j < elements.size(); j++) { + ASSERT_EQ(producer->Send(elements[j]), Status::OK()); + output[streamName].totalSend.fetch_add(1); + } + ASSERT_EQ(producer->Close(), Status::OK()); + }); + } + + int totalElementNum = elements.size() * producerNum; + for (auto &subInfo : iter.second.consumers) { + auto &subName = subInfo.first; + int consumerNum = subInfo.second.size(); + for (int i = 0; i < consumerNum; i++) { + threads.emplace_back([&output, i, streamName, subName, totalElementNum, elementBatchNum, blockingMs]() { + int remainNum = totalElementNum; + auto consumer = output[streamName].consumers[subName][i].get(); + while (remainNum > 0) { + int expectNum = std::min(elementBatchNum, remainNum); + std::vector outElements; + auto status = consumer->Receive(expectNum, blockingMs, outElements); + if (status.IsError()) { + if (status.GetCode() != StatusCode::K_RUNTIME_ERROR) { + LOG(ERROR) << "Receive error:" << status.ToString(); + } + } else if (outElements.empty()) { + continue; + } else { + ASSERT_EQ(consumer->Ack(outElements.back().id), Status::OK()); + } + remainNum -= outElements.size(); + output[streamName].totalRecv.fetch_add(outElements.size()); + } + ASSERT_EQ(consumer->Close(), Status::OK()); + }); + } + } + } + for (auto &t : threads) { + t.join(); + } + for (auto &iter : output) { + EXPECT_EQ(iter.second.totalSend.load() * iter.second.consumers.size(), iter.second.totalRecv.load()); + } + return Status::OK(); +} + +Status MultiProducerMultiConsumerTest::MultiDeleteStream(std::unordered_map &output) +{ + std::vector threads; + for (auto &iter : output) { + std::string streamName = iter.first; + threads.emplace_back([streamName, this]() { ASSERT_EQ(client_->DeleteStream(streamName), Status::OK()); }); + } + for (auto &t : threads) { + t.join(); + } + return Status::OK(); +} + +std::vector MultiProducerMultiConsumerTest::GenerateElements(int elementNum, uint64_t elementSize, + std::string &outData) +{ + outData = RandomData().GetRandomString(elementNum * elementSize); + std::vector ret; + ret.reserve(elementSize); + for (int i = 1; i <= elementNum; i++) { + Element element(reinterpret_cast(&outData.front()), elementSize, ULONG_MAX); + ret.push_back(element); + } + return ret; +} + +TEST_F(MultiProducerMultiConsumerTest, SPSC) +{ + auto paras = GetInputParas(1, 1, 1); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_10 * K_100, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +TEST_F(MultiProducerMultiConsumerTest, SPMC) +{ + auto paras = GetInputParas(1, 1, K_10); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_100, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +TEST_F(MultiProducerMultiConsumerTest, MPSC) +{ + auto paras = GetInputParas(1, K_10, 1); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_100, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +TEST_F(MultiProducerMultiConsumerTest, MPMC) +{ + auto paras = GetInputParas(1, K_3, K_3); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_50, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +TEST_F(MultiProducerMultiConsumerTest, MSSPSC) +{ + auto paras = GetInputParas(K_10, 1, 1); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_100, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +TEST_F(MultiProducerMultiConsumerTest, MSSPMC) +{ + auto paras = GetInputParas(K_10, 1, K_10); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_100, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +TEST_F(MultiProducerMultiConsumerTest, MSMPSC) +{ + auto paras = GetInputParas(K_5, K_5, 1); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_100, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +TEST_F(MultiProducerMultiConsumerTest, MSMPMC) +{ + auto paras = GetInputParas(K_5, K_3, K_3); + std::unordered_map output; + AsyncCreateProducersAndConsumers(paras->streams, output); + MultiSendRecv(output, K_100, K_100, K_10, K_5, K_100); + MultiDeleteStream(output); +} + +// This is set of basic testcases for MPSC scenario (2 producer, 1 consumer) +constexpr int ONE_WORKER = 1; +constexpr int TWO_WORKER = 2; +constexpr int THREE_WORKER = 3; +class BasicMultipleProducerSingleConsumerTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = numWorkers; + opts.numEtcd = numEtcd; + opts.numRpcThreads = numRpcThreads; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + client3_ = nullptr; + ExternalClusterTest::TearDown(); + } + + void SendHelper(std::shared_ptr producer, int numElement) + { + const int DEFAULT_SLEEP_TIME = 300; + int retryLimit = 5; + size_t sizeElement = 512; + + for (int i = 0; i < numElement; i++) { + std::vector writeElement = RandomData().RandomBytes(sizeElement); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + datasystem::Status rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + } + DS_ASSERT_OK(rc); + int log_interval = 100; + if (i % log_interval == 0) { + LOG(INFO) << "send out " << i << " elements"; + } + } + LOG(INFO) << "send end"; + } + + void ReceiveHelper(std::shared_ptr consumer, size_t numElements) + { + size_t remaining = numElements; + int round = 0; + const int roundLimit = 10; + const int PER_RECEIVE_NUM = 100; + const int DEFAULT_WAIT_TIME = 1000; + while (remaining > 0 && round < roundLimit) { + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(PER_RECEIVE_NUM, DEFAULT_WAIT_TIME, outElements)); + LOG(INFO) << "remaining num : " << remaining; + LOG(INFO) << "receive num : " << outElements.size() << " ;" << round++; + if (!outElements.empty()) { + remaining -= outElements.size(); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + if (remaining == 0) { + break; + } + } + } + } + void GetProducer(int numOfWorker, std::shared_ptr &Producer, + std::string streamName) + { + (void)numOfWorker; + client1_->CreateProducer(streamName, Producer, defaultProducerConf_); + } + + void GetConsumer(int numOfWorker, std::shared_ptr &Consumer, + std::string streamName) + { + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + switch (numOfWorker) { + case ONE_WORKER: + DS_ASSERT_OK(client1_->Subscribe(streamName, config1, Consumer)); + break; + case TWO_WORKER: + DS_ASSERT_OK(client2_->Subscribe(streamName, config1, Consumer)); + break; + case THREE_WORKER: + DS_ASSERT_OK(client3_->Subscribe(streamName, config1, Consumer)); + break; + } + } + + void GetProducerConsumers(int numOfWorker, std::shared_ptr &Producer1, + std::shared_ptr &Producer2, std::shared_ptr &Consumer1, + std::string streamName) + { + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + + if (numOfWorker == ONE_WORKER) { + DS_ASSERT_OK(client1_->CreateProducer(streamName, Producer1, defaultProducerConf_)); + DS_ASSERT_OK(client1_->CreateProducer(streamName, Producer2, defaultProducerConf_)); + DS_ASSERT_OK(client1_->Subscribe(streamName, config1, Consumer1)); + } else if (numOfWorker == TWO_WORKER) { + DS_ASSERT_OK(client1_->CreateProducer(streamName, Producer1, defaultProducerConf_)); + DS_ASSERT_OK(client1_->CreateProducer(streamName, Producer2, defaultProducerConf_)); + DS_ASSERT_OK(client2_->Subscribe(streamName, config1, Consumer1)); + } else if (numOfWorker == THREE_WORKER) { + DS_ASSERT_OK(client1_->CreateProducer(streamName, Producer1, defaultProducerConf_)); + DS_ASSERT_OK(client2_->CreateProducer(streamName, Producer2, defaultProducerConf_)); + DS_ASSERT_OK(client3_->Subscribe(streamName, config1, Consumer1)); + } else { + LOG(ERROR) << "Incorrect number of numOfWorker"; + ASSERT_TRUE(false); + } + } + + void CloseProducerDuringSend(int numOfWorker) + { + std::shared_ptr Producer1; + std::shared_ptr Producer2; + std::shared_ptr Consumer1; + std::string streamName = "CloseProdDuringSend"; + GetProducerConsumers(numOfWorker, Producer1, Producer2, Consumer1, streamName); + + int numElement = 1; + int totalRecvNum = 2 * numElement; + int threadNum = 3; + ThreadPool pool(threadNum); + + datasystem::inject::Set("ProducerImpl.Send.delay", "call(10)"); + pool.Submit([this, Producer1, numElement]() { SendHelper(Producer1, numElement);}); + pool.Submit([this, Producer2, numElement]() { SendHelper(Producer2, numElement);}); + pool.Submit([this, Consumer1, totalRecvNum]() { ReceiveHelper(Consumer1, totalRecvNum);}); + + // wait for create shm page to finish + const int SLEEP_TIME = 20; + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); + DS_ASSERT_OK(Producer1->Close()); + } + + void CloseTwoProducerDuringSendReceive(int numOfWorker) + { + std::shared_ptr Producer1; + std::shared_ptr Producer2; + std::shared_ptr Consumer1; + std::string streamName = "CloseTwoProdDuringSendRecv"; + GetProducerConsumers(numOfWorker, Producer1, Producer2, Consumer1, streamName); + + int numElement = 1; + int totalRecvNum = 2; + int threadNum = 3; + ThreadPool pool(threadNum); + datasystem::inject::Set("ProducerImpl.Send.delay", "call()"); + pool.Submit([this, Producer1, numElement]() { SendHelper(Producer1, numElement);}); + pool.Submit([this, Producer2, numElement]() { SendHelper(Producer2, numElement);}); + pool.Submit([this, Consumer1, totalRecvNum]() { ReceiveHelper(Consumer1, totalRecvNum);}); + + const int SLEEP_TIME = 20; + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); + DS_ASSERT_OK(Producer1->Close()); + DS_ASSERT_OK(Producer2->Close()); + } + + void CloseConsumerWhileSend(int numOfWorker) + { + std::shared_ptr Producer1; + std::shared_ptr Producer2; + std::shared_ptr Consumer1; + std::string streamName = "CloseConWhileSend"; + GetProducerConsumers(numOfWorker, Producer1, Producer2, Consumer1, streamName); + + int numElement = 1000; + int threadNum = 2; + ThreadPool pool(threadNum); + pool.Submit([this, Producer1, numElement]() { SendHelper(Producer1, numElement);}); + pool.Submit([this, Producer2, numElement]() { SendHelper(Producer2, numElement);}); + + const int SLEEP_TIME = 10; + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); + DS_ASSERT_OK(Consumer1->Close()); + } + + void CloseAllWhileSendReceiveDone(int numOfWorker) + { + std::shared_ptr Producer1; + std::shared_ptr Producer2; + std::shared_ptr Consumer1; + std::string streamName = "CloseAllWhileSendRecv"; + GetProducerConsumers(numOfWorker, Producer1, Producer2, Consumer1, streamName); + + int numElement = 500; + int totalRecvNum = numElement * 2; + SendHelper(Producer1, numElement); + SendHelper(Producer2, numElement); + ReceiveHelper(Consumer1, totalRecvNum); + + DS_ASSERT_OK(Producer1->Close()); + DS_ASSERT_OK(Producer2->Close()); + DS_ASSERT_OK(Consumer1->Close()); + } + + void NewProducerContinueSend(int numOfWorker) + { + std::shared_ptr Producer1; + std::shared_ptr Producer2; + std::shared_ptr Consumer1; + std::string streamName = "NewProducerContinueSend"; + GetProducerConsumers(numOfWorker, Producer1, Producer2, Consumer1, streamName); + + int numElement = 1; + int totalRecvNum = 3 * numElement; + int threadNum = 3; + ThreadPool pool(threadNum); + datasystem::inject::Set("ProducerImpl.Send.delay", "call(10)"); + pool.Submit([this, Producer1, numElement]() { SendHelper(Producer1, numElement);}); + pool.Submit([this, Producer2, numElement]() { SendHelper(Producer2, numElement);}); + pool.Submit([this, Consumer1, totalRecvNum]() { ReceiveHelper(Consumer1, totalRecvNum);}); + + const int SLEEP_TIME = 20; + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); + + // Close Producer 1, then Create Producer 3 to continue sending + DS_ASSERT_OK(Producer1->Close()); + std::shared_ptr Producer3; + GetProducer(numOfWorker, Producer3, streamName); + auto fut = pool.Submit([this, Producer3, numElement]() { SendHelper(Producer3, numElement);}); + fut.get(); + } + + void ConsumerContinueReceive(int numOfWorker) + { + std::shared_ptr Producer1; + std::shared_ptr Producer2; + std::shared_ptr Consumer1; + std::string streamName = "ConsumerContinueReceive"; + GetProducerConsumers(numOfWorker, Producer1, Producer2, Consumer1, streamName); + + int numElement = 1; + int totalRecvNum = 2 * numElement; + int threadNum = 3; + ThreadPool pool(threadNum); + datasystem::inject::Set("ProducerImpl.Send.delay", "call(100)"); + pool.Submit([this, Producer1, numElement]() { SendHelper(Producer1, numElement);}); + pool.Submit([this, Producer2, numElement]() { SendHelper(Producer2, numElement);}); + + // Close Consumer 1, then Create Consumer 2 to continue receiving + DS_ASSERT_OK(Consumer1->Close()); + std::shared_ptr Consumer2; + GetConsumer(numOfWorker, Consumer2, streamName); + ReceiveHelper(Consumer2, totalRecvNum); + } + +protected: + void InitTest() + { + uint32_t workerIndex = 0; + HostPort workerAddress1; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex++, workerAddress1)); + HostPort workerAddress2; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex++, workerAddress2)); + HostPort workerAddress3; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex, workerAddress3)); + LOG(INFO) << FormatString("\n Worker1: <%s>\n Worker2: <%s>\n Worker3: <%s>", workerAddress1.ToString(), + workerAddress2.ToString(), workerAddress3.ToString()); + InitStreamClient(0, client1_); + InitStreamClient(1, client2_); + InitStreamClient(2, client3_); // worker index is 2 + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + std::shared_ptr client1_ = nullptr; + std::shared_ptr client2_ = nullptr; + std::shared_ptr client3_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + + // cluster config + int numWorkers = 3; + int numEtcd = 1; + int numRpcThreads = 0; +}; + +// these testsuites test all scenarios in one testcase, 1 worker, 2 worker and 3 worker +TEST_F(BasicMultipleProducerSingleConsumerTest, CloseProducerDuringSend) +{ + for (int i = 1; i <= numWorkers; i++) { + CloseProducerDuringSend(i); + } +} + +TEST_F(BasicMultipleProducerSingleConsumerTest, CloseTwoProducerDuringSendReceive) +{ + for (int i = 1; i <= numWorkers; i++) { + CloseTwoProducerDuringSendReceive(i); + } +} + +TEST_F(BasicMultipleProducerSingleConsumerTest, CloseConsumerWhileSend) +{ + for (int i = 1; i <= numWorkers; i++) { + CloseConsumerWhileSend(i); + } +} + +TEST_F(BasicMultipleProducerSingleConsumerTest, CloseAllWhileSendReceiveDone) +{ + for (int i = 1; i <= numWorkers; i++) { + CloseAllWhileSendReceiveDone(i); + } +} + +TEST_F(BasicMultipleProducerSingleConsumerTest, NewProducerContinueSend) +{ + for (int i = 1; i <= numWorkers; i++) { + NewProducerContinueSend(i); + } +} + +TEST_F(BasicMultipleProducerSingleConsumerTest, ConsumerContinueReceive) +{ + for (int i = 1; i <= numWorkers; i++) { + ConsumerContinueReceive(i); + } +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/producer_test.cpp b/tests/st/client/stream_cache/producer_test.cpp new file mode 100644 index 0000000..1b46290 --- /dev/null +++ b/tests/st/client/stream_cache/producer_test.cpp @@ -0,0 +1,4311 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/stream/producer.h" +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/stream_client.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/random_data.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/stream/consumer.h" + +DS_DECLARE_uint32(page_size); +using namespace datasystem::client::stream_cache; + +namespace datasystem { +namespace st { +constexpr int K_TWO = 2; +const uint64_t ELEMENTS_TOTAL_SIZE = 64 * 1024 * 1024; // 64 MB +const uint64_t DEFAULT_MAX_STREAM_SIZE = 64 * 1024 * 1024; // 64MB; + +/** + * @brief SendConfig. + * @param[in] streamName Name of stream. + * @param[in] producerName Name of producer. + * @param[in] expectedNumOfConsumers Number of consumers to wait for this stream. + * @param[in] numOfElements Number of elements to send. + */ +struct SendConfig { + std::string streamName; + std::string producerName; + ProducerConf producerConf; + size_t numOfElements; +}; + +/** + * @brief RecvConfig. + * @param[in] streamName Name of stream. + * @param[in] subscriptionName Name of subscription (stream mode). + * @param[in] numOfBatchElements Number of elements to receive. + * @param[in] timeToWaitMs Time to wait in milli-seconds. + * @param[in] ackInterval Number of elements received for ack. + */ +struct RecvConfig { + std::string streamName; + std::string subscriptionName; + size_t numOfBatchElements; + size_t timeToWaitMs; + size_t ackInterval; + bool autoAck; +}; + +class ProducerTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = NUM_WORKERS; + opts.vLogLevel = 2; + opts.enableDistributedMaster = "false"; + opts.masterIdx = 1; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + signature_ = std::make_unique(accessKey_, secretKey_); + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + ExternalClusterTest::TearDown(); + } + + const int count = 30; + + /** + * @brief Send streaming data. + * @param[in] sendConfig Configuration for a producer. + * @param[in] expectedNumOfConsumers Number of consumers to wait for this stream. + * @param[in] elementsFut Elements to send for this stream. + * @param[in] spClient The stream client to user for this send loop + */ + Status SendStreamData(const SendConfig &sendConfig, uint64_t expectedNumOfConsumers, + std::shared_future> &elementsFut, + std::shared_ptr spClient); + + /** + * @brief Send streaming data with sleep in between for test purpose. + * @param[in] sendConfig Configuration for a producer. + * @param[in] expectedNumOfConsumers Number of consumers to wait for this stream. + * @param[in] elementsFut Elements to send for this stream. + * @param[in] spClient The stream client to user for this send loop + */ + Status SendStreamDataSlow(const SendConfig &sendConfig1, const SendConfig &sendConfig2, + uint64_t expectedNumOfConsumers, + std::shared_future> &elementsFut1, + std::shared_future> &elementsFut2, + std::shared_ptr spClient1, std::shared_ptr spClient2); + /* + * @brief Receive streaming data. + * @param[in] recvConfig Configuration for a receiver. + * @param[in] numOfElements Total number of elements to receive. + * @param[in] spClient The stream client to use for this receive loop + * This does not work if there are multiple producers of the same stream and should be turned off in that case. + */ + Status RecvStreamData(const RecvConfig &recvConfig, size_t numOfElements, std::shared_ptr spClient); + + /* + * @brief Receive streaming data with slowReceive set to true. + * @param[in] recvConfig Configuration for a receiver. + * @param[in] numOfElements Total number of elements to receive. + * @param[in] spClient The stream client to use for this receive loop + * This does not work if there are multiple producers of the same stream and should be turned off in that case. + */ + Status RecvStreamDataWithSlowReceive(const RecvConfig &recvConfig, size_t numOfElements, + std::shared_ptr spClient); + + /** + * @brief Creates a stream client at the given worker num + * @param[in] workerNum The worker num to create the stream against + * @param[out] spClient Shared pointer to the stream client + * @return status + */ + Status CreateClient(int workerNum, std::shared_ptr &spClient) + { + InitStreamClient(workerNum, spClient); + return Status::OK(); + } + + /** + * @brief Creates a stream client at the given worker num and timeout + * @param[in] workerNum The worker num to create the stream against + * @param[in] timeout Timeout for RPC requests + * @param[out] spClient Shared pointer to the stream client + * @return status + */ + Status CreateClient(int workerNum, int32_t timeout, std::shared_ptr &spClient) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(workerNum, workerAddress)); + // Create a client with user defined timeout + ConnectOptions connectOptions = { .host = workerAddress.Host(), + .port = workerAddress.Port(), + .connectTimeoutMs = timeout }; + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + spClient = std::make_shared(connectOptions); + RETURN_IF_NOT_OK(spClient->Init()); + return Status::OK(); + } + + /** + * @brief Creates a client worker api at the given worker num + * @param[in] workerNum The worker num to create the api to + * @param[out] spClient Shared pointer to the stream client + * @return status + */ + Status CreateClientWorkerApi(int workerNum, std::shared_ptr &workerApi) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(workerNum, workerAddress)); + workerApi = std::make_shared(workerAddress, RpcCredential(), signature_.get()); + RETURN_IF_NOT_OK(workerApi->Init(CLIENT_RPC_TIMEOUT)); + return Status::OK(); + } + + /** + * @brief A bunch of the OOM testing all have the same setup/flow but different configs. This function captures + * the common logic so that different tests can run the same stuff with different configs + * @param[in] prodClient The client for the producer side + * @param[in] conClient The client for the consumer side + * @param[in] maxStreamSizeMB Max stream size used by stream + */ + Status RunOOMTest(std::shared_ptr prodClient, std::shared_ptr conClient, + std::string streamName, uint64_t maxStreamSizeMB = 2); + +protected: + const int SLOW_CONSUME_WAIT = 5; // seconds + const int CLIENT_THREAD_POOL_SIZE = 2; + const int NUM_WORKERS = 2; + const int SLEEP_TIME = 10; + const uint32_t RECV_WAIT_MILLI_SECONDS = 20; + static constexpr int CLIENT_RPC_TIMEOUT = 4 * 60 * 1000; + void InitTest() + { + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + waitForGoCount_ = 1; // assume single consumer. if a testcase has more, then it need to update this + interrupt_ = false; + } + + void WaitForConsumers(std::shared_ptr spClient, std::string streamName, + uint64_t expectedNumOfConsumers); + Status SendHelper(const std::vector &elements, const SendConfig &sendConfig, + std::shared_ptr producer, std::atomic &numOfRetries, unsigned long failInterval, + bool slowSend); + + uint64_t CheckProducerCount(std::shared_ptr spClient, std::string streamName) + { + uint64_t totalProducerNum; + spClient->QueryGlobalProducersNum(streamName, totalProducerNum); + return totalProducerNum; + } + + Status TryAndDeleteStream(std::shared_ptr spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + std::unique_ptr signature_; + std::atomic waitForGoCount_; + std::atomic interrupt_; + + // These knobs tune how RecvStreamData and SendStreamData behave, since different testcases want to test + // different scenarios. It's easier to store them here in the class with defaults rather than pass them around. + bool slowConsume_ = false; // Causes the receiver to pause at the start to allow sender to fill up and OOM + bool slowReceive_ = false; // Another way to throttle the receive side. It injects a sleep between recv's. + bool validate_ = true; // Inspects every record at the receiver. Cannot be used with multi producers + bool waitForGo_ = false; // used with waitForGoCount_ to sync producers + bool checkNumOfFails_ = false; // Provides a cap on retries and eventually fails if too many + bool clientRetry_ = true; // Toggles if the client should rerun the send if it got on OOM. + bool autoAck_ = false; // Toggles auto-ack mode + uint16_t prefetchLWM_ = 0; // Enable client prefetch if non-zero (0 to 100) + uint32_t clientCacheSize_ = 0; // client cache size. (0 means use the external default) + uint64_t earlyExitCount_ = 0; // Set to non-zero value if you want the receive to quit early once N elements recv'd + int64_t sendTimeout_ = 1; // timeout arg for send calls + uint64_t eleSz_ = 1024; // size of each element + size_t ackInterval_ = 0; // ack interval + uint64_t elementsTotalSize_ = ELEMENTS_TOTAL_SIZE; // total size of all elements. gets carved into elements +}; + +// Configure with a small amount +class BigShmTest : public ProducerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = NUM_WORKERS; + opts.workerGflagParams = "-shared_memory_size_mb=512"; + SCClientCommon::SetClusterSetupOptions(opts); + } + + Status MultiTest_NStream(std::vector> &clients, int64_t pageSize, + uint64_t maxStreamSize, int numStreams, int numProds, int numSubs); +}; + +Status BigShmTest::MultiTest_NStream(std::vector> &clients, int64_t pageSize, + uint64_t maxStreamSize, int numStreams, int numProds, int numSubs) +{ + ThreadPool preparePool(1); + uint64_t eleSz = eleSz_; + uint64_t numElements = elementsTotalSize_ / eleSz; + waitForGoCount_ = numSubs; + waitForGo_ = true; // Ensure sync'd up before sending starts + checkNumOfFails_ = true; + + LOG(INFO) << FormatString("Testing Size: %zu", eleSz); + std::shared_future> elementsFut = preparePool.Submit([eleSz, numElements]() { + ElementGenerator elementGenerator(eleSz + 1, std::min(eleSz, KB)); + auto elements = elementGenerator.GenElements("producer", numElements, 8ul); + LOG(INFO) << "Element data generated. return from generator thread."; + return elements; + }); + + // Wait for the data generation to complete before we launch producers and consumers + LOG(INFO) << "Waiting for data generation"; + while (elementsFut.wait_for(std::chrono::seconds(1)) != std::future_status::ready) + ; + LOG(INFO) << "Data generation complete. kick off threads now"; + + ThreadPool pool(numProds + numSubs); // enough threads for the producers and consumers + std::vector streamNames; + + for (int i = 0; i < numStreams; ++i) { + std::string newStream = "stream" + std::to_string(i); + streamNames.push_back(newStream); + } + + // kick off the producers first. + std::unordered_map expectedRecvCounts; + std::vector> prodFutures; + for (int i = 0; i < numProds; ++i) { + // Round robin the client for each producer. This will spread producers over the list of workers. + int clientIdx = i % clients.size(); + std::shared_ptr prodClient = clients[i % clients.size()]; + LOG(INFO) << "Creating a producer/sender client thread for client/worker index: " << clientIdx; + + // Round robin the stream name as well for each producer + std::string streamName = streamNames[i % numStreams]; + + // What if multiple producers target the same stream? Say stream 1 gets 10 records from producer1. + // But producer2 is also sending to stream 1, so the number of data sent is 20 into that stream. + // Identity the expected receive counts for each stream. + auto iter = expectedRecvCounts.find(streamName); + if (iter == expectedRecvCounts.end()) { + // This stream not accounted for yet. Insert new key into the hash table with the expected initial count + expectedRecvCounts[streamName] = numElements; + } else { + // Another producer targets this same stream. bump the expected count. + iter->second += numElements; + } + + prodFutures.emplace_back( + pool.Submit([this, &i, streamName, &elementsFut, numElements, prodClient, pageSize, maxStreamSize]() { + ProducerConf prodCfg = { + .delayFlushTime = 20, .pageSize = pageSize, .maxStreamSize = maxStreamSize, .autoCleanup = true + }; + prodCfg.reserveSize = maxStreamSize; // reserve everything + SendConfig sendCfg = { .streamName = streamName, + .producerName = "producer", + .producerConf = prodCfg, + .numOfElements = numElements }; + return SendStreamData(sendCfg, 1, elementsFut, prodClient); + })); + } + + std::vector> subFutures; + // instead of using i, create a custom counter for the receiver that is 1 off the value of i. + // This results in staggering the clients. so: + // - producer on client0 pairs with consumer on client1 + // - producer on client1 pairs with consumer on client0 + // What if we want to have same-node client's but multiple workers? No supported at this time for this + // test function. + int clientCounter = 1; + for (int i = 0; i < numSubs; ++i) { + uint64_t numStreamElements = 0; + // Round robin the client for each consumer. + int clientIdx = clientCounter % clients.size(); + std::shared_ptr conClient = clients[clientIdx]; + LOG(INFO) << "Creating a receiver/consumer client thread for client/worker index: " << clientIdx; + ++clientCounter; + + // Round robin the stream name as well for each consumer. + std::string streamName = streamNames[i % numStreams]; + + numStreamElements = expectedRecvCounts[streamName]; + + size_t ackInterval; + if (ackInterval_ == 0) { + ackInterval = std::max(400ul * KB / eleSz, 1ul); + } else { + ackInterval = ackInterval_; + } + subFutures.emplace_back(pool.Submit([this, &i, streamName, numStreamElements, ackInterval, conClient]() { + RecvConfig rcvCfg = { .streamName = streamName, + .subscriptionName = "subscription", + .numOfBatchElements = 100, + .timeToWaitMs = 20, + .ackInterval = ackInterval, + .autoAck = autoAck_ }; + return RecvStreamData(rcvCfg, numStreamElements, conClient); + })); + } + + // If any of the producers or subscribers got non-ok, they will assign return rc. + // The last non-ok error collected will be the winner to return to the caller. + Status returnRc = Status::OK(); + for (int i = 0; i < numSubs; ++i) { + Status rc = subFutures[i].get(); + if (rc.IsError()) { + returnRc = rc; + } + } + + for (int i = 0; i < numProds; ++i) { + Status rc = prodFutures[i].get(); + if (rc.IsError()) { + returnRc = rc; + } + } + + return returnRc; +} + +Status CreateElement(size_t elementSize, Element &element, std::vector &writeElement) +{ + writeElement = RandomData().RandomBytes(elementSize); + element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + return Status::OK(); +} + +TEST_F(ProducerTest, LEVEL1_TestSendTriggerBackPressure) +{ + Timer timer; + ThreadPool pool(2); + + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + // Producer sends in total 32 MB data to a stream of 4 MB capacity and 64 KB-sized pages. + size_t numOfElements = 320000; + auto producerFut = pool.Submit([this, numOfElements, spClient]() { + size_t numOfRetries = 0; + std::shared_ptr producer; + ProducerConf producerConf{ .delayFlushTime = 20, .pageSize = 64 * KB, .maxStreamSize = 4 * MB }; + uint64_t elementSize = 100; + ElementGenerator elementGenerator(elementSize + 1, elementSize); + uint64_t numOfConsumers = 0; + while (numOfConsumers != 1) { + spClient->QueryGlobalConsumersNum("SendTriggerBackPressure", numOfConsumers); + } + RETURN_IF_NOT_OK(spClient->CreateProducer("SendTriggerBackPressure", producer, producerConf)); + auto elements = elementGenerator.GenElements("producer", numOfElements, 8ul); + Timer timer; + for (size_t i = 0; i < numOfElements; i++) { + // Resend if fail due to the async back-pressure mechanism. + auto element = Element(reinterpret_cast(&elements[i].front()), elements[i].size()); + auto status = producer->Send(element); + while (!status.IsOk()) { + numOfRetries++; + status = producer->Send(element); + } + } + LOG(INFO) << FormatString("Producer's number of re-sending: %zu, sending time: %.6lf s", numOfRetries, + timer.ElapsedSecond()); + return Status::OK(); + }); + + // Receiver receive all the data. + auto consumerFut = pool.Submit([this, numOfElements, spClient]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(spClient->Subscribe("SendTriggerBackPressure", config, consumer)); + std::unordered_map seqNoMap; + Timer timer; + size_t ackInterval = 4096; + for (size_t i = 0; i < numOfElements;) { + std::vector outElements; + consumer->Receive(1, 20, outElements); + i += outElements.size(); + for (auto &element : outElements) { + ElementView elementView(std::string(reinterpret_cast(element.ptr), element.size)); + RETURN_IF_NOT_OK(elementView.VerifyFifo(seqNoMap)); + RETURN_IF_NOT_OK(elementView.VerifyIntegrity()); + } + if (i % ackInterval == (ackInterval - 1)) { + LOG(INFO) << FormatString("Ack id: %zu", outElements.back().id); + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + } + } + LOG(INFO) << FormatString("Total Recv Time Elapsed: %.6lf s", timer.ElapsedSecond()); + return Status::OK(); + }); + + ASSERT_EQ(consumerFut.get(), Status::OK()); + ASSERT_EQ(producerFut.get(), Status::OK()); + LOG(INFO) << FormatString("End To End Time Elapsed: %.6lf s", timer.ElapsedSecond()); +} + +Status ProducerTest::RecvStreamData(const RecvConfig &recvConfig, size_t numOfElements, + std::shared_ptr spClient) +{ + // Step 1: Subscribe. + // On producer side, we ensure waiting for all the subscriptions before data sending. + struct SubscriptionConfig cfg; + cfg.subscriptionName = recvConfig.subscriptionName; + cfg.subscriptionType = SubscriptionType::STREAM; + if (prefetchLWM_) { + cfg.cachePrefetchLWM = prefetchLWM_; + } + if (clientCacheSize_) { + cfg.cacheCapacity = clientCacheSize_; + } + std::shared_ptr consumer; + RETURN_IF_NOT_OK(spClient->Subscribe(recvConfig.streamName, cfg, consumer, recvConfig.autoAck)); + + LOG(INFO) << "Stream consumer created"; + --waitForGoCount_; + + if (slowConsume_) { + // purposely cause a delay so that the producer will fill up the memory and get an OOM. + // Once this sleep is done, then it will start to drain and clear the OOM issues. + LOG(INFO) << "Pausing the consumer to cause producer to build up data"; + std::this_thread::sleep_for(std::chrono::seconds(SLOW_CONSUME_WAIT)); + } + + // Step 2: Receive and verify elements. + LOG(INFO) << "Starting receive loop. This consumer will expect to receive " << numOfElements << " elements"; + std::unordered_map seqNoMap; + Timer timer; + size_t ackInterval = recvConfig.ackInterval; + size_t toAckNum = 0; + std::vector outElements; + for (size_t i = 0; i < numOfElements && !interrupt_;) { + RETURN_IF_NOT_OK(consumer->Receive(recvConfig.numOfBatchElements, RECV_WAIT_MILLI_SECONDS, outElements)); + i += outElements.size(); + // Don't log all the retries from timed out Receive calls. Only log if we got something. + if (outElements.size() != 0) { + LOG(INFO) << "Consumer received. total elements read so far: " << i; + } + if (validate_) { + for (auto &element : outElements) { + ElementView elementView(std::string(reinterpret_cast(element.ptr), element.size)); + RETURN_IF_NOT_OK(elementView.VerifyFifo(seqNoMap)); + RETURN_IF_NOT_OK(elementView.VerifyIntegrity()); + } + } else if (slowReceive_) { + // Make the receiver a bit slower in between each call. + // If validate is true, you get this for "free" because generally validating takes some work to compare + // the data which makes the receive side run slow between each receive. + const int recvSleepTime = 450; + std::this_thread::sleep_for(std::chrono::microseconds(recvSleepTime)); + } + toAckNum += outElements.size(); + if (!autoAck_ && toAckNum >= ackInterval) { + LOG(INFO) << FormatString("Ack id: %zu", outElements.back().id); + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + toAckNum = 0; + } + + if (earlyExitCount_ != 0 && i > earlyExitCount_) { + LOG(INFO) << "Consumer has reached its early exit row count of " << earlyExitCount_ + << FormatString(". Total Recv Time Elapsed: %.6lf s", timer.ElapsedSecond()); + return Status::OK(); + } + } + if (interrupt_) { + LOG(INFO) << "Consumer loop quits due to interrupt"; + } + + LOG(INFO) << FormatString("Total Recv Time Elapsed: %.6lf s", timer.ElapsedSecond()); + return Status::OK(); +} + +// For test purposes. Testing back pressure. Modified RecvStreamData where slowReceive_ is true. +// else if (slowReceive_) is removed to always delay between recv. +Status ProducerTest::RecvStreamDataWithSlowReceive(const RecvConfig &recvConfig, size_t numOfElements, + std::shared_ptr spClient) +{ + struct SubscriptionConfig cfg; + cfg.subscriptionName = recvConfig.subscriptionName; + cfg.subscriptionType = SubscriptionType::STREAM; + if (prefetchLWM_) { + cfg.cachePrefetchLWM = prefetchLWM_; + } + if (clientCacheSize_) { + cfg.cacheCapacity = clientCacheSize_; + } + std::shared_ptr consumer; + RETURN_IF_NOT_OK(spClient->Subscribe(recvConfig.streamName, cfg, consumer, recvConfig.autoAck)); + + LOG(INFO) << "Stream slow consumer created"; + --waitForGoCount_; + + if (slowConsume_) { + LOG(INFO) << "Pausing the slow consumer to cause producer to build up data"; + std::this_thread::sleep_for(std::chrono::seconds(SLOW_CONSUME_WAIT)); + } + + LOG(INFO) << "Starting receive loop. This slow consumer will expect to receive " << numOfElements << " elements"; + std::unordered_map seqNoMap; + Timer timer; + size_t ackInterval = recvConfig.ackInterval, toAckNum = 0; + std::vector outElements; + for (size_t i = 0; i < numOfElements && !interrupt_;) { + RETURN_IF_NOT_OK(consumer->Receive(recvConfig.numOfBatchElements, RECV_WAIT_MILLI_SECONDS, outElements)); + i += outElements.size(); + LOG(INFO) << "Slow consumer received. total elements read so far: " << i; + if (validate_) { + for (auto &element : outElements) { + ElementView elementView(std::string(reinterpret_cast(element.ptr), element.size)); + RETURN_IF_NOT_OK(elementView.VerifyFifo(seqNoMap)); + RETURN_IF_NOT_OK(elementView.VerifyIntegrity()); + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); // delay to slow recv + toAckNum += outElements.size(); + if (!autoAck_ && toAckNum >= ackInterval) { + LOG(INFO) << FormatString("Slow consumer Ack id: %zu", outElements.back().id); + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + toAckNum = 0; + } + + if (earlyExitCount_ != 0 && i > earlyExitCount_) { + LOG(INFO) << "Slow consumer has reached its early exit row count of " << earlyExitCount_; + return Status::OK(); + } + } + if (interrupt_) { + LOG(INFO) << "Slow consumer loop quits due to interrupt"; + } + + LOG(INFO) << FormatString("Slow consumer Total Recv Time Elapsed: %.6lf s", timer.ElapsedSecond()); + return Status::OK(); +} + +void ProducerTest::WaitForConsumers(std::shared_ptr spClient, std::string streamName, + uint64_t expectedNumOfConsumers) +{ + uint64_t numOfConsumers = 0; + while (true) { + spClient->QueryGlobalConsumersNum(streamName, numOfConsumers); + LOG(INFO) << "Current NumOf Consumers: " << numOfConsumers; + if (numOfConsumers == expectedNumOfConsumers) { + break; + } + auto sleepUseconds = 5000ul; + usleep(sleepUseconds); + } +} + +// Helper function used in SendStreamDataSlow. Removes duplicate code and to +// keep function <= 50 lines +Status ProducerTest::SendHelper(const std::vector &elements, const SendConfig &sendConfig, + std::shared_ptr producer, std::atomic &numOfRetries, + unsigned long failInterval, bool slowSend) +{ + for (size_t i = 0; i < sendConfig.numOfElements; ++i) { + if (slowSend) { + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); // add delay on send + } + // Resend if fail due to the async back-pressure mechanism. + auto idx = i % elements.size(); + // Force cast due to element interface. + auto element = Element((uint8_t *)(elements[idx].data()), elements[idx].size()); + auto status = producer->Send(element, sendTimeout_); + if (status.IsError()) { + LOG(INFO) << "Send element " << idx << " failed. Will do client retry? " << std::boolalpha << clientRetry_ + << " " << status.ToString(); + if (!clientRetry_) { + interrupt_ = true; // break the receiver loop also and quit the test + return status; + } + } + while (!status.IsOk() && clientRetry_) { + // Avoiding consumers being too slow to process + usleep(failInterval); + numOfRetries++; + if (checkNumOfFails_) { + size_t retryMaxTimes = 30u; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(numOfRetries < retryMaxTimes, K_RUNTIME_ERROR, "too many retries"); + } + status = producer->Send(element); + } + } + return Status::OK(); +} + +Status ProducerTest::SendStreamData(const SendConfig &sendConfig, uint64_t expectedNumOfConsumers, + std::shared_future> &elementsFut, + std::shared_ptr spClient) +{ + auto failInterval = 500000ul; + // Step 1: Create a producer. + std::shared_ptr producer; + RETURN_IF_NOT_OK(spClient->CreateProducer(sendConfig.streamName, producer, sendConfig.producerConf)); + + // Step 2: Wait for consumers. + WaitForConsumers(spClient, sendConfig.streamName, expectedNumOfConsumers); + + if (waitForGo_) { + // block until the receiver side tells us to go. When each receiver is ready it decrements the waitForGo count. + while (waitForGoCount_ > 0) { + LOG(INFO) << "Waiting for consumer..."; + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + } + LOG(INFO) << "Starting send loop"; + + // Step 3: Send elements and retry on failure. + Timer timer; + uint64_t numOfRetries = 0; + const std::vector &elements = elementsFut.get(); + for (size_t i = 0; i < sendConfig.numOfElements; i++) { + // Resend if fail due to the async back-pressure mechanism. + auto idx = i % elements.size(); + // Force cast due to element interface. + auto element = Element((uint8_t *)(elements[idx].data()), elements[idx].size()); + auto status = producer->Send(element, sendTimeout_); + if (status.IsError()) { + LOG(INFO) << "Send element " << idx << " failed. Will do client retry? " << std::boolalpha << clientRetry_ + << " " << status.ToString(); + if (!clientRetry_) { + interrupt_ = true; // break the receiver loop also and quit the test + return status; + } + } + while (!status.IsOk() && clientRetry_) { + // Avoiding consumers being too slow to process + usleep(failInterval); + numOfRetries++; + if (checkNumOfFails_) { + size_t retryMaxTimes = 30u; + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(numOfRetries < retryMaxTimes, K_RUNTIME_ERROR, "too many retries"); + } + status = producer->Send(element); + } + } + LOG(INFO) << FormatString("Producer's number of re-sending: %zu/%zu, sending time: %.6lf s", numOfRetries, + sendConfig.numOfElements, timer.ElapsedSecond()); + + LOG(INFO) << "Small sleep so that producing workers fully send all their data before closing."; + std::this_thread::sleep_for(std::chrono::seconds(2)); + // producer ptr will descope here and destructor call close (not the recommended say to close!) + return Status::OK(); +} + +// For test purposes. Testing back pressure. 2 producers send. One producer sends slower than other. +Status ProducerTest::SendStreamDataSlow(const SendConfig &sendConfig1, const SendConfig &sendConfig2, + uint64_t expectedNumOfConsumers, + std::shared_future> &elementsFut1, + std::shared_future> &elementsFut2, + std::shared_ptr spClient1, + std::shared_ptr spClient2) +{ + auto failInterval = 500000ul; + // Step 1: Create a producer. + std::shared_ptr producer1, producer2; + RETURN_IF_NOT_OK(spClient1->CreateProducer(sendConfig1.streamName, producer1, sendConfig1.producerConf)); + RETURN_IF_NOT_OK(spClient2->CreateProducer(sendConfig2.streamName, producer2, sendConfig2.producerConf)); + + // Step 2: Wait for consumers. + WaitForConsumers(spClient1, sendConfig1.streamName, expectedNumOfConsumers); + WaitForConsumers(spClient2, sendConfig2.streamName, expectedNumOfConsumers); + + if (waitForGo_) { + // block until the receiver side tells us to go. When each receiver is ready it decrements the waitForGo count. + while (waitForGoCount_ > 0) { + LOG(INFO) << "Waiting for consumer..."; + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + } + LOG(INFO) << "Starting send loop producers"; + + // Step 3: Send elements and retry on failure. + Timer timer; + std::atomic numOfRetries1{ 0 }; + std::atomic numOfRetries2{ 0 }; + const std::vector &elements1 = elementsFut1.get(); + const std::vector &elements2 = elementsFut2.get(); + + // Send data for producer1 + auto sendProducer1 = [&]() { + SendHelper(elements1, sendConfig1, producer1, numOfRetries1, failInterval, false); + LOG(INFO) << FormatString("Producer1's number of re-sending: %zu/%zu, sending time: %.6lf s", numOfRetries1, + sendConfig1.numOfElements, timer.ElapsedSecond()); + return Status::OK(); + }; + + // Slow send data for producer2 + auto sendProducer2 = [&]() { + SendHelper(elements2, sendConfig2, producer2, numOfRetries2, failInterval, true); + LOG(INFO) << FormatString("Producer2's number of re-sending: %zu/%zu, sending time: %.6lf s", numOfRetries2, + sendConfig2.numOfElements, timer.ElapsedSecond()); + return Status::OK(); + }; + + std::thread producer2Thread(sendProducer1); + std::thread producer1Thread(sendProducer2); + + producer2Thread.join(); + producer1Thread.join(); + + LOG(INFO) << "Producer Small sleep so that producing workers fully send all their data before closing."; + // std::this_thread::sleep_for(std::chrono::seconds(5)); + // producer ptr will descope here and destructor call close (not the recommended say to close!) + return Status::OK(); +} + +Status ProducerTest::RunOOMTest(std::shared_ptr prodClient, std::shared_ptr conClient, + std::string streamName, uint64_t maxStreamSizeMB) +{ + ThreadPool preparePool(1); + uint64_t eleSz = 8192ul; // this should be small element, not big element + uint64_t numElements = elementsTotalSize_ / eleSz; + slowConsume_ = true; + waitForGo_ = true; + + LOG(INFO) << FormatString("Testing Size: %zu", eleSz); + std::shared_future> elementsFut = preparePool.Submit([eleSz, numElements]() { + ElementGenerator elementGenerator(eleSz + 1, eleSz); + auto elements = elementGenerator.GenElements("producer", numElements, 8ul); + LOG(INFO) << "Element data generated. return from generator thread."; + return elements; + }); + + // Wait for the data generation to complete before we launch producers and consumers + LOG(INFO) << "Waiting for data generation"; + while (elementsFut.wait_for(std::chrono::seconds(1)) != std::future_status::ready) + ; + LOG(INFO) << "Data generation complete. kick off threads now"; + + ThreadPool pool(CLIENT_THREAD_POOL_SIZE); + auto producerFut = pool.Submit([this, streamName, &elementsFut, numElements, prodClient, maxStreamSizeMB]() { + ProducerConf prodCfg = { .delayFlushTime = 20, .pageSize = 1 * MB, .maxStreamSize = maxStreamSizeMB * MB }; + SendConfig sendCfg = { + .streamName = streamName, .producerName = "producer", .producerConf = prodCfg, .numOfElements = numElements + }; + return SendStreamData(sendCfg, 1, elementsFut, prodClient); + }); + + size_t ackInterval = std::max(400ul * KB / eleSz, 1ul); + auto consumerFut = pool.Submit([this, streamName, numElements, ackInterval, conClient]() { + RecvConfig rcvCfg = { .streamName = streamName, + .subscriptionName = "subscription", + .numOfBatchElements = 100, + .timeToWaitMs = 20, + .ackInterval = ackInterval, + .autoAck = false }; + return RecvStreamData(rcvCfg, numElements, conClient); + }); + + // Collect the status after threads complete + Status cStatus = consumerFut.get(); + Status pStatus = producerFut.get(); + + // If either the producer or consumer run got an error, return their rc as the overall rc of the test run + if (cStatus.IsError()) { + return cStatus; + } else if (pStatus.IsError()) { + return pStatus; + } + + return Status::OK(); +} + +TEST_F(ProducerTest, LEVEL1_TestVaryingEleSz) +{ + ThreadPool preparePool(1); + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + + std::vector eleSzs = { 100, 1 * KB, 4 * KB, 16 * KB }; + int i = 0; + for (uint64_t eleSz : eleSzs) { + LOG(INFO) << FormatString("Testing Size: %zu", eleSz); + const int elementsTotalSize_ = 16 * MB; + uint64_t numElements = elementsTotalSize_ / eleSz; + std::shared_future> elementsFut = preparePool.Submit([eleSz, numElements]() { + ElementGenerator elementGenerator(eleSz + 1, eleSz); + auto elements = elementGenerator.GenElements("producer", numElements, 8ul); + return elements; + }); + + ThreadPool pool(2); + std::string streamName = FormatString("stream%d", i++); + auto producerFut = pool.Submit([this, streamName, &elementsFut, numElements, spClient]() { + ProducerConf prodCfg = { .delayFlushTime = 20, .pageSize = 1 * MB, .maxStreamSize = 4 * MB }; + SendConfig sendCfg = { .streamName = streamName, + .producerName = "producer", + .producerConf = prodCfg, + .numOfElements = numElements }; + return SendStreamData(sendCfg, 1, elementsFut, spClient); + }); + size_t ackInterval = std::max(400ul * KB / eleSz, 1ul); + auto consumerFut = pool.Submit([this, streamName, numElements, ackInterval, spClient]() { + RecvConfig rcvCfg = { .streamName = streamName, + .subscriptionName = "subscription", + .numOfBatchElements = 100, + .timeToWaitMs = 20, + .ackInterval = ackInterval, + .autoAck = false }; + return RecvStreamData(rcvCfg, numElements, spClient); + }); + ASSERT_EQ(consumerFut.get(), Status::OK()); + ASSERT_EQ(producerFut.get(), Status::OK()); + } + + spClient = nullptr; +} + +TEST_F(ProducerTest, TestNoConsumer) +{ + for (int j = 0; j < 5; j++) { + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + auto stream_name = "test_dfx_streamcache_node_scale_004"; + ProducerConf prodCfg = { .delayFlushTime = 5, .pageSize = 1 * MB, .maxStreamSize = 2 * MB }; + std::shared_ptr producer; + spClient->CreateProducer(stream_name, producer, prodCfg); + for (int i = 0; i < 1000; i++) { + std::string data = "test "; + Element element(reinterpret_cast(const_cast(data.data())), data.length()); + ASSERT_EQ(producer->Send(element), Status::OK()); + } + } +} + +TEST_F(ProducerTest, TestInvalidSend) +{ + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient->Subscribe("InvalidSend", config, consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(spClient->CreateProducer("InvalidSend", producer, defaultProducerConf_)); + + std::string data = "Hello World"; + Element element1(reinterpret_cast(&data.front()), 0); + + // 1. Send element with 0 size + const uint32_t timeoutMs = 1000; + std::vector outElements; + DS_ASSERT_NOT_OK(producer->Send(element1)); + DS_ASSERT_OK(consumer->Receive(1, timeoutMs, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + outElements.clear(); + + // 2. Send element with nullptr + Element element2(nullptr, data.length()); + DS_ASSERT_NOT_OK(producer->Send(element2)); + DS_ASSERT_OK(consumer->Receive(1, timeoutMs, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); + outElements.clear(); +} + +TEST_F(ProducerTest, TestVaryingEleSzBigElement) +{ + ThreadPool preparePool(1); + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + checkNumOfFails_ = true; + uint64_t eleSz = 64 * KB; + const int elementsTotalSize_ = 32 * MB; + uint64_t numElements = elementsTotalSize_ / eleSz; + LOG(INFO) << FormatString("Testing Size: %zu", eleSz); + std::shared_future> elementsFut = preparePool.Submit([eleSz, numElements]() { + ElementGenerator elementGenerator(eleSz + 1, eleSz); + auto elements = elementGenerator.GenElements("producer", numElements, 8ul); + return elements; + }); + + ThreadPool pool(2); + std::string streamName = "VaryingEleSzBigElement"; + auto producerFut = pool.Submit([this, streamName, &elementsFut, numElements, spClient]() { + ProducerConf prodCfg = { .delayFlushTime = 20, .pageSize = 1 * MB, .maxStreamSize = 4 * MB }; + SendConfig sendCfg = { + .streamName = streamName, .producerName = "producer", .producerConf = prodCfg, .numOfElements = numElements + }; + return SendStreamData(sendCfg, 1, elementsFut, spClient); + }); + + // To analyze big element memory issue. + // Another configuration: size_t ackInterval = std::max(400ul * KB / eleSz, 1ul); + size_t ackInterval = 1ul; + auto consumerFut = pool.Submit([this, streamName, numElements, ackInterval, spClient]() { + RecvConfig rcvCfg = { .streamName = streamName, + .subscriptionName = "subscription", + .numOfBatchElements = 100, + .timeToWaitMs = 20, + .ackInterval = ackInterval, + .autoAck = false }; + return RecvStreamData(rcvCfg, numElements, spClient); + }); + ASSERT_EQ(producerFut.get(), Status::OK()); + ASSERT_EQ(consumerFut.get(), Status::OK()); + + spClient = nullptr; +} + +TEST_F(ProducerTest, TestBlockingUnBlockingShm) +{ + LOG(INFO) << "TestWorkerCrashStopRemotePush start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = 10 * 1024 * 1024; + DS_ASSERT_OK(client1->CreateProducer("BlockingUnBlockingShm", producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("BlockingUnBlockingShm", config, consumer)); + + const size_t testSize = 4ul * 1024ul; + // Keep sending until out of memory + size_t sendCount = 0; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + while (true) { + Status rc = producer->Send(element); + if (rc.IsOk()) { + ++sendCount; + usleep(10); + continue; + } + ASSERT_EQ(rc.GetCode(), K_OUT_OF_MEMORY); + break; + } + + uint64_t elementId = 0; + std::vector outElements; + while (sendCount) { + DS_ASSERT_OK(consumer->Receive(1, 1000, outElements)); + DS_ASSERT_OK(consumer->Ack(++elementId)); + --sendCount; + outElements.clear(); + } +} + +TEST_F(ProducerTest, TestBlockingUnBlockingShmOutofOrder) +{ + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "RemoteWorker.EnableStreamBlocking.sleep", "11*sleep(2000)")); + LOG(INFO) << "TestWorkerCrashStopRemotePush start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = 10 * 1024 * 1024; + DS_ASSERT_OK(client1->CreateProducer("ShmOutofOrder", producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("ShmOutofOrder", config, consumer)); + + const size_t testSize = 1024 * 1024ul; + // Keep sending until out of memory + size_t sendCount = 1000; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + ThreadPool pool(2); + auto producerFut = pool.Submit([this, producer, element]() { + size_t sendCount = 0; + while (sendCount != 1000) { + Status rc = producer->Send(element); + if (rc.IsOk()) { + ++sendCount; + usleep(10); + continue; + } + LOG(INFO) << rc.GetCode(); + } + }); + sleep(1); + LOG(INFO) << "sendCount " << sendCount; + int count = 0; + while (sendCount) { + std::vector outElements; + const int recvTimeout = 1000; + DS_ASSERT_OK(consumer->Receive(1, recvTimeout, outElements)); + if (outElements.empty()) { + LOG(INFO) << "empty........ " << count; + continue; + } + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + count += outElements.size(); + --sendCount; + } + producerFut.get(); +} + +TEST_F(ProducerTest, TestUnblockingWithEarlyAck) +{ + LOG(INFO) << "TestUnblockingWithEarlyAck start!"; + std::shared_ptr client1, client2; + + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + + std::shared_ptr producer; + const int DEFAULT_MAX_STREAM_SIZE = 5 * MB; + ProducerConf conf; + conf.maxStreamSize = DEFAULT_MAX_STREAM_SIZE; + conf.pageSize = 1 * MB; + DS_ASSERT_OK(client1->CreateProducer("EarlyAck", producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("EarlyAck", config, consumer)); + + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "StreamManager.SendBlockProducerReq.delay", "1*sleep(5000)")); + + const size_t testSize = 500 * KB; + // Keep sending until out of memory + const size_t SEND_COUNT = 100; + std::thread producerThrd([&producer]() { + const int DEFAULT_SLEEP_TIME = 200; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + for (size_t i = 0; i < SEND_COUNT; i++) { + Status rc = producer->Send(element); + int retryCount = 30; + while (rc.GetCode() == K_OUT_OF_MEMORY && retryCount-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + DS_ASSERT_OK(rc); + } + }); + + const int DEFAULT_WAIT_TIME = 1000; + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_WAIT_TIME)); + const int K_100 = 100; + const int DEFAULT_RETRY_TIME = 30; + Timer timer; + std::vector outElements; + int sendCount = SEND_COUNT; + while (sendCount > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + DS_ASSERT_OK(consumer->Receive(1, K_100, outElements)); + if (!outElements.empty()) { + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + sendCount -= outElements.size(); + } + } + ASSERT_EQ(sendCount, 0); + producerThrd.join(); +} + +TEST_F(ProducerTest, TestWriteLessThanWritePage) +{ + LOG(INFO) << "test WriteLessThanWritePage start!"; + // Subscribe before send. + std::shared_ptr consumer; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient->Subscribe("WriteLessThanWritePage", config, consumer)); + + size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + DS_ASSERT_OK(spClient->CreateProducer("WriteLessThanWritePage", producer, defaultProducerConf_)); + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); + spClient = nullptr; +} + +TEST_F(ProducerTest, TestWriteWithPages) +{ + LOG(INFO) << "test WriteWithPages start!"; + size_t testSize = 60ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + DS_ASSERT_OK(spClient->CreateProducer("WriteWithPages", producer, defaultProducerConf_)); + + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient->Subscribe("WriteWithPages", config, consumer)); + + std::string writeData; + for (int i = 0; i < count; i++) { + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + writeData += data; + } + + std::vector outElements; + ASSERT_EQ(consumer->Receive(count, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), static_cast(count)); + std::string actualData; + for (int i = 0; i < count; i++) { + std::string tmp(reinterpret_cast(outElements[i].ptr), outElements[i].size); + actualData += tmp; + } + EXPECT_EQ(writeData, actualData); + spClient = nullptr; +} + +TEST_F(ProducerTest, TestWriteFlushAndCreatePage) +{ + LOG(INFO) << "test WriteBigElement start!"; + size_t testSize = 1000ul * 1024ul; + Element element1, element2; + std::shared_ptr producer; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + DS_ASSERT_OK(spClient->CreateProducer("WriteFlushAndCreatePage", producer, defaultProducerConf_)); + + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient->Subscribe("WriteFlushAndCreatePage", config, consumer)); + + // Disable autoflush + DS_ASSERT_OK(inject::Set("Client.ProducerImpl.SendWithNoAutoFlush", "1*call(-1)")); + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element1, writeElement)); + DS_ASSERT_OK(producer->Send(element1)); + // Page size 1024*1024 byte. Second write will overflow the page. + // Send again for the flush and create page request together. + DS_ASSERT_OK(CreateElement(testSize, element2, writeElement)); + DS_ASSERT_OK(producer->Send(element2)); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), (size_t)1); + outElements.clear(); + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), (size_t)1); + spClient.reset(); +} + +TEST_F(ProducerTest, TestWriteLocalBigElementSuccess) +{ + LOG(INFO) << "test WriteLocalBigElement start!"; + size_t testSize = 2ul * 1024ul * 1024ul; + Element element; + std::shared_ptr producer; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + DS_ASSERT_OK(spClient->CreateProducer("WriteLocalBigElement", producer, defaultProducerConf_)); + + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient->Subscribe("WriteLocalBigElement", config, consumer)); + + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + Status rc = producer->Send(element); + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + const int expectedNum = 1; + const int timeoutMs = 5000; + std::vector outElements; + rc = consumer->Receive(expectedNum, timeoutMs, outElements); + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); + + spClient.reset(); +} + +TEST_F(ProducerTest, TestReleaseBigElementMemoryInErrorCase) +{ + const size_t testSize = 2ul * 1024ul * 1024ul; + std::shared_ptr producer; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + DS_ASSERT_OK(spClient->CreateProducer("ReleaseBigEle", producer, defaultProducerConf_)); + Element element; + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + datasystem::inject::Set("ProducerImpl.ReleaseBigElementMemory", "1*return(K_RUNTIME_ERROR)"); + Status rc = producer->Send(element); + LOG(INFO) << "Expected to get error. Rc returned: " << rc.ToString(); + ASSERT_NE(rc.GetCode(), StatusCode::K_OK); +} + +TEST_F(ProducerTest, LEVEL1_TestBigElementBatchInsertRollback) +{ + // This testcase tests that the Big Element allocated at BatchInsert can be rollback correctly. + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + // Inject to the consumer side worker + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "InsertBigElement.Rollback", "100*return(K_TRY_AGAIN)")); + const size_t testSize = 2ul * 1024ul * 1024ul; + std::string streamName = "BigEleBatchInsertRollback"; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + const size_t SEND_COUNT = 1000; + auto func = [&producer]() { + const int DEFAULT_SLEEP_TIME = 300; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + for (size_t i = 0; i < SEND_COUNT; i++) { + Status rc = producer->Send(element); + int retryCount = 30; + while (rc.GetCode() == K_OUT_OF_MEMORY && retryCount-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + DS_ASSERT_OK(rc); + } + }; + std::thread producerThrd(func); + const int DEFAULT_WAIT_TIME = 10'000; + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_WAIT_TIME)); + + const int DEFAULT_RETRY_TIME = 20; + Timer timer; + std::vector outElements; + int sendCount = SEND_COUNT; + while (sendCount > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + if (!outElements.empty()) { + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + sendCount -= outElements.size(); + } + } + ASSERT_EQ(sendCount, 0); + producerThrd.join(); +} + +TEST_F(ProducerTest, LEVEL1_TestEarlyProducerCloseWhileSendingData) +{ + size_t testSize = 2ul * 1024ul * 1024ul; + Element element; + std::shared_ptr producer; + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + DS_ASSERT_OK(client1->CreateProducer("EarlyProducerClose", producer, defaultProducerConf_)); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "RemoteWorker.BatchAsyncFlushEntry.delay", "11*sleep(20000)")); + + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + DS_ASSERT_OK(client2->Subscribe("EarlyProducerClose", config, consumer)); + + // Send data to producer + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + Status rc = producer->Send(element); + DS_ASSERT_OK(producer->Close()); + + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer("EarlyProducerClose", producer1, defaultProducerConf_)); + for (int i = 0; i < 100; ++i) { + // Send data to producer + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + Status rc = producer1->Send(element); + } + DS_ASSERT_OK(producer1->Close()); + + sleep(SLEEP_TIME); + // Get data and test for contents + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + const int expectedNum = 1; + const int timeoutMs = 5000; + std::vector outElements; + rc = consumer->Receive(expectedNum, timeoutMs, outElements); + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); + + client1.reset(); + client2.reset(); +} + +TEST_F(ProducerTest, DISABLED_TestConsumerCloseWhileDoingScanEval) +{ + // Testcase is disabled because ClearAllRemoteConsumer operations + // no longer waits for FlushAllChanges wait post. + // Create one producer and one consumer + std::shared_ptr spClient1; + ASSERT_EQ(CreateClient(0, spClient1), Status::OK()); + std::shared_ptr producer; + DS_ASSERT_OK(spClient1->CreateProducer("test1", producer, defaultProducerConf_)); + + std::shared_ptr spClient2; + ASSERT_EQ(CreateClient(1, spClient2), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient2->Subscribe("test1", config, consumer)); + + // Create a delay in waking up the WaitPost without sleep + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, 0, "StreamDataPool.ScanChangesAndEval.delaywakeup", "100000*call(0)")); + // Producer produces data + Element element; + std::vector writeElement; + DS_ASSERT_OK(CreateElement(1024, element, writeElement)); + Status rc = producer->Send(element); + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + DS_ASSERT_NOT_OK(producer->Close()); // This will timeout + // Consumer closes will fail because we are still waiting on wait post + DS_ASSERT_OK(consumer->Close()); +} + +TEST_F(ProducerTest, TestWriteRemoteBigElementSuccess) +{ + LOG(INFO) << "test WriteRemoteBigElement start!"; + size_t testSize = 2ul * 1024ul * 1024ul; + Element element; + std::shared_ptr producer; + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + DS_ASSERT_OK(client1->CreateProducer("RemoteBigEleSuccess", producer, defaultProducerConf_)); + + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + DS_ASSERT_OK(client2->Subscribe("RemoteBigEleSuccess", config, consumer)); + + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + Status rc = producer->Send(element); + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + const int expectedNum = 1; + const int timeoutMs = 5000; + std::vector outElements; + rc = consumer->Receive(expectedNum, timeoutMs, outElements); + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); + + client1.reset(); + client2.reset(); +} + +TEST_F(ProducerTest, TestWriteBigElementWhenMmapFailed) +{ + Element element; + std::shared_ptr producer; + std::shared_ptr client; + ASSERT_EQ(CreateClient(0, client), Status::OK()); + DS_ASSERT_OK(client->CreateProducer("MmapFailed", producer, defaultProducerConf_)); + + DS_ASSERT_OK(inject::Set("MmapTableEntry.mmap", "1*return(K_RUNTIME_ERROR)")); + std::vector writeElement; + size_t testSize = 2ul * 1024ul * 1024ul; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + Status rc = producer->Send(element); + LOG(INFO) << "Expected to get K_OK. Rc returned: " << rc.ToString(); + DS_ASSERT_NOT_OK(rc); + client.reset(); +} + +TEST_F(ProducerTest, TestSendBigEleDuringLostHeartbeat) +{ + LOG(INFO) << "test WriteRemoteBigElement start!"; + size_t testSize = 2ul * 1024ul * 1024ul; + Element element; + Element element2; + std::shared_ptr producer; + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + DS_ASSERT_OK(client1->CreateProducer("RemoteBigEleSuccess", producer, defaultProducerConf_)); + + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + DS_ASSERT_OK(client2->Subscribe("RemoteBigEleSuccess", config, consumer)); + + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + + DS_ASSERT_OK(datasystem::inject::Set("ProducerImpl.SendImpl.postInsertSuccess", "1*return(K_RPC_UNAVAILABLE)")); + DS_ASSERT_OK(datasystem::inject::Set("ProducerConsumerWorkerApi.ReleaseBigElementMemory.preReleaseBigShmMemory", + "1*sleep(3000)")); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, 0, "RemoteWorkerManager.SendElementsView.PostIncRefCount", "1*sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "ShmUnit.FreeMemory", "call()")); + DS_ASSERT_NOT_OK(producer->Send(element)); + + const int expectedNum = 1; + const int timeoutMs = 10000; + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(expectedNum, timeoutMs, outElements)); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); + + client1.reset(); + client2.reset(); +} + +TEST_F(ProducerTest, TestWriteMixedElements) +{ + LOG(INFO) << "test WriteMixedElements start!"; + size_t testSize = 60ul * 1024ul; + size_t bigSize = 2ul * 1024ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + std::shared_ptr spClient; + ProducerConf prodConf = { .delayFlushTime = 0, .pageSize = 4 * MB, .maxStreamSize = TEST_STREAM_SIZE }; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + DS_ASSERT_OK(spClient->CreateProducer("MixedElements", producer, prodConf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient->Subscribe("MixedElements", config, consumer)); + std::string writeData; + for (int i = 0; i < count; i++) { + if (i % 10 != 0) { + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + writeData += data; + } else { + DS_ASSERT_OK(CreateElement(bigSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + std::string bigData(reinterpret_cast(writeElement.data()), writeElement.size()); + writeData += bigData; + } + } + + std::vector outElements; + ASSERT_EQ(consumer->Receive(count, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), static_cast(count)); + std::string actualData; + for (int i = 0; i < count; i++) { + std::string tmp(reinterpret_cast(outElements[i].ptr), outElements[i].size); + actualData += tmp; + } + EXPECT_EQ(writeData, actualData); + spClient = nullptr; +} + +TEST_F(ProducerTest, TestSendInvalidTimeout) +{ + std::shared_ptr client; + DS_ASSERT_OK(CreateClient(0, client)); + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = 40 * 1024; + conf.maxStreamSize = 50 * 1024 * 1024; + DS_ASSERT_OK(client->CreateProducer("SendInvalidTimeout", producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe("SendInvalidTimeout", config, consumer)); + + std::string data = "hello"; + Element element((uint8_t *)data.data(), data.size()); + ASSERT_EQ(producer->Send(element, 0).GetCode(), K_OK); + ASSERT_EQ(producer->Send(element, -1).GetCode(), K_INVALID); + + const int timeout = 500; + DS_ASSERT_OK(producer->Send(element, timeout)); + + std::vector outElements; + const int maxRecvNum = 2; + DS_ASSERT_OK(consumer->Receive(maxRecvNum, timeout, outElements)); + ASSERT_EQ(outElements.size(), maxRecvNum); + ASSERT_EQ(outElements[0].size, data.size()); + DS_ASSERT_OK(consumer->Ack(outElements[0].id)); +} + +TEST_F(ProducerTest, TestSendSmallElementBigElement) +{ + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create a Producer + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = 40 * 1024; + conf.maxStreamSize = 50 * 1024 * 1024; + DS_ASSERT_OK(spClient0->CreateProducer("SendSmallElementBigElement", producer, conf)); + + // Create a Consumer on a different node + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("SendSmallElementBigElement", config, consumer)); + + // Add delay to ScanEval thread so that it picks up both elements at the same time + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "ExclusivePageQueue.ScanAndEval.wait", "sleep(500)")); + + // Send two elements: one small element and one big element + const uint32_t eleSz1 = 37910; + ElementGenerator elementGenerator1(eleSz1); + auto strs1 = elementGenerator1.GenElements("producer1", 1, 1); + DS_ASSERT_OK(producer->Send(Element((uint8_t *)strs1[0].data(), strs1[0].size()), 1000)); + + const uint32_t eleSz2 = 185015; + ElementGenerator elementGenerator2(eleSz2); + auto strs2 = elementGenerator2.GenElements("producer2", 1, 1); + DS_ASSERT_OK(producer->Send(Element((uint8_t *)strs2[0].data(), strs2[0].size()), 1000)); + + sleep(1); + + // Receiver should get both correctly + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(2, 5000, outElements)); + ASSERT_EQ(outElements.size(), 2); + ASSERT_EQ(outElements[0].size, strs1[0].size()); + ASSERT_EQ(outElements[1].size, strs2[0].size()); + DS_ASSERT_OK(consumer->Ack(2)); + + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); +} + +TEST_F(ProducerTest, TestBigElementOOM) +{ + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + // Create a Producer + const int64_t pageSize = 16 * KB; + const int numPages = 4; + const int64_t streamSize = numPages * pageSize; + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = pageSize; + conf.maxStreamSize = streamSize; + conf.retainForNumConsumers = 1; + DS_ASSERT_OK(spClient0->CreateProducer("BigEleOOM", producer, conf)); + + // Send three elements same as the page size. The new logic will convert them into BigElement + RandomData rand; + auto str = rand.GetRandomString(pageSize); + for (int i = 0; i < numPages - 1; ++i) { + DS_ASSERT_OK(producer->Send(Element((uint8_t *)str.data(), str.size()))); + } + + // Insert one more time should get OOM (because we have already created one data page and three big element pages) + const int64_t timeoutMs = 5000; + Status rc = producer->Send(Element((uint8_t *)str.data(), str.size()), timeoutMs); + DS_ASSERT_TRUE(rc.GetCode(), K_OUT_OF_MEMORY); + + // Create a Consumer on the same node + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient0->Subscribe("BigEleOOM", config, consumer)); + + // Receive one, ack, and send again. Should be successful + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), numPages - 1); + consumer->Ack(outElements[0].id); + std::this_thread::sleep_for(std::chrono::seconds(1)); + DS_ASSERT_OK(producer->Send(Element((uint8_t *)str.data(), str.size()))); + + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(spClient0->DeleteStream("BigEleOOM")); +} + +TEST_F(ProducerTest, TestCreateProducerWithPage) +{ + std::shared_ptr producer; + std::shared_ptr producer2; + std::shared_ptr producer3; + std::shared_ptr producer4; + std::shared_ptr spClient; + uint32_t baseSize = 4 * 1024; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + DS_ASSERT_OK(spClient->CreateProducer( + "test_producer", producer, { .delayFlushTime = 0, .pageSize = baseSize, .maxStreamSize = TEST_STREAM_SIZE })); + + Status rc = + spClient->CreateProducer("test_producer2", producer2, + { .delayFlushTime = 0, .pageSize = 4294967295, .maxStreamSize = TEST_STREAM_SIZE }); + LOG_IF_ERROR(rc, "Expected failure for invalid page size"); + ASSERT_EQ(rc.GetCode(), StatusCode::K_INVALID); + + const uint64_t maxPageSize = 16 * MB; + const uint64_t nextInvalidPageSize = maxPageSize + 4 * KB; + DS_ASSERT_OK( + spClient->CreateProducer("test_producer3", producer3, + { .delayFlushTime = 0, .pageSize = maxPageSize, .maxStreamSize = TEST_STREAM_SIZE })); + + rc = spClient->CreateProducer( + "test_producer4", producer4, + { .delayFlushTime = 0, .pageSize = nextInvalidPageSize, .maxStreamSize = TEST_STREAM_SIZE }); + LOG_IF_ERROR(rc, "Expected failure for invalid page size"); + ASSERT_EQ(rc.GetCode(), StatusCode::K_INVALID); + + spClient = nullptr; +} + +TEST_F(ProducerTest, TestCreateProducerWithoutPage) +{ + std::shared_ptr producer; + std::shared_ptr spClient; + ProducerConf conf; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + conf.delayFlushTime = 0; + conf.maxStreamSize = TEST_STREAM_SIZE; + DS_ASSERT_OK(spClient->CreateProducer("test_default_producer", producer, conf)); + spClient = nullptr; +} + +TEST_F(ProducerTest, TestCreateProducerReserveSize) +{ + // This testcase intends to test that invalid reserve size will lead to CreateProducer failure. + ProducerConf conf; + const int64_t DEFAULT_PAGE_SIZE = 8 * KB; + const int64_t NOT_MULTIPLE_RESERVE_SIZE = 12 * KB; + conf.pageSize = DEFAULT_PAGE_SIZE; + conf.maxStreamSize = TEST_STREAM_SIZE; + std::shared_ptr client; + ASSERT_EQ(CreateClient(0, client), Status::OK()); + std::string streamName = "test_reserve_size"; + std::shared_ptr producer; + + // Valid reserve size should be less than or equal to max stream size. + ProducerConf exceedSizeConf(conf); + exceedSizeConf.reserveSize = TEST_STREAM_SIZE + DEFAULT_PAGE_SIZE; + DS_ASSERT_NOT_OK(client->CreateProducer(streamName, producer, exceedSizeConf)); + + // Valid reserve size should be a multiple of page size. + ProducerConf notMultipleConf(conf); + notMultipleConf.reserveSize = NOT_MULTIPLE_RESERVE_SIZE; + DS_ASSERT_NOT_OK(client->CreateProducer(streamName, producer, notMultipleConf)); + + // 0 is an acceptable input for reserve size, the default reserve size will then be the page size. + ProducerConf zeroConf(conf); + notMultipleConf.reserveSize = 0; + DS_ASSERT_OK(client->CreateProducer(streamName, producer, zeroConf)); +} + +TEST_F(ProducerTest, TestCreateMultiProducerWithPage) +{ + uint32_t baseSize = 4096; + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + for (int i = 1; i < 11; i++) { + std::shared_ptr producer; + DS_ASSERT_OK( + spClient->CreateProducer("test_multi_producer", producer, + { .delayFlushTime = 0, .pageSize = baseSize, .maxStreamSize = TEST_STREAM_SIZE })); + } + spClient = nullptr; +} + +TEST_F(ProducerTest, TestFlushPGIsFull) +{ + Timer timer; + ThreadPool pool(2); + + std::shared_ptr spClient; + ASSERT_EQ(CreateClient(0, spClient), Status::OK()); + // Producer sends in total 64 MB data to a stream of 4 MB capacity and 64 KB-sized pages. + size_t numOfElements = 100; + auto producerFut = pool.Submit([this, numOfElements, spClient]() { + size_t numOfRetries = 0; + std::shared_ptr producer; + ProducerConf producerConf{ .delayFlushTime = -1, .pageSize = 64 * KB, .maxStreamSize = 4 * MB }; + ElementGenerator elementGenerator(3 * KB, 3 * KB); + uint64_t numOfConsumers = 0; + while (numOfConsumers != 1) { + spClient->QueryGlobalConsumersNum("FlushPGIsFull", numOfConsumers); + } + RETURN_IF_NOT_OK(spClient->CreateProducer("FlushPGIsFull", producer, producerConf)); + auto elements = elementGenerator.GenElements("producer", numOfElements, 8ul); + Timer timer; + for (size_t i = 0; i < numOfElements; i++) { + auto element = Element(reinterpret_cast(&elements[i].front()), elements[i].size()); + auto status = producer->Send(element); + } + LOG(INFO) << FormatString("Producer's number of re-sending: %zu, sending time: %.6lf s", numOfRetries, + timer.ElapsedSecond()); + return Status::OK(); + }); + + // Receiver receive all the data. + auto consumerFut = pool.Submit([this, numOfElements, spClient]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(spClient->Subscribe("FlushPGIsFull", config, consumer)); + std::unordered_map seqNoMap; + Timer timer; + size_t ackInterval = 4096; + for (size_t i = 0; i < numOfElements;) { + std::vector outElements; + consumer->Receive(1, 20, outElements); + i += outElements.size(); + for (auto &element : outElements) { + ElementView elementView(std::string(reinterpret_cast(element.ptr), element.size)); + RETURN_IF_NOT_OK(elementView.VerifyFifo(seqNoMap)); + RETURN_IF_NOT_OK(elementView.VerifyIntegrity()); + } + if (i % ackInterval == (ackInterval - 1)) { + LOG(INFO) << FormatString("Ack id: %zu", outElements.back().id); + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + } + } + LOG(INFO) << FormatString("Total Recv Time Elapsed: %.6lf s", timer.ElapsedSecond()); + return Status::OK(); + }); + + ASSERT_EQ(consumerFut.get(), Status::OK()); + ASSERT_EQ(producerFut.get(), Status::OK()); + LOG(INFO) << FormatString("End To End Time Elapsed: %.6lf s", timer.ElapsedSecond()); +} + +// Test a "local" OOM where the consumer is on the same node as the producer, but does not consume right away causing +// build up of data and OOM conditions. Client retry will eventually run leading to overall sending success. +// Send timeout of 0. +TEST_F(ProducerTest, LEVEL2_TestOOM1) +{ + Status rc; + std::shared_ptr spClient; + std::string streamName = "TestOOM1"; + + rc = CreateClient(0, spClient); + LOG_IF_ERROR(rc, "Creating client failed."); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = RunOOMTest(spClient, spClient, streamName); + LOG_IF_ERROR(rc, "Running OOM test gave error."); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + spClient = nullptr; +} + +// Test a "local" OOM where the consumer is on the same node as the producer, but does not consume right away causing +// build up of data and OOM conditions. Use a send-side timeout such that OOM is only returned after timing out, +// and disable client retry so that the OOM timeout/failure will terminate the test run. +TEST_F(ProducerTest, TestOOM2) +{ + Status rc; + clientRetry_ = false; + sendTimeout_ = 3000; // 3 second timeout + std::shared_ptr spClient; + std::string streamName = "TestOOM2"; + + rc = CreateClient(0, spClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = RunOOMTest(spClient, spClient, streamName); + LOG_IF_ERROR(rc, "Running OOM test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OUT_OF_MEMORY); // expected fail + + spClient = nullptr; +} + +// Test a "local" OOM where the consumer is on the same node as the producer, but does not consume right away causing +// build up of data and OOM conditions. Use a send-side timeout that is large enough such that a blocked send will +// eventually complete successfully (before the timeout) once the consumer drains some data to free up memory. +TEST_F(ProducerTest, TestOOM3) +{ + Status rc; + clientRetry_ = false; + sendTimeout_ = 10000; // 10 second timeout to give it lots of wait time on send blocking + elementsTotalSize_ = 32 * 1024 * 1024; + std::shared_ptr spClient; + std::string streamName = "TestOOM3"; + + rc = CreateClient(0, spClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = RunOOMTest(spClient, spClient, streamName); + LOG_IF_ERROR(rc, "Running OOM test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + + spClient = nullptr; +} + +// Tests LEVEL1_TestOOM4 and TestOOM5: +// Test a remote OOM which then leads to local OOM. +// Producer is on worker0 and consumer is on worker1. +// Flushes naturally free the pages locally on worker0 since there is no local consumer. +// However, if the remote node is returning OOMs when the pages are attempted to send from worker0 to worker1, then the +// local free of the page is delayed until the remote sends can reduce their ref count. This leads to local OOM and +// send waiting. + +// no timeout on the send, use client-side retry when OOM's are returned +TEST_F(ProducerTest, LEVEL2_TestOOM4) +{ + Status rc; + std::shared_ptr prodClient; + std::shared_ptr conClient; + std::string streamName = "TestOOM4"; + + rc = CreateClient(0, prodClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = CreateClient(1, conClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = RunOOMTest(prodClient, conClient, streamName); + LOG_IF_ERROR(rc, "Running OOM test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + + prodClient = nullptr; + conClient = nullptr; +} + +// disable client retry and use a timeout on the send calls that is large enough to support the work +TEST_F(ProducerTest, TestOOM5) +{ + Status rc; + clientRetry_ = false; + sendTimeout_ = 10000; // 10 second timeout to give it lots of wait time on send blocking + std::shared_ptr prodClient; + std::shared_ptr conClient; + std::string streamName = "TestOOM5"; + + rc = CreateClient(0, prodClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = CreateClient(1, conClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = RunOOMTest(prodClient, conClient, streamName); + LOG_IF_ERROR(rc, "Running OOM test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + + prodClient = nullptr; + conClient = nullptr; +} + +// ensure multiple producers will block invalid configs +TEST_F(ProducerTest, TestMultiProdCfg) +{ + Status rc; + std::shared_ptr client1; + std::shared_ptr client2; + + rc = CreateClient(0, client1); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = CreateClient(1, client2); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + std::shared_ptr producer1; + std::shared_ptr producer2; + ProducerConf producerConf1{ .delayFlushTime = 20, .pageSize = 64 * KB, .maxStreamSize = 2 * MB }; + ProducerConf producerConf2{ .delayFlushTime = 20, .pageSize = 32 * KB, .maxStreamSize = 2 * MB }; + DS_ASSERT_OK(client1->CreateProducer("MultiProdCfg", producer1, producerConf1)); + + // Same worker, different page size + rc = client1->CreateProducer("MultiProdCfg", producer2, producerConf2); + LOG_IF_ERROR(rc, "Expected failure for invalid page size of existing stream"); + ASSERT_EQ(rc.GetCode(), StatusCode::K_INVALID); + + // remote worker, different page size + rc = client2->CreateProducer("MultiProdCfg", producer2, producerConf2); + LOG_IF_ERROR(rc, "Expected failure for invalid page size of existing stream"); + ASSERT_EQ(rc.GetCode(), StatusCode::K_INVALID); + + // Same worker, turn on auto delete + ProducerConf producerConf3 = producerConf1; + producerConf3.autoCleanup = true; + rc = client1->CreateProducer("MultiProdCfg", producer2, producerConf3); + LOG_IF_ERROR(rc, "Expected failure for invalid auto cleanup of existing stream"); + ASSERT_EQ(rc.GetCode(), StatusCode::K_INVALID); + + client1 = nullptr; + client2 = nullptr; +} + +TEST_F(ProducerTest, TestMultiProdClose) +{ + const int numClients = 2; + const int numStreams = 2; + std::shared_ptr producerClients[numClients]; + std::shared_ptr consumerClients[numClients]; + std::vector> producers; + Status rc; + ProducerConf producerConf{ .delayFlushTime = 20, .pageSize = 64 * KB, .maxStreamSize = 4 * MB }; + + // 2 clients for producers + ASSERT_EQ(CreateClient(0, producerClients[0]), Status::OK()); + ASSERT_EQ(CreateClient(1, producerClients[1]), Status::OK()); + + // Create 8 producers, 4 on the stream named "test0 and 4 on "test1" + // 2 producers form each client. Layout: + // + // stream client producer + // ------- ------ -------- + // "test0" p0 0 + // "test0" p1 1 + // "test0" p0 2 + // "test0" p1 3 + // "test1" p0 4 + // "test1" p1 5 + // "test1" p0 6 + // "test1" p1 7 + + std::vector streamNames; + for (int i = 0; i < numStreams; ++i) { + streamNames.push_back("test" + std::to_string(i)); + } + + const int numProds = 4; + for (int j = 0; j < numStreams; ++j) { + for (int i = 0; i < numProds; ++i) { + std::shared_ptr newProducer; + rc = producerClients[i % numClients]->CreateProducer(streamNames[j], newProducer, producerConf); + LOG_IF_ERROR(rc, "Creating producer failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + producers.push_back(std::move(newProducer)); + } + } + + const int numCons = 4; + std::vector subCfgs; + for (int i = 0; i < numCons; ++i) { + SubscriptionConfig consumerConf("sub" + std::to_string(i), SubscriptionType::STREAM); + subCfgs.push_back(consumerConf); + } + + // With different clients, create subscribers/consumers. + // + // stream client consumer + // ------- ------ -------- + // "test0" c0 0 + // "test1" c0 1 + // "test0" c1 2 + // "test1" c1 3 + + std::vector> consumers; + ASSERT_EQ(CreateClient(0, consumerClients[0]), Status::OK()); + ASSERT_EQ(CreateClient(1, consumerClients[1]), Status::OK()); + + std::shared_ptr newConsumer; + int cIdx = 0; + rc = consumerClients[0]->Subscribe(streamNames[0], subCfgs[cIdx++], newConsumer); + LOG_IF_ERROR(rc, "(subscribe): Creating consumer failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + consumers.push_back(std::move(newConsumer)); + + rc = consumerClients[0]->Subscribe(streamNames[1], subCfgs[cIdx++], newConsumer); + LOG_IF_ERROR(rc, "(subscribe): Creating consumer failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + consumers.push_back(std::move(newConsumer)); + + rc = consumerClients[1]->Subscribe(streamNames[0], subCfgs[cIdx++], newConsumer); + LOG_IF_ERROR(rc, "(subscribe): Creating consumer failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + consumers.push_back(std::move(newConsumer)); + + rc = consumerClients[1]->Subscribe(streamNames[1], subCfgs[cIdx], newConsumer); + LOG_IF_ERROR(rc, "(subscribe): Creating consumer failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + consumers.push_back(std::move(newConsumer)); + + // Do a simple receive from each consumer. nothing sent so we expect no results here. + LOG(INFO) << "Loop over the 4 consumers and do a receive. Expect each to give 0 records."; + for (auto consumer : consumers) { + std::vector outElements; + consumer->Receive(1, RECV_WAIT_MILLI_SECONDS, outElements); + LOG(INFO) << "Used a consumer to call receive. Got " << outElements.size() << " elements returned"; + } + + // Instead of individually closing each consumer and producer, do a client reset to drive a forced + // shutdown of the producers in clients 1 and 2. + LOG(INFO) << "Client reset of client 0 start"; + producerClients[0].reset(); + LOG(INFO) << "Client reset of client 0 done."; + + LOG(INFO) << "Client reset of client 1 start"; + producerClients[1].reset(); + LOG(INFO) << "Client reset of client 1 done"; + + // Repeat the receive attempts. This time, since the client disconnect drove a close with force mode + // true, these consumers will not get any data as the producers are already closed. + // The Receive() calls might get triggered after subscriber receiving producer close notification. Hence + // no one will wake them up from the pending state other than the timer. + LOG(INFO) << "Loop over the 4 consumers and do a receive. Expect them all to fail."; + for (auto consumer : consumers) { + Status rc; + std::vector outElements; + rc = consumer->Receive(1, RECV_WAIT_MILLI_SECONDS, outElements); + LOG_IF_ERROR(rc, "Calling receive failed "); + ASSERT_EQ(outElements.size(), (size_t)0); + } +} + +TEST_F(ProducerTest, TestMultiProdClose2) +{ + std::shared_ptr pClient; + std::shared_ptr cClient; + std::vector> producers; + std::shared_ptr consumer; + Status rc; + ProducerConf producerConf{ .delayFlushTime = 20, .pageSize = 64 * KB, .maxStreamSize = 4 * MB }; + + ASSERT_EQ(CreateClient(0, pClient), Status::OK()); + std::string streamName("MultiProdClose2"); + + // Create 4 producers + const int numProds = 4; + LOG(INFO) << "Creating 4 producers"; + for (int i = 0; i < numProds; ++i) { + std::shared_ptr newProducer; + rc = pClient->CreateProducer(streamName, newProducer, producerConf); + LOG_IF_ERROR(rc, "Creating producer failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + producers.push_back(std::move(newProducer)); + } + + // Create a consumer on the remote worker + LOG(INFO) << "Creating consumer"; + ASSERT_EQ(CreateClient(1, cClient), Status::OK()); + SubscriptionConfig consumerConf("sub", SubscriptionType::STREAM); + rc = cClient->Subscribe(streamName, consumerConf, consumer); + LOG_IF_ERROR(rc, "Creating producer failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + // Inject a failure on the master worker dealing with close to test the retry handling of the + // client disconnect force close + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "master.PubDecreaseNode.afterSendNotification", + "1*return(K_RPC_UNAVAILABLE)")); + + LOG(INFO) << "Client reset start"; + pClient.reset(); + LOG(INFO) << "Client reset 0 done."; +} + +// The generic MultiTest_NStream test driver function can provide different configurations. +// This run will test 8 producer, single consumer, over a 2 worker setup. +TEST_F(BigShmTest, MultiTest_NStream1) +{ + Status rc; + sendTimeout_ = 10000; // 10 second timeout to give it lots of wait time on send blocking + std::shared_ptr newClient; + std::vector> clients; + earlyExitCount_ = 327680; + slowReceive_ = true; + validate_ = false; + + rc = CreateClient(0, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + rc = CreateClient(1, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + int64_t pageSize = 1048576; // 1 mb + uint64_t maxStreamSize = 16 * MB; + int numStreams = 1; + int numProds = 8; + int numSubs = 1; + rc = MultiTest_NStream(clients, pageSize, maxStreamSize, numStreams, numProds, numSubs); + LOG_IF_ERROR(rc, "Running multi test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + + clients.clear(); // frees the client ptr's + + // The sleep at the end just ensures the final testcase kill happens after things are cleaned up and + // makes it easier to read the logs if needed. + LOG(INFO) << "sleep at the end of testcase"; + const int SLEEP = 5; + std::this_thread::sleep_for(std::chrono::seconds(SLEEP)); + LOG(INFO) << "sleep at the end of testcase done"; +} + +// The generic MultiTest_NStream test driver function can provide different configurations. +// This run will test 8 producer, single consumer, over a 2 worker setup. This one uses autoAck +// feature, so manual acks are not sent and instead each receive will drive acks. +TEST_F(BigShmTest, LEVEL1_MultiTest_NStream2) +{ + Status rc; + sendTimeout_ = 10000; // 10 second timeout to give it lots of wait time on send blocking + std::shared_ptr newClient; + std::vector> clients; + autoAck_ = true; // configure for auto ack mode + validate_ = false; + + rc = CreateClient(0, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + rc = CreateClient(1, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + int64_t pageSize = 1048576; // 1 mb + uint64_t maxStreamSize = 16 * MB; + int numStreams = 1; + int numProds = 8; + int numSubs = 1; + rc = MultiTest_NStream(clients, pageSize, maxStreamSize, numStreams, numProds, numSubs); + LOG_IF_ERROR(rc, "Running multi test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + clients.clear(); // frees the client ptr's +} + +// The generic MultiTest_NStream test driver function can provide different configurations. +// This run will test 4 producer, single consumer, over a 2 worker setup. This one does not +// use auto ack feature, and it is configured to ack frequently to test the smart ack +// logic +TEST_F(BigShmTest, MultiTest_NStream3) +{ + // README + // The pageSz has been increased from 4 * KB to reduce run time during CI + // To run the intended load locally, edit the value. + Status rc; + sendTimeout_ = 10000; // 10 second timeout to give it lots of wait time on send blocking + std::shared_ptr newClient; + std::vector> clients; + ackInterval_ = 1; // ack after every element! + validate_ = false; + eleSz_ = 8 * KB; + + rc = CreateClient(0, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + rc = CreateClient(1, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + int64_t pageSize = 1 * MB; + uint64_t maxStreamSize = 32 * MB; + int numStreams = 1; + int numProds = 4; + int numSubs = 1; + rc = MultiTest_NStream(clients, pageSize, maxStreamSize, numStreams, numProds, numSubs); + LOG_IF_ERROR(rc, "Running multi test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + clients.clear(); // frees the client ptr's +} + +// README +// MultiTest_NStream4 and 5 is consolidated into one testcase. Only difference is slowReceive_. +// The elementsTotalSize_ is decreased from 256 * MB to reduce run time during CI +// To run the intended load locally, edit the value. +// Toggle the flag inside the testacse to switch between the NStream4 and NStream5 +// 1 producer, 1 consumer (remote worker) scenario. +// NStream4 - Consumer is slow in between each receive to ensure worker has data. Dumps a perf log at the end +// for perf analysis of logs +// NStream5 - Consumer is not slow in between each receive. This is not the winning scenario for prefetching, +// but we canuse this to test prefetching in a case where it does not aid performance to assess overhead. +TEST_F(BigShmTest, LEVEL1_MultiTest_NStream4) +{ + Status rc; + sendTimeout_ = 10000; // 10 second timeout to give it lots of wait time on send blocking + std::shared_ptr newClient; + std::vector> clients; + slowConsume_ = true; + // Set True for NStream4 setting or False for NStream5 setting + slowReceive_ = true; + validate_ = false; + clientCacheSize_ = 4096; // override default for the cache size + prefetchLWM_ = 50; // cache threshold for prefetching + elementsTotalSize_ = 64 * MB; + eleSz_ = 128; // small elements. dft was 1K + + rc = CreateClient(0, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + rc = CreateClient(1, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + int64_t pageSize = 1048576; // 1 mb + uint64_t maxStreamSize = 32 * MB; + int numStreams = 1; + int numProds = 1; + int numSubs = 1; + rc = MultiTest_NStream(clients, pageSize, maxStreamSize, numStreams, numProds, numSubs); + LOG_IF_ERROR(rc, "Running multi test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + clients.clear(); // frees the client ptr's + + PerfManager *perfManager = PerfManager::Instance(); + perfManager->PrintPerfLog(); +} + +// 3 producers, 1 consumer, 1 worker. +TEST_F(BigShmTest, LEVEL2_MultiTest_NStream6) +{ + // README + // The eleSz_ has been increased from 48 to reduce run time during CI + // To run the intended load locally, edit the value. + Status rc; + sendTimeout_ = 10000; // 10 second timeout to give it lots of wait time on send blocking + std::shared_ptr newClient; + std::vector> clients; + autoAck_ = true; // configure for auto ack mode + validate_ = false; + eleSz_ = 128; + + rc = CreateClient(0, newClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + clients.push_back(std::move(newClient)); + + int64_t pageSize = 1048576; // 1 mb + uint64_t maxStreamSize = 16 * MB; + int numStreams = 1; + int numProds = 3; + int numSubs = 1; + rc = MultiTest_NStream(clients, pageSize, maxStreamSize, numStreams, numProds, numSubs); + LOG_IF_ERROR(rc, "Running multi test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + clients.clear(); // frees the client ptr's + + PerfManager *perfManager = PerfManager::Instance(); + perfManager->PrintPerfLog(); +} + +class ProducerLocalMemTest : public ProducerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + // This will cause local cache to send OOMs + opts.workerGflagParams = " -sc_local_cache_memory_size_mb=10 -v=1"; + opts.numWorkers = NUM_WORKERS; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + signature_ = std::make_unique(accessKey_, secretKey_); + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + ExternalClusterTest::TearDown(); + } +}; + +TEST_F(ProducerLocalMemTest, DISABLED_TestLocalCacheOOM1) +{ + Status rc; + std::shared_ptr prodClient; + std::shared_ptr conClient; + std::string streamName = "TestLocalCacheOOM1"; + rc = CreateClient(0, prodClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + + rc = CreateClient(0, conClient); + LOG_IF_ERROR(rc, "Creating client failed. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); + const int K_10 = 10; + rc = RunOOMTest(prodClient, conClient, streamName, K_10); + LOG_IF_ERROR(rc, "Running OOM test gave error. "); + ASSERT_EQ(rc.GetCode(), StatusCode::K_OK); // expected success + + prodClient = nullptr; + conClient = nullptr; +} + +/* +Create 2 producers, 1 consumer. Testing Back pressure with modified RunOOMTest with a second producer. +One producer sends slower than other. Use future object to assert status is OK. Each producer sends +half of numElements so no TimeOut for test purposes. +*/ +TEST_F(ProducerTest, LEVEL1_TestDifferentSendSpeed) +{ + std::shared_ptr prodClient1, prodClient2, conClient; + DS_ASSERT_OK(CreateClient(0, prodClient1)); + DS_ASSERT_OK(CreateClient(0, prodClient2)); + DS_ASSERT_OK(CreateClient(1, conClient)); + + uint64_t maxStreamSizeMB = 2; + ThreadPool preparePool(CLIENT_THREAD_POOL_SIZE); + uint64_t eleSz = 8192ul; // this should be small element, not big element + uint64_t numElements = elementsTotalSize_ / eleSz / 2; + uint64_t quarterNumElements = numElements / 4; + uint64_t halfNumElements = numElements / 2; + slowConsume_ = waitForGo_ = true; + + LOG(INFO) << FormatString("Testing Size: %zu", eleSz); + std::shared_future> elementsFut1 = preparePool.Submit([eleSz, quarterNumElements]() { + ElementGenerator elementGenerator(eleSz + 1, eleSz); + auto elements1 = elementGenerator.GenElements("producer1", quarterNumElements, 8ul); + return elements1; + }); + std::shared_future> elementsFut2 = preparePool.Submit([eleSz, quarterNumElements]() { + ElementGenerator elementGenerator(eleSz + 1, eleSz); + auto elements2 = elementGenerator.GenElements("producer2", quarterNumElements, 8ul); + return elements2; + }); + + // Wait for the data generation to complete before we launch producers and consumers + while (elementsFut1.wait_for(std::chrono::seconds(1)) != std::future_status::ready) + ; + while (elementsFut2.wait_for(std::chrono::seconds(1)) != std::future_status::ready) + ; + LOG(INFO) << "Data generation complete. kick off threads now"; + + ThreadPool pool(CLIENT_THREAD_POOL_SIZE); + std::string streamName = "TestDifferentSendSpeed"; + ProducerConf prodCfg = { .delayFlushTime = 20, .pageSize = 1 * MB, .maxStreamSize = maxStreamSizeMB * MB }; + auto producerFut = pool.Submit( + [this, streamName, prodCfg, &elementsFut1, &elementsFut2, quarterNumElements, prodClient1, prodClient2]() { + SendConfig sendCfg1 = { .streamName = streamName, + .producerName = "producer1", + .producerConf = prodCfg, + .numOfElements = quarterNumElements }; + SendConfig sendCfg2 = { .streamName = streamName, + .producerName = "producer2", + .producerConf = prodCfg, + .numOfElements = quarterNumElements }; + return SendStreamDataSlow(sendCfg1, sendCfg2, 1, elementsFut1, elementsFut2, prodClient1, prodClient2); + }); + + size_t ackInterval = std::max(400ul * KB / eleSz, 1ul); + auto consumerFut = pool.Submit([this, streamName, halfNumElements, ackInterval, conClient]() { + RecvConfig rcvCfg = { .streamName = streamName, + .subscriptionName = "subscription", + .numOfBatchElements = 100, + .timeToWaitMs = 20, + .ackInterval = ackInterval, + .autoAck = false }; + return RecvStreamData(rcvCfg, halfNumElements, conClient); + }); + + // assert status after threads complete + ASSERT_EQ(producerFut.get(), Status::OK()); + ASSERT_EQ(consumerFut.get(), Status::OK()); +} + +/* +Create 2 consumers 1 producer. Testing Back pressure with modified RunOOMTest with a second consumer. +One Consumer receives slower than other. Use future object to assert status is OK. +*/ +TEST_F(ProducerTest, LEVEL1_TestDifferentReceiveSpeed) +{ + std::shared_ptr prodClient, conClient1, conClient2; + DS_ASSERT_OK(CreateClient(0, prodClient)); + DS_ASSERT_OK(CreateClient(1, conClient1)); + DS_ASSERT_OK(CreateClient(1, conClient2)); + + uint64_t maxStreamSizeMB = 2; + ThreadPool preparePool(1); + uint64_t eleSz = 8192ul; // this should be small element, not big element + uint64_t numElements = elementsTotalSize_ / eleSz / 2; + slowConsume_ = waitForGo_ = true; + + LOG(INFO) << FormatString("Testing Size: %zu", eleSz); + std::shared_future> elementsFut = preparePool.Submit([eleSz, numElements]() { + ElementGenerator elementGenerator(eleSz + 1, eleSz); + auto elements = elementGenerator.GenElements("producer", numElements, 8ul); + return elements; + }); + + // Wait for the data generation to complete before we launch producers and consumers + while (elementsFut.wait_for(std::chrono::seconds(1)) != std::future_status::ready) + ; + LOG(INFO) << "Data generation complete. kick off threads now"; + + const int POOL_SIZE = 3; + int numOfConsumer = 2; + ThreadPool pool(POOL_SIZE); + std::string streamName = "DifferentRecvSpeed"; + auto producerFut = + pool.Submit([this, streamName, numOfConsumer, &elementsFut, numElements, prodClient, maxStreamSizeMB]() { + ProducerConf prodCfg = { .delayFlushTime = 20, .pageSize = 1 * MB, .maxStreamSize = maxStreamSizeMB * MB }; + SendConfig sendCfg = { .streamName = streamName, + .producerName = "producer", + .producerConf = prodCfg, + .numOfElements = numElements }; + return SendStreamData(sendCfg, numOfConsumer, elementsFut, prodClient); + }); + + size_t ackInterval = std::max(400ul * KB / eleSz, 1ul); + auto consumerFut1 = pool.Submit([this, streamName, numElements, ackInterval, conClient1]() { + RecvConfig rcvCfg = { .streamName = streamName, + .subscriptionName = "sub1", + .numOfBatchElements = 100, + .timeToWaitMs = 20, + .ackInterval = ackInterval, + .autoAck = false }; + return RecvStreamData(rcvCfg, numElements, conClient1); + }); + + auto consumerFut2 = pool.Submit([this, streamName, numElements, ackInterval, conClient2]() { + RecvConfig rcvCfg = { .streamName = streamName, + .subscriptionName = "sub2", + .numOfBatchElements = 100, + .timeToWaitMs = 20, + .ackInterval = ackInterval, + .autoAck = false }; + return RecvStreamDataWithSlowReceive(rcvCfg, numElements, conClient2); + }); + + // assert status after threads complete + ASSERT_EQ(producerFut.get(), Status::OK()); + ASSERT_EQ(consumerFut1.get(), Status::OK()); + ASSERT_EQ(consumerFut2.get(), Status::OK()); +} + +TEST_F(ProducerTest, TestReCreateProducerDiscardData) +{ + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr producer; + std::string streamName = "DiscardData"; + + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + ProducerConf conf; + conf.maxStreamSize = 10 * 1024 * 1024; + const int NUM_ITER = 5; + const int DEFAULT_WAIT_TIME = 5000; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "worker.UsageMonitor.CheckOverUsedForStream.MockError", + "return(K_OUT_OF_MEMORY)")); + for (int iteration = 0; iteration < NUM_ITER; iteration++) { + LOG(INFO) << "Iteration number " << iteration; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + + const size_t testSize = 500 * KB; + Element element; + std::vector writeElement; + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer->Close()); + } + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 1, "worker.UsageMonitor.CheckOverUsedForStream.MockError")); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(NUM_ITER, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), (size_t)NUM_ITER); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); +} + +TEST_F(ProducerTest, TestConsumerFutexWake) +{ + FLAGS_v = SC_DEBUG_LOG_LEVEL; + std::shared_ptr client; + ASSERT_EQ(CreateClient(0, client), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe("ConsumerFutexWake", config, consumer)); + std::shared_ptr producer; + ProducerConf conf; + const int pageSize = 4 * KB; + conf.pageSize = pageSize; + conf.maxStreamSize = DEFAULT_MAX_STREAM_SIZE; + DS_ASSERT_OK(client->CreateProducer("ConsumerFutexWake", producer, conf)); + const int eleSize = 3 * KB; + std::string a(eleSize, 'a'); + size_t numElementRecv = 0; + size_t numElementSend = 0; + ThreadPool pool(2); + auto consFut = pool.Submit([&consumer, &numElementRecv]() { + std::vector out; + RETURN_IF_NOT_OK(consumer->Receive(RPC_TIMEOUT, out)); + numElementRecv += out.size(); + datasystem::inject::Set("StreamDataPage.WaitOnFutexForever", "1*call()"); + RETURN_IF_NOT_OK(consumer->Receive(RPC_TIMEOUT, out)); + numElementRecv += out.size(); + return Status::OK(); + }); + auto prodFut = pool.Submit([&a, &producer, &numElementSend]() { + Element ele1(reinterpret_cast(a.data()), a.size()); + // Send one element + RETURN_IF_NOT_OK(producer->Send(ele1)); + numElementSend++; + std::this_thread::sleep_for(std::chrono::seconds(5)); + RETURN_IF_NOT_OK(producer->Send(ele1)); + numElementSend++; + return Status::OK(); + }); + auto rc1 = consFut.get(); + auto rc2 = prodFut.get(); + DS_ASSERT_OK(rc1); + DS_ASSERT_OK(rc2); + ASSERT_EQ(numElementSend, numElementRecv); + producer->Close(); + consumer->Close(); +} + +TEST_F(ProducerTest, TestConsumerTimingHole) +{ + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr producer; + const size_t SEND_COUNT = 100; + const size_t testSize = 500 * KB; + ProducerConf conf; + conf.maxStreamSize = 10 * 1024 * 1024; + DS_ASSERT_OK(client1->CreateProducer("ConsumerTimingHole", producer, conf)); + std::thread producerThrd([&client1, &producer]() { + const int DEFAULT_SLEEP_TIME = 300; + Element element; + std::vector writeElement; + uint64_t numOfConsumers = 0; + while (numOfConsumers != 1) { + client1->QueryGlobalConsumersNum("ConsumerTimingHole", numOfConsumers); + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + } + CreateElement(testSize, element, writeElement); + for (size_t i = 0; i < SEND_COUNT; i++) { + Status rc = producer->Send(element); + int retryCount = 30; + while (rc.GetCode() == K_OUT_OF_MEMORY && retryCount-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + DS_ASSERT_OK(rc); + } + }); + + // Inject sleep to extend the consumer timing hole + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "ClientWorkerSC.Subscribe.TimingHole", "1*sleep(1000)")); + + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("ConsumerTimingHole", config, consumer)); + + const int DEFAULT_WAIT_TIME = 1000; + const int DEFAULT_RETRY_TIME = 10; + Timer timer; + std::vector outElements; + int sendCount = SEND_COUNT; + while (sendCount > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + if (!outElements.empty()) { + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + sendCount -= outElements.size(); + } + } + ASSERT_EQ(sendCount, 0); + producerThrd.join(); +} + +TEST_F(ProducerTest, TestBlockedCreateRequestTimingHole) +{ + // 1 Producer -> 1 Consumer Same Node. + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + std::string streamName = "BlockedCreateRequestTimingHole"; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + + // Only 1 element per page. + const size_t elementSize = 900 * KB; + Element element; + std::vector writeElement; + CreateElement(elementSize, element, writeElement); + + // Normal Send. + DS_ASSERT_OK(producer->Send(element)); + + // Timer for the BlockedCreateRequest remain active, we do not cancel it. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "do.not.cancel.timer", "1*return()")); + DS_ASSERT_OK(producer->Send(element, 10)); // Timer with less than 10ms + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "do.not.cancel.timer")); + + // The timer above will attempt to process the BlockedCreateRequest created below, the timer should do no ops. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep", "sleep(2000)")); + DS_ASSERT_OK(producer->Send(element)); // No timer + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep")); + + // Clean up + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client1->DeleteStream(streamName)); +} + +TEST_F(ProducerTest, LEVEL2_TestCreateProducerLongTimeout1) +{ + // Request should not timeout if client timeout is set to 10mins and master takes more time + + // set timeout to 10 mins + std::shared_ptr client1; + const int32_t timeoutMs = 1000 * 60 * 10; + ASSERT_EQ(CreateClient(0, timeoutMs, client1), Status::OK()); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + + // Make master wait for 1 min and it should not timeout + // We actually dont know who is the master so inject in both + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "SCMetadataManager.CreateProducer.wait", + "1*sleep(60000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "SCMetadataManager.CreateProducer.wait", + "1*sleep(60000)")); + + // This request should not timeout as client timeout is 10 mins. + DS_ASSERT_OK(client1->CreateProducer("ProducerLongTimeout1", producer, conf)); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(ProducerTest, LEVEL1_TestCreateProducerLongTimeout2) +{ + // Request should timeout if client timeout is set to 15s and master takes more + + // Set timeout to default + std::shared_ptr client1; + const int timeoutMs = 15000; + ASSERT_EQ(CreateClient(0, timeoutMs, client1), Status::OK()); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + + // Make master wait for 1 min and whole CreateProducer() request should timeout + // We actually dont know who is the master so inject in both + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "SCMetadataManager.CreateProducer.wait", + "1*sleep(20000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "SCMetadataManager.CreateProducer.wait", + "1*sleep(20000)")); + + // This request should fail as timeout is 15secs and master takes more than that + DS_ASSERT_NOT_OK(client1->CreateProducer("ProducerLongTimeout2", producer, conf)); +} + +TEST_F(ProducerTest, LEVEL2_TestCreateProducerLongTimeout3) +{ + // MasterWorkerSCServiceImpl::SyncConsumerNode takes long time + + // set client timeout to 10 mins + std::shared_ptr client1; + const int32_t timeoutMs = 1000 * 60 * 10; + ASSERT_EQ(CreateClient(0, timeoutMs, client1), Status::OK()); + + // Create a consumer so that we can get SyncConsumerNode + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::shared_ptr newConsumer; + SubscriptionConfig consumerConf("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("ProducerLongTimeout3", consumerConf, newConsumer)); + + // Make worker wait for 1 min in SyncConsumerNode() and CreateProducer request should not timeout + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, + "MasterWorkerSCServiceImpl.SyncConsumerNode.sleep", "1*sleep(60000)")); + + // Check request should not timeout as client timeout is 10 mins. + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = 67108864; + DS_ASSERT_OK(client1->CreateProducer("ProducerLongTimeout3", producer, conf)); +} + +TEST_F(ProducerTest, TestMultiLocalProducerCreateClose) +{ + // This test case tests multi local producers + // They will generate single master call + // They all can be created and closed without an error + const int num_producer = 10; + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + + // Create a producer config + auto streamName = "MultiLocalProdCreateClose"; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + + // Create 10 producers on same worker for same stream + std::vector> producerList; + for (int i = 0; i < num_producer; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + producerList.emplace_back(producer); + } + + // Master should get only one request and count should be 1 + ASSERT_EQ(CheckProducerCount(client1, streamName), 1); + + DS_ASSERT_OK(producerList[0]->Close()); + // Count should not change + ASSERT_EQ(CheckProducerCount(client1, streamName), 1); + + // Close remaining 9 producers on same worker for same stream + for (int i = 1; i < num_producer; i++) { + DS_ASSERT_OK(producerList[i]->Close()); + } + ASSERT_EQ(CheckProducerCount(client1, streamName), 0); + DS_ASSERT_OK(TryAndDeleteStream(client1, streamName)); +} + +TEST_F(ProducerTest, TestMultiLocalProducerSendReceive) +{ + // This test case tests multi local producers + // They will generate single master call + // All of them can send data to a consumer + const int num_producer = 10; + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + + // Create a producer config + auto streamName = "MultiLocalProdSendReceive"; + ProducerConf conf; + const uint64_t maxStreamSize = 2 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + + // Create 10 producers on worker1 for same stream + std::vector> producerList; + for (int i = 0; i < num_producer; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + producerList.emplace_back(producer); + } + + // Create consumer on worker2 for the stream + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + // Master should get only one request and count should be 1 + ASSERT_EQ(CheckProducerCount(client1, streamName), 1); + + // Send and Receive data from all producers + // Only 1 element per page. + const size_t elementSize = KB; + Element element; + std::vector writeElement; + CreateElement(elementSize, element, writeElement); + + // Normal Send. + for (int i = 0; i < num_producer; i++) { + DS_ASSERT_OK(producerList[i]->Send(element)); + } + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(num_producer, RECV_WAIT_MILLI_SECONDS, outElements)); + ASSERT_EQ(outElements.size(), num_producer); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); + + // Close remaining 9 producers on same worker for same stream + for (int i = 0; i < num_producer; i++) { + DS_ASSERT_OK(producerList[i]->Close()); + } + ASSERT_EQ(CheckProducerCount(client1, streamName), 0); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(TryAndDeleteStream(client1, streamName)); +} + +TEST_F(ProducerTest, TestDuplicatedBlockedCreateRequest) +{ + // 1 Producer -> 1 Consumer Same Node. + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + std::string streamName = "DupBlockedCreateReq"; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + + // Only 1 element per page. + const size_t elementSize = 900 * KB; + Element element; + std::vector writeElement; + CreateElement(elementSize, element, writeElement); + + // Normal Send. + DS_ASSERT_OK(producer->Send(element)); + + // ZMQ timeout before the BlockedCreateRequest is processed normally or timer expried but not yet remove the + // BlockedCreateRequest. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "UnblockCreators.sleep", "sleep(10000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep", "sleep(10000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "ClientWorkerSCServiceImpl.HandleBlockedCreateTimeout.sleep", + "sleep(10000)")); + DS_ASSERT_OK(inject::Set("client.CreateWritePage", "call()")); // rpc timeout = timeout below + DS_ASSERT_NOT_OK(producer->Send(element, 10)); // Timer with less than 10ms + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep")); + DS_ASSERT_OK(inject::Clear("client.CreateWritePage")); + + // Add new BlockedCreateRequest to unordered map, there should be a one exist already for the same producer. + DS_ASSERT_OK(producer->Send(element, 10)); // Timer with less than 10ms + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "ClientWorkerSCServiceImpl.HandleBlockedCreateTimeout.sleep")); + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "UnblockCreators.sleep")); + + // Clean up + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client1->DeleteStream(streamName)); +} + +TEST_F(ProducerTest, TestDuplicatedBlockedCreateRequestOutOfOrder) +{ + // 1 Producer -> 1 Consumer Same Node. + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + std::string streamName = "DupBlockReqOutOfOrder"; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + + // Only 1 element per page. + const size_t elementSize = 900 * KB; + Element element; + std::vector writeElement; + CreateElement(elementSize, element, writeElement); + + // Normal Send. + DS_ASSERT_OK(producer->Send(element)); + + // Worker Thread A received CreateShmPage request A, but stuck before getting StreamManager lock. + // Client rpc timeout. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "UnblockCreators.sleep", "sleep(10000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamManager.AddBlockCreateRequest.sleep", "sleep(5000)")); + DS_ASSERT_OK(inject::Set("client.CreateWritePage", "call()")); // rpc timeout = timeout below + DS_ASSERT_NOT_OK(producer->Send(element, 10)); // Timer with less than 10ms + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "StreamManager.AddBlockCreateRequest.sleep")); + DS_ASSERT_OK(inject::Clear("client.CreateWritePage")); + + // Worker Thread B received CreateSgmPage request B, created and added a new BlockedCreateRequest into blockedList. + // Worker Thread B stuck at getting the BlockedCreateRequest B out from the blockedList. + // Worker Thread A try to add a new BlockedCreateRequest A but find BlockedCreateRequest B. + // Worker Thread A: since BlockedCreateRequest B has request pb timestamp later than BlockedCreateRequest A's + // request pb timestamp, Worker Thread A do not add BlockedCreateRequest A into blockedList. + // Worker Thread B continue to process BlockedCreateRequest B and return success to client. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep", "sleep(6000)")); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep")); + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "UnblockCreators.sleep")); + + // Clean up + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client1->DeleteStream(streamName)); +} + +class ProducerNoKeysTest : public ProducerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + ProducerTest::SetClusterSetupOptions(opts); + // Set these to not generate signature. + // Requests to create shm page by the same producer should have non-zero unique timestamp. + accessKey_ = ""; + secretKey_ = ""; + opts.systemAccessKey = accessKey_; + opts.systemSecretKey = secretKey_; + } +}; + +TEST_F(ProducerNoKeysTest, TestDuplicatedBlockedCreateRequestNoSignature) +{ + // 1 Producer -> 1 Consumer Same Node. + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 100 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + std::string streamName = "DupBlockReqNoSig"; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + + // Only 1 element per page. + const size_t elementSize = 900 * KB; + Element element; + std::vector writeElement; + CreateElement(elementSize, element, writeElement); + + // Normal Send. + DS_ASSERT_OK(producer->Send(element)); + + // ZMQ timeout before the BlockedCreateRequest is processed normally or timer expried but not yet remove the + // BlockedCreateRequest. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "UnblockCreators.sleep", "sleep(10000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep", "sleep(10000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "ClientWorkerSCServiceImpl.HandleBlockedCreateTimeout.sleep", + "sleep(10000)")); + DS_ASSERT_OK(inject::Set("client.CreateWritePage", "call()")); // rpc timeout = timeout below + DS_ASSERT_NOT_OK(producer->Send(element, 10)); // Timer with less than 10ms + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "GetBlockedCreateRequest.sleep")); + DS_ASSERT_OK(inject::Clear("client.CreateWritePage")); + + // Add new BlockedCreateRequest to unordered map, there should be a one exist already for the same producer. + DS_ASSERT_OK(producer->Send(element, 10)); // Timer with less than 10ms + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "ClientWorkerSCServiceImpl.HandleBlockedCreateTimeout.sleep")); + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "UnblockCreators.sleep")); + + // Clean up + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client1->DeleteStream(streamName)); +} + +TEST_F(ProducerTest, TestScanAndEvalRecycledPage) +{ + // 1 Producer -> 1 Remote Consumer + std::shared_ptr client1; + std::shared_ptr client2; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("ScanAndEvalRecycledPage", config, consumer)); + + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 128 * KB; + const uint64_t pageSize = 28 * KB; // up to 4 pages allowed in the stream. + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + DS_ASSERT_OK(client1->CreateProducer("ScanAndEvalRecycledPage", producer, conf)); + + const size_t elementSize = 20 * KB; + const size_t numElementPerPage = 1; + Element element; + std::vector writeElement; + CreateElement(elementSize, element, writeElement); + + // What do we want to solve here? + // 1. Producer send 2 elements, so 2 pages (A and B). + // 2. Page A is acked. + // 3. ScanAndEval thread locate the last page (B). + // 4. ScanAndEval thread sleep. + // 5. Page B is acked. + // 6. Producer send 2 elements, reused acked pages (A and B). + // 7. ScanAndEval thread finish sleep, try to receive elements from page (B). + // 8. The begCursor is updated while ScanAndEval thread holding page (B) + // causing coredump because the slot value is garbage. + // + // What do we do to make to make the testcase passed? + // In step 6 above, do not reused page B since page B is holded by the ScanAndEval thread, instead, + // create a new page. + DS_ASSERT_OK(producer->Send(element)); + sleep(1); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "UpdateLastAckCursorUnlocked.sleep", "sleep(1000)")); + DS_ASSERT_OK(producer->Send(element)); + sleep(1); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamDataPage::Receive.sleep", "sleep(15000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "AppendFreePagesImplNotLocked", "sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamDataPage::Receive.fake.BIG_ELEMENT", "call()")); + const int TWO_SEC = 2; + sleep(TWO_SEC); + DS_ASSERT_OK(producer->Send(element)); + const int TEN_SEC = 10; + sleep(TEN_SEC); + DS_ASSERT_OK(producer->Send(element)); + + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "StreamDataPage::Receive.sleep")); + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "UpdateLastAckCursorUnlocked.sleep")); + + // Receive all 4 elements. + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(numElementPerPage, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), numElementPerPage); + outElements.clear(); + + DS_ASSERT_OK(consumer->Receive(numElementPerPage, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), numElementPerPage); + outElements.clear(); + + DS_ASSERT_OK(consumer->Receive(numElementPerPage, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), numElementPerPage); + outElements.clear(); + + DS_ASSERT_OK(consumer->Receive(numElementPerPage, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), numElementPerPage); + outElements.clear(); + + // Clean up + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client1->DeleteStream("ScanAndEvalRecycledPage")); + + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "StreamDataPage::Receive.fake.BIG_ELEMENT")); +} + +TEST_F(ProducerTest, TestProducerDiscardPrivateBuffer) +{ + std::shared_ptr client1; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + std::shared_ptr producer; + const size_t SEND_COUNT = 10; + const size_t testSize = 100 * KB; + ProducerConf conf; + conf.maxStreamSize = 10 * 1024 * 1024; + DS_ASSERT_OK(client1->CreateProducer("DiscardPrivateBuffer", producer, conf)); + std::thread producerThrd([&client1, &producer]() { + const int DEFAULT_SLEEP_TIME = 300; + Element element; + std::vector writeElement; + uint64_t numOfConsumers = 0; + while (numOfConsumers != 1) { + client1->QueryGlobalConsumersNum("DiscardPrivateBuffer", numOfConsumers); + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + } + CreateElement(testSize, element, writeElement); + for (size_t i = 0; i < SEND_COUNT; i++) { + Status rc = producer->Send(element); + int retryCount = 30; + while (rc.GetCode() == K_OUT_OF_MEMORY && retryCount-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + DS_ASSERT_OK(rc); + } + DS_ASSERT_OK(producer->Close()); + }); + + // Inject sleep to extend the consumer timing hole + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "worker.UsageMonitor.CheckOverUsedForStream.MockError", + "100*return(K_OUT_OF_MEMORY)")); + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("DiscardPrivateBuffer", config, consumer)); + + const int DEFAULT_WAIT_TIME = 1000; + const int DEFAULT_RETRY_TIME = 10; + Timer timer; + std::vector outElements; + int sendCount = SEND_COUNT; + while (sendCount > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + if (!outElements.empty()) { + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + sendCount -= outElements.size(); + } + } + ASSERT_EQ(sendCount, 0); + producerThrd.join(); +} + +TEST_F(ProducerTest, TestStaleCursorEarlyReclaim) +{ + // This testcase tests that the early reclaim of shm happens with old last append cursor + // can trigger data loss. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + } + std::string streamName = "StaleCursorEarlyReclaim"; + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[1]->Subscribe(streamName, config, consumer)); + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + const size_t testPageSize = 8 * KB; + conf.pageSize = testPageSize; + DS_ASSERT_OK(clients[0]->CreateProducer(streamName, producer, conf)); + const size_t testSize = 5 * KB; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamManager.RemoteAck.delay", "1*sleep(2000)")); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamDataPool.SendElementsToRemote.wait", "1*sleep(2000)")); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer->Close()); + const int DEFAULT_WAIT_TIME = 10000; + std::vector outElements; + const size_t expectedNum = 2; + DS_ASSERT_OK(consumer->Receive(expectedNum, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), expectedNum); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients[0]->DeleteStream(streamName)); +} + +TEST_F(ProducerTest, TestCloseProducerEarlyReclaim) +{ + // Test that close producer triggers early reclaim when consumer was also local. + std::shared_ptr client1; + std::shared_ptr client2; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::string streamNameBase = "CloseProducerEarlyReclaim"; + // There is only 64MB shm, so only 6 streams can be created on this node. + ProducerConf conf; + const uint64_t pageSize = 10 * MB; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.pageSize = pageSize; + const int streamNum = 6; + + std::vector> consumers(streamNum); + std::vector> producers(streamNum); + for (int i = 0; i < streamNum; i++) { + std::string streamName = "CloseProducerEarlyReclaim" + std::to_string(i); + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumers[i])); + + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producers[i], conf)); + } + // Now the 7th stream should fail to create. + std::shared_ptr producer; + DS_ASSERT_NOT_OK(client1->CreateProducer(streamNameBase, producer, conf)); + // But if one of the stream got their producer and consumer all closed, + // the 7th stream can be created. + // Close producer comes last. + DS_ASSERT_OK(consumers[0]->Close()); + DS_ASSERT_OK(producers[0]->Close()); + DS_ASSERT_OK(client1->CreateProducer(streamNameBase, producer, conf)); +} + +TEST_F(ProducerTest, TestCloseConsumerEarlyReclaim) +{ + // Test that close consumer triggers early reclaim. + std::shared_ptr client1; + std::shared_ptr client2; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::string streamNameBase = "CloseConsumerEarlyReclaim"; + // There is only 64MB shm, so only 6 streams can be created on this node. + ProducerConf conf; + const uint64_t pageSize = 10 * MB; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.pageSize = pageSize; + const int streamNum = 6; + + std::vector> consumers(streamNum); + std::vector> producers(streamNum); + for (int i = 0; i < streamNum; i++) { + std::string streamName = "CloseProducerEarlyReclaim" + std::to_string(i); + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumers[i])); + + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producers[i], conf)); + } + // Now the 7th stream should fail to create. + std::shared_ptr producer; + DS_ASSERT_NOT_OK(client1->CreateProducer(streamNameBase, producer, conf)); + // But if one of the stream got their producer and consumer all closed, + // the 7th stream can be created. + // Close consumer comes last. + DS_ASSERT_OK(producers[0]->Close()); + DS_ASSERT_OK(consumers[0]->Close()); + DS_ASSERT_OK(client1->CreateProducer(streamNameBase, producer, conf)); +} + +TEST_F(ProducerTest, TestEarlyReclaimDeadlock) +{ + // Test that with incorrect code order, deadlock can be triggered with reclaimMutex_. + std::shared_ptr client1; + std::shared_ptr client2; + ASSERT_EQ(CreateClient(0, client1), Status::OK()); + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::string streamNameBase = "CloseConsumerEarlyReclaim"; + // There is only 64MB shm, so only 6 streams can be created on this node. + ProducerConf conf; + const uint64_t pageSize = 10 * MB; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.pageSize = pageSize; + const int streamNum = 6; + + std::vector> consumers(streamNum); + std::vector> producers(streamNum); + for (int i = 0; i < streamNum; i++) { + std::string streamName = "CloseProducerEarlyReclaim" + std::to_string(i); + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumers[i])); + + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producers[i], conf)); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamNameBase, config, consumer)); + // Now the 7th stream should fail to create. + std::shared_ptr producer; + DS_ASSERT_NOT_OK(client1->CreateProducer(streamNameBase, producer, conf)); + // Now test that CloseConsumer would not deadlock. + DS_ASSERT_OK(consumer->Close()); +} + +TEST_F(ProducerTest, LEVEL1_TestCreateProducerTimeout1) +{ + // This testcase tests the case that if the CreateProducer requests take too long on master, + // master will check the timeout and return early before actual timeout. + // Sleep is injected so that by the time the thread pool picks up the request, it already timed out. + // Consumer is on same node as the producers, so SyncConsumerNode and UpdateTopoNotification are not sent. + const int timeoutMs = 10000; + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, timeoutMs, client), Status::OK()); + clients.push_back(client); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "master.CreateProducer", "1*sleep(8000)")); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[0]->Subscribe("CreateProducerTimeout1", config, consumer)); + ThreadPool producerPool(NUM_WORKERS); + auto producerFunc([&clients](uint32_t index) { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + if (clients[index]->CreateProducer("CreateProducerTimeout1", producer, conf).IsError()) { + return std::shared_ptr(); + } + return producer; + }); + Timer timer; + auto ptr = producerFunc(0); + ASSERT_EQ(ptr, nullptr); + auto timeCost = timer.ElapsedMilliSecond(); + LOG(INFO) << "Elapsed time: " << timeCost; + // sleep a bit so the CreateProducer request actually goes through on master after the timeout. + const int DEFAULT_WAIT_TIME = 3; + sleep(DEFAULT_WAIT_TIME); + auto producer = producerFunc(0); + ASSERT_NE(producer, nullptr); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients[0]->DeleteStream("CreateProducerTimeout1")); +} + +TEST_F(ProducerTest, LEVEL1_TestCreateProducerTimeout2) +{ + // This testcase tests the case that if rollback fails with timeout, the producer count is still handled. + // Injection is to simulate SyncConsumerNode fail with timeout, and to trigger rollback logic. + // And also that the rollback ClearAllRemoteConsumer fail with timeout. + // This is to make sure the producer count is still decremented. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "master.PubIncreaseNodeImpl.beforeSendNotification", + "1*return(K_RPC_UNAVAILABLE)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "MasterWorkerSCServiceImpl.ClearAllRemoteConsumer.sleep", + "1*sleep(40000)")); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[1]->Subscribe("CreateProducerTimeout2", config, consumer)); + ThreadPool producerPool(NUM_WORKERS); + auto producerFunc([&clients](uint32_t index) { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + if (clients[index]->CreateProducer("CreateProducerTimeout2", producer, conf).IsError()) { + return std::shared_ptr(); + } + return producer; + }); + producerPool.Execute([&producerFunc, i = 0]() { ASSERT_EQ(nullptr, producerFunc(i)); }); + // sleep a bit so the CreateProducer RPC request is sent. + sleep(1); + const int producerCount = 3; + std::vector>> prodFutures; + for (int i = 0; i < producerCount; i++) { + prodFutures.push_back(producerPool.Submit([&producerFunc, i = 0]() { return producerFunc(i); })); + } + for (auto &fut : prodFutures) { + auto producer = fut.get(); + ASSERT_NE(producer, nullptr); + DS_ASSERT_OK(producer->Close()); + } + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients[0]->DeleteStream("CreateProducerTimeout2")); +} + +TEST_F(ProducerTest, LEVEL1_TestCreateProducerTimeout3) +{ + // This testcase tests the case that CreateProducer can send UpdateTopoNotification through local bypass instead of + // actual RPC. In that case it can be blocked by some locks and go beyond scTimeoutDuration. Then worker->master + // CreateProducer will timeout, and that will release the create lock on worker, so other CreateProducer of the same + // stream from the same worker can go through. Since the related change would allow parallel CreateProducer on + // master, now it can for example get OK because it is not the first producer from the worker, but the first + // producer request is still running and can fail. Injection is to simulate UpdateTopoNotification takes too long. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + const int timeout = 30000; + ASSERT_EQ(CreateClient(i, timeout, client), Status::OK()); + clients.push_back(client); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "master.PubIncreaseNodeImpl.beforeSendNotification", + "1*sleep(25000)")); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[1]->Subscribe("CreateProducerTimeout3", config, consumer)); + ThreadPool producerPool(NUM_WORKERS); + auto producerFunc([&clients](uint32_t index, std::shared_ptr &producer) { + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + return clients[index]->CreateProducer("CreateProducerTimeout3", producer, conf); + }); + producerPool.Execute([&producerFunc, i = 0]() { + std::shared_ptr producer; + ASSERT_EQ(producerFunc(i, producer).GetCode(), K_RPC_UNAVAILABLE); + }); + // sleep a bit so the first CreateProducer RPC request is sent. + sleep(1); + std::shared_ptr producer; + Status rc = producerFunc(0, producer); + ASSERT_EQ(rc.GetCode(), K_TRY_AGAIN); + // Wait for the first CreateProducer request to rollback on master, and then the new requests should work. + const int DEFAULT_WAIT_TIME = 5; + sleep(DEFAULT_WAIT_TIME); + DS_ASSERT_OK(producerFunc(0, producer)); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients[0]->DeleteStream("CreateProducerTimeout3")); +} + +TEST_F(ProducerTest, LEVEL1_TestParallelCreateCloseProducer) +{ + // This testcase tests that CreateProducer is blocked by worker level create lock + // when last producer on the worker is getting closed on master. + // This is so that the conflict between CreateProducer and CloseProducer is mitigated. + std::vector> clients; + for (int i = 0; i <= 1; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, i, "master.PubDecreaseNode.beforeSendNotification", "1*sleep(7000)")); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[1]->Subscribe("ParallelCreateCloseProducer", config, consumer)); + ThreadPool producerPool(1); + auto producerFunc([&clients](uint32_t index, std::shared_ptr &producer) { + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + return clients[index]->CreateProducer("ParallelCreateCloseProducer", producer, conf); + }); + std::shared_ptr producer; + DS_ASSERT_OK(producerFunc(0, producer)); + producerPool.Execute([&producer]() { DS_ASSERT_OK(producer->Close()); }); + // sleep a bit so the CloseProducer RPC request is sent. + sleep(1); + std::shared_ptr producer2; + DS_ASSERT_OK(producerFunc(0, producer2)); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients[0]->DeleteStream("ParallelCreateCloseProducer")); +} + +TEST_F(ProducerTest, TestLocalClearAllRemoteConsumerParallelSubscribe) +{ + // This testcase tests that the local ClearAllRemoteConsumer after CloseProducer + // gets triggered after a remote Subscribe is done. + // This creates a timing hole where ClearAllRemoteConsumer is called after new remote consumer is added, + // and before scan is done. + // But functionally it should be unaffected because ClearAllRemoteConsumer does not do flush nor RemoveStreamObject. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + // Delay the local ClearAllRemoteConsumer after CloseProducer. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "StreamManager.CloseProducer.timing", "1*sleep(2000)")); + } + // Delay scan. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamDataPool.SendElementsToRemote.wait", "1*sleep(4000)")); + ThreadPool producerPool(NUM_WORKERS); + auto producerFunc([&clients](uint32_t index, std::shared_ptr &producer) { + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + conf.retainForNumConsumers = 1; + return clients[index]->CreateProducer("ConsumerParallelSubscribe", producer, conf); + }); + std::shared_ptr producer; + DS_ASSERT_OK(producerFunc(0, producer)); + auto fut = producerPool.Submit([&producer]() { + std::string data = "H"; + Element element(reinterpret_cast(&data.front()), data.size()); + RETURN_IF_NOT_OK(producer->Send(element)); + return producer->Close(); + }); + // sleep a bit so the CloseProducer RPC request is sent. + sleep(1); + const int DEFAULT_WAIT_TIME = 5000; + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[1]->Subscribe("ConsumerParallelSubscribe", config, consumer)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(1)); + DS_ASSERT_OK(fut.get()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients[0]->DeleteStream("ConsumerParallelSubscribe")); +} + +TEST_F(ProducerTest, SendReturnOOMTest) +{ + std::shared_ptr client; + ASSERT_EQ(CreateClient(0, client), Status::OK()); + + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + conf.retainForNumConsumers = 1; + DS_ASSERT_OK(client->CreateProducer("SendReturnOOM", producer, conf)); + + Element element; + std::vector writeElement; + const uint elementSize = 900 * KB; + CreateElement(elementSize, element, writeElement); + + DS_ASSERT_OK(producer->Send(element)); + + // We first sleep 5 seconds to let the timeout expired in the first loop, producer then request a new page. + // After getting the new page, we return K_TRY_AGAIN in writePage_->Insert(...) in the second loop to simulate + // timeout on getting the page lock on the new page. + DS_ASSERT_OK(datasystem::inject::Set("producer_insert", "1*sleep(5000)->1*return(K_TRY_AGAIN)")); + const uint timeout = 1000; + ASSERT_EQ(producer->Send(element, timeout).GetCode(), K_OUT_OF_MEMORY); +} + +TEST_F(ProducerTest, TestParallelProducerUse) +{ + std::shared_ptr client; + ASSERT_EQ(CreateClient(0, client), Status::OK()); + + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer("ParallelProducerUse", producer)); + + DS_ASSERT_OK(datasystem::inject::Set("CheckAndSetInUse.success.sleep", "sleep(5000)")); + + std::string data = "Hello"; + Element element(reinterpret_cast(&data.front()), data.size()); + + // Create a producer thread that Send() last at least 5 seconds. + ThreadPool pool(1); + auto producerSendFunc([&producer, &element]() { return producer->Send(element); }); + std::future fut = pool.Submit([&producerSendFunc]() { return producerSendFunc(); }); + + sleep(1); + + // Parallel call from the same producer should fail. + StatusCode expectedCode = K_SC_STREAM_IN_USE; + ASSERT_EQ(producer->Send(element).GetCode(), expectedCode); + ASSERT_EQ(producer->Close().GetCode(), expectedCode); + + DS_ASSERT_OK(fut.get()); + + DS_ASSERT_OK(datasystem::inject::Clear("CheckAndSetInUse.success.sleep")); + + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(ProducerTest, EXCLUSIVE_TestParallelLocalCreateCloseProducer) +{ + // In this testcase, we will create multiple local producers in parallel + const int num_producers = 10; + std::shared_ptr client; + ASSERT_EQ(CreateClient(0, client), Status::OK()); + + ThreadPool producerPool(num_producers); + auto producerCreateFunc([&client]() { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + if (client->CreateProducer("ParallelLocalCreateCloseProd", producer, conf).IsError()) { + return std::shared_ptr(); + } + return producer; + }); + + // Create multiple local producers in the same node in parallel + std::vector>> prodFutures; + for (int i = 0; i < num_producers; i++) { + prodFutures.push_back(producerPool.Submit([&producerCreateFunc]() { return producerCreateFunc(); })); + } + + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe("ParallelLocalCreateCloseProd", config, consumer)); + + std::vector> producers; + for (auto &fut : prodFutures) { + auto producer = fut.get(); + ASSERT_NE(producer, nullptr); + producers.push_back(producer); + } + + // Master should get only one request and count should be 1 + ASSERT_EQ(CheckProducerCount(client, "ParallelLocalCreateCloseProd"), 1); + + // Close multiple local producers in the same node in parallel + std::vector> prodCloseFutures; + for (int i = 0; i < num_producers; i++) { + prodCloseFutures.push_back(producerPool.Submit([&producers, i]() { return producers[i]->Close(); })); + } + for (auto &fut : prodCloseFutures) { + DS_ASSERT_OK(fut.get()); + } + + // Master should get only one request and count should be 0 + ASSERT_EQ(CheckProducerCount(client, "ParallelLocalCreateCloseProd"), 0); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(TryAndDeleteStream(client, "ParallelLocalCreateCloseProd")); +} + +TEST_F(ProducerTest, EXCLUSIVE_TestParallelLocalCreateCloseProducerRollBack) +{ + // In this testcase, we will create multiple local producers in parallel + const int num_producers = 10; + std::shared_ptr client; + ASSERT_EQ(CreateClient(0, client), Status::OK()); + std::string streamName = "ParallelLocalCreateCloseProdRollBack"; + ThreadPool producerPool(num_producers); + auto producerCreateFunc([&client, streamName]() { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + if (client->CreateProducer(streamName, producer, conf).IsError()) { + return std::shared_ptr(); + } + return producer; + }); + + // Make first CreateProducer Call fail + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CreateProducer.beforeSendToMaster", + "1*return(K_RUNTIME_ERROR)")); + + // Create multiple local producers in the same node in parallel + std::vector>> prodFutures; + for (int i = 0; i < num_producers; i++) { + prodFutures.push_back(producerPool.Submit([&producerCreateFunc]() { return producerCreateFunc(); })); + } + + std::shared_ptr client2; + ASSERT_EQ(CreateClient(1, client2), Status::OK()); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + std::vector> producers; + bool gotFailedProducer = false; + for (auto &fut : prodFutures) { + auto producer = fut.get(); + if (gotFailedProducer) { + // Only one producer should be failed + ASSERT_NE(producer, nullptr); + } + if (producer == nullptr) { + // mark first producer fail + gotFailedProducer = true; + continue; + } + producers.push_back(producer); + } + + // Master should get only one request and count should be 1 + ASSERT_EQ(CheckProducerCount(client, streamName), 1); + + // Close multiple local producers in the same node in parallel + std::vector> prodCloseFutures; + for (auto &producer : producers) { + prodCloseFutures.push_back(producerPool.Submit([&producer]() { return producer->Close(); })); + } + for (auto &fut : prodCloseFutures) { + DS_ASSERT_OK(fut.get()); + } + + // Master should get only one request and count should be 0 + ASSERT_EQ(CheckProducerCount(client, streamName), 0); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(TryAndDeleteStream(client, streamName)); +} + +class LargeScaleProducerTest : public ProducerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + ProducerTest::SetClusterSetupOptions(opts); + opts.numWorkers = NUM_WORKERS; + // Enable stream data verification for testing purposes + opts.workerGflagParams += " -enable_stream_data_verification=true"; + } + + void SetUp() override + { + ProducerTest::SetUp(); + } + + void TearDown() override + { + ProducerTest::TearDown(); + } + +protected: + const int NUM_WORKERS = 10; +}; + +TEST_F(LargeScaleProducerTest, EXCLUSIVE_TestParallelCreateProducer1) +{ + // This testcase tests that CreateProducer can be handled in parallel on master. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, i, "MasterWorkerSCServiceImpl.SyncConsumerNode.sleep", "sleep(5000)")); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[0]->Subscribe("ParallelCreate1", config, consumer)); + ThreadPool producerPool(NUM_WORKERS); + auto producerFunc([&clients](uint32_t index) { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + if (clients[index]->CreateProducer("ParallelCreate1", producer, conf).IsError()) { + return std::shared_ptr(); + } + return producer; + }); + Timer timer; + // Multiple producers per node, to also test out the worker level create lock. + std::vector>> prodFutures; + for (int i = 0; i < NUM_WORKERS; i++) { + prodFutures.push_back(producerPool.Submit([&producerFunc, i]() { return producerFunc(i); })); + } + std::vector> producers; + for (auto &fut : prodFutures) { + auto producer = fut.get(); + ASSERT_NE(producer, nullptr); + producers.push_back(producer); + } + auto timeCost = timer.ElapsedMilliSecond(); + LOG(INFO) << "Elapsed time for Create Producer: " << timeCost; + // A 5-second is injected to the SyncConsumerNode request, so the requests should take more than 5 seconds. + const uint64_t minExpectedTime = 5000; + // While they should run in parallel, so the total elapsed time should not be too off from 5 seconds. + const uint64_t maxExpectedTime = minExpectedTime + 500; + ASSERT_TRUE(timeCost >= minExpectedTime && timeCost <= maxExpectedTime); + for (auto &producer : producers) { + DS_ASSERT_OK(producer->Close()); + } + Timer timer1; + DS_ASSERT_OK(consumer->Close()); + timeCost = timer1.ElapsedMilliSecond(); + LOG(INFO) << "Elapsed time for Close Consumer: " << timeCost; + + DS_ASSERT_OK(clients.back()->DeleteStream("ParallelCreate1")); +} + +TEST_F(LargeScaleProducerTest, EXCLUSIVE_TestParallelCreateProducer2) +{ + // This testcase tests that CreateProducer can be handled in parallel on master. + // In this case, create multiple producers per node, to also test out the worker level create lock. + // Also test stream data verification logic, to make sure that producer number is still handled correctly. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[0]->Subscribe("ParallelCreate2", config, consumer)); + ThreadPool producerPool(NUM_WORKERS); + auto producerFunc([&clients](uint32_t index) { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + if (clients[index]->CreateProducer("ParallelCreate2", producer, conf).IsError()) { + return std::shared_ptr(); + } + return producer; + }); + const int producerCount = 4; + Timer timer; + std::vector>> prodFutures; + for (int i = 0; i < NUM_WORKERS; i++) { + for (int j = 0; j < producerCount; j++) { + prodFutures.push_back(producerPool.Submit([&producerFunc, i]() { return producerFunc(i); })); + } + } + std::vector> producers; + for (auto &fut : prodFutures) { + auto producer = fut.get(); + ASSERT_NE(producer, nullptr); + std::string data = "H"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer->Close()); + } + const int totalNum = producerCount * NUM_WORKERS; + const int DEFAULT_WAIT_TIME = 5000; + std::vector outElements; + ASSERT_EQ(consumer->Receive(totalNum, DEFAULT_WAIT_TIME, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(totalNum)); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients.back()->DeleteStream("ParallelCreate2")); +} + +TEST_F(LargeScaleProducerTest, EXCLUSIVE_TestParallelCreateProducer3) +{ + // This testcase tests that Subscribe happens in between of CreateProducer requests. + // Injection is done at master level, so 5s will be spent even if there is no consumer for SyncConsumerNode. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, i, "master.PubIncreaseNodeImpl.beforeSendNotification", "sleep(5000)")); + } + ThreadPool producerPool(NUM_WORKERS); + auto producerFunc([&clients](uint32_t index) { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + if (clients[index]->CreateProducer("ParallelCreate3", producer, conf).IsError()) { + return std::shared_ptr(); + } + return producer; + }); + Timer timer; + // Multiple producers per node, to also test out the worker level create lock. + std::vector>> prodFutures; + const int halfWorkers = NUM_WORKERS / 2; + for (int i = 0; i < halfWorkers; i++) { + prodFutures.push_back(producerPool.Submit([&producerFunc, i]() { return producerFunc(i); })); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[0]->Subscribe("ParallelCreate3", config, consumer)); + for (int i = halfWorkers; i < NUM_WORKERS; i++) { + prodFutures.push_back(producerPool.Submit([&producerFunc, i]() { return producerFunc(i); })); + } + std::vector> producers; + for (auto &fut : prodFutures) { + auto producer = fut.get(); + ASSERT_NE(producer, nullptr); + producers.push_back(producer); + } + auto timeCost = timer.ElapsedMilliSecond(); + LOG(INFO) << "Elapsed time: " << timeCost; + // A 5-second is injected to the SyncConsumerNode request, so the requests should take more than 5 seconds. + // But the Subscribe request will hold the xlock and accessor, so in total it should take more than 10 seconds. + const uint64_t minExpectedTime = 10000; + // While they should run in parallel, so the total elapsed time should not be too off from 5 seconds. + const uint64_t maxExpectedTime = minExpectedTime + 500; + ASSERT_TRUE(timeCost >= minExpectedTime && timeCost <= maxExpectedTime); + for (auto &producer : producers) { + DS_ASSERT_OK(producer->Close()); + } + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients.back()->DeleteStream("ParallelCreate3")); +} + +TEST_F(LargeScaleProducerTest, DISABLED_TestParallelCloseProducer) +{ + // This testcase tests that CloseProducer can be handled in parallel on master. + std::vector> clients; + for (int i = 0; i < NUM_WORKERS; i++) { + std::shared_ptr client; + ASSERT_EQ(CreateClient(i, client), Status::OK()); + clients.push_back(client); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, i, "master.PubDecreaseNode.beforeSendNotification", "sleep(5000)")); + } + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients[0]->Subscribe("ParallelClose", config, consumer)); + ThreadPool producerPool(NUM_WORKERS); + const int producerCount = 4; + std::vector> producers; + for (int i = 0; i < NUM_WORKERS; i++) { + for (int j = 0; j < producerCount; j++) { + std::shared_ptr producer; + ProducerConf conf; + const uint64_t maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + DS_ASSERT_OK(clients[i]->CreateProducer("ParallelClose", producer, conf)); + producers.push_back(producer); + } + } + Timer timer; + std::vector> prodFutures; + for (auto &producer : producers) { + prodFutures.push_back(producerPool.Submit([producer]() { return producer->Close(); })); + } + for (auto &fut : prodFutures) { + DS_ASSERT_OK(fut.get()); + } + auto timeCost = timer.ElapsedMilliSecond(); + LOG(INFO) << "Elapsed time: " << timeCost; + // A 5-second is injected to the PubDecreaseNode, so the requests should take more than 5 seconds. + const uint64_t minExpectedTime = 5000; + // While they should run in parallel, so the total elapsed time should not be too off from 5 seconds. + const uint64_t maxExpectedTime = minExpectedTime + 500; + ASSERT_TRUE(timeCost >= minExpectedTime && timeCost <= maxExpectedTime); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(clients.back()->DeleteStream("ParallelClose")); +} + +TEST_F(ProducerTest, UpdateLocalPubLastDataPageFailed) +{ + // 2 producer -> 1 remote consumer. + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(CreateClient(0, client1)); + DS_ASSERT_OK(CreateClient(1, client2)); + std::string streamName = "UpdateLocalPubLastDataPgFail"; + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer, true)); + + std::shared_ptr producer1; + std::shared_ptr producer2; + ProducerConf conf; + const uint maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer1, conf)); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer2, conf)); + + // Each page is enough to hold only 1 element, so we are creating a new page for every element. + const uint testSize1 = 600 * KB; + Element element1; + std::vector writeElement1; + DS_ASSERT_OK(CreateElement(testSize1, element1, writeElement1)); + for (uint i = 0; i < K_TWO; ++i) { + DS_ASSERT_OK(producer1->Send(element1, RPC_TIMEOUT)); + DS_ASSERT_OK(producer2->Send(element1, RPC_TIMEOUT)); + } + + // Producer 1 and Producer 2 cursor's ShmView is at page 4. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "UpdateLocalPubLastDataPage.skip", "return(K_OK)")); + DS_ASSERT_OK(producer1->Send(element1, RPC_TIMEOUT)); + // Producer 2 cursor's ShmView remain at page 4. + + const uint SLEEP_TIME_SEC = 5; + sleep(SLEEP_TIME_SEC); // Page 4 is recycled during sleep. + + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "UpdateLocalPubLastDataPage.skip")); + const uint testSize2 = 300 * KB; + Element element2; + std::vector writeElement2; + DS_ASSERT_OK(CreateElement(testSize2, element2, writeElement2)); + + // Since page 4 is recycled, Producer 2 should create new page from worker. + DS_ASSERT_OK(producer2->Send(element2, RPC_TIMEOUT)); +} + +TEST_F(ProducerTest, TestProducerCloseAndNewProducerCreate) +{ + // 2 producer -> 1 remote consumer. + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(CreateClient(0, client1)); + DS_ASSERT_OK(CreateClient(1, client2)); + std::string streamName = "ProducerCloseAndNewProducerCreate"; + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer, true)); + + std::shared_ptr producer1; + std::shared_ptr producer2; + ProducerConf conf; + const uint maxStreamSize = 10 * MB; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer1, conf)); + + // Each page is enough to hold only 1 element, so we are creating a new page for every element. + const uint testSize1 = 600 * KB; + Element element1; + std::vector writeElement1; + DS_ASSERT_OK(CreateElement(testSize1, element1, writeElement1)); + for (uint i = 0; i < K_TWO; ++i) { + writeElement1[0] = '1' + i; + DS_ASSERT_OK(producer1->Send(element1, RPC_TIMEOUT)); + } + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamDataPool::ObjectPartition::RemoveStreamObject.sleep", + "1*sleep(2000)")); + DS_ASSERT_OK(producer1->Close()); + sleep(1); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer2, conf)); + + const uint testSize2 = 300 * KB; + Element element2; + std::vector writeElement2; + DS_ASSERT_OK(CreateElement(testSize2, element2, writeElement2)); + + for (uint i = 0; i < K_TWO; ++i) { + writeElement1[0] = '3' + i; + DS_ASSERT_OK(producer2->Send(element1, RPC_TIMEOUT)); + } + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(K_TWO + K_TWO, RPC_TIMEOUT, outElements)); + for (auto &ele : outElements) { + LOG(INFO) << ele.ptr[0]; + } + ASSERT_EQ(outElements.size(), K_TWO + K_TWO); + // The 1st element of producer2 is inserted in page2, which worker believe is the reserved page, + // but this page actually freed before returning CreateProducerRsp for creating producer2. + // Therefore, the above LOG(INFO) output is: 1, 2, 4 and missing 3. +} + +TEST_F(ProducerTest, TestCreateProducerWhenAllocPage) +{ + std::shared_ptr client; + std::string streamName = "stream001"; + + DS_ASSERT_OK(CreateClient(0, client)); + + const uint maxStreamSize = 10 * MB; + const uint pageSize = 64 * KB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(streamName, producer, conf)); + + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe(streamName, config, consumer)); + + const size_t sizeElement = 10 * KB; + std::string writeElement = RandomData().GetRandomString(sizeElement); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + const int threadNum = 3; + const size_t numElements = 1000; + ThreadPool pool(threadNum); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.AddCursor.afterLockCursorMutex", "sleep(10)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.UpdateLocalCursorLastDataPage.beforeLockCursorMutex", + "sleep(10)")); + std::vector> futs; + futs.push_back(pool.Submit([&producer, element] { + const int DEFAULT_SLEEP_TIME = 300; + int retryLimit = 30; + + for (size_t i = 0; i < numElements; i++) { + auto rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + } + RETURN_IF_NOT_OK(rc); + } + + return Status::OK(); + })); + + futs.push_back(pool.Submit([&consumer] { + Timer timer; + size_t remaining = numElements; + int round = 0; + const int PER_RECEIVE_NUM = 1; + const int DEFAULT_WAIT_TIME = 1000; + const int DEFAULT_RETRY_TIME = 30; + while (remaining > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + std::vector outElements; + RETURN_IF_NOT_OK(consumer->Receive(PER_RECEIVE_NUM, DEFAULT_WAIT_TIME, outElements)); + LOG(INFO) << "remaining num : " << remaining << ", receive num : " << outElements.size() << " ;" << round++; + if (!outElements.empty()) { + remaining -= outElements.size(); + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + } + } + CHECK_FAIL_RETURN_STATUS(remaining == 0, K_RUNTIME_ERROR, "failed to receive all data"); + return Status::OK(); + })); + + futs.push_back(pool.Submit([&client, &conf, &streamName] { + const int createCount = 10; + for (int i = 0; i < createCount; i++) { + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, conf)); + } + return Status::OK(); + })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } +} + +TEST_F(ProducerTest, TestSendWithZeroTimeoutNotWait) +{ + std::shared_ptr client; + std::string streamName = "stream001"; + DS_ASSERT_OK(CreateClient(0, client)); + + const uint maxStreamSize = 10 * MB; + const uint pageSize = 64 * KB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(streamName, producer, conf)); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe(streamName, config, consumer)); + + // Each element occupies a single page. + const size_t sizeElement = 50 * KB; + std::string writeElement = RandomData().GetRandomString(sizeElement); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + // using the reserve page. + DS_ASSERT_OK(producer->Send(element)); + + cluster_->SetInjectAction(WORKER, 0, "worker.Allocator.AllocateMemory", "return(K_OUT_OF_MEMORY)"); + Timer timer; + int maxSendElapsed = 100; + ASSERT_EQ(producer->Send(element, 0).GetCode(), K_OUT_OF_MEMORY); + auto elapsed = timer.ElapsedMilliSecondAndReset(); + LOG(INFO) << "elapsed:" << elapsed; + ASSERT_LT(elapsed, maxSendElapsed); +} + +TEST_F(ProducerTest, DISABLED_TestSendWithZeroTimeoutParallel) +{ + std::shared_ptr client; + std::string streamName = "stream001"; + + DS_ASSERT_OK(CreateClient(0, client)); + + const uint maxStreamSize = 10 * MB; + const uint pageSize = 64 * KB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = pageSize; + + std::vector> producers; + int producerCount = 5; + for (int i = 0; i < producerCount; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(streamName, producer, conf)); + producers.emplace_back(std::move(producer)); + } + + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + DS_ASSERT_OK(client->Subscribe(streamName, config, consumer)); + + const size_t testDataSize = pageSize * 100; + const size_t sizeElement = 512; + const size_t numElements = testDataSize / sizeElement / producers.size(); + std::string writeElement = RandomData().GetRandomString(sizeElement); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ThreadPool pool(producers.size() + 1); + std::vector> futs; + std::atomic_bool failed{ false }; + for (auto &producer : producers) { + futs.push_back(pool.Submit([producer, element, numElements, &failed] { + for (size_t i = 0; i < numElements; i++) { + if (failed) { + return Status::OK(); + } + // send with 0 timeout. + Status rc = producer->Send(element, 0); + if (rc.IsError()) { + failed = true; + LOG(ERROR) << "Send failed with:" << rc.ToString(); + return rc; + } + } + return Status::OK(); + })); + } + const size_t recvNumElements = numElements * producers.size(); + futs.push_back(pool.Submit([&consumer, recvNumElements, &failed] { + Timer timer; + size_t remaining = recvNumElements; + int round = 0; + const int PER_RECEIVE_NUM = 1; + const int DEFAULT_WAIT_TIME = 1000; + const int DEFAULT_RETRY_TIME = 30; + while (remaining > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + if (failed) { + return Status::OK(); + } + std::vector outElements; + Status rc = consumer->Receive(PER_RECEIVE_NUM, DEFAULT_WAIT_TIME, outElements); + if (rc.IsError()) { + failed = true; + LOG(ERROR) << "Receive failed with:" << rc.ToString(); + return rc; + } + LOG(INFO) << "remaining num : " << remaining << ", receive num : " << outElements.size() << " ;" << round++; + if (!outElements.empty()) { + remaining -= outElements.size(); + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + } + } + CHECK_FAIL_RETURN_STATUS(remaining == 0, K_RUNTIME_ERROR, "failed to receive all data"); + return Status::OK(); + })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + ASSERT_TRUE(!failed); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/pub_sub_complex_test.cpp b/tests/st/client/stream_cache/pub_sub_complex_test.cpp new file mode 100644 index 0000000..e490c0c --- /dev/null +++ b/tests/st/client/stream_cache/pub_sub_complex_test.cpp @@ -0,0 +1,313 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include + +#include + +#include "common.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/thread_pool.h" +#include "common/stream_cache/stream_common.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class PubSubComplexTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 2; + opts.numEtcd = 1; + opts.workerGflagParams = " -page_size=" + std::to_string(PAGE_SIZE); + opts.vLogLevel = 2; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client0_ = nullptr; + client1_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + void InitTest() + { + HostPort workerAddress0; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress0)); + HostPort workerAddress1; + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress1)); + LOG(INFO) << FormatString("\n Worker1: <%s>\n Worker2: <%s>\n", workerAddress0.ToString(), + workerAddress1.ToString()); + + InitStreamClient(0, client0_); + InitStreamClient(1, client1_); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + std::shared_ptr client0_ = nullptr; + std::shared_ptr client1_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(PubSubComplexTest, MSBiDirectionRpc) +{ + const int streamNum = 16; + ThreadPool pool(streamNum); + for (int i = 0; i < streamNum; ++i) { + pool.Submit([this, i]() { + std::string streamName = "MSBiDirectionRpc" + std::to_string(i); + SubscriptionConfig config("sub_core", SubscriptionType::STREAM); + std::shared_ptr n1c0; + DS_ASSERT_OK(client1_->Subscribe(streamName, config, n1c0)); + + std::shared_ptr producer; + DS_ASSERT_OK(client0_->CreateProducer(streamName, producer, defaultProducerConf_)); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(n1c0->Close()); + DS_ASSERT_OK(client0_->DeleteStream(streamName)); + }); + } +} + +TEST_F(PubSubComplexTest, DISABLED_TestProducerConsumerTiming) +{ + // This is a general test of ordering events. + std::shared_ptr prod1, prod2, prod3, prod4; + std::shared_ptr localCon, remoteCon; + DS_ASSERT_OK(client0_->CreateProducer("test_stream", prod1, defaultProducerConf_)); + DS_ASSERT_OK(client0_->CreateProducer("test_stream", prod2, defaultProducerConf_)); + DS_ASSERT_OK(client0_->CreateProducer("test_stream", prod3, defaultProducerConf_)); + DS_ASSERT_OK(client0_->CreateProducer("test_stream", prod4, defaultProducerConf_)); + + SubscriptionConfig config1("localSub", SubscriptionType::STREAM); + SubscriptionConfig config2("remoteSub", SubscriptionType::STREAM); + // start remote consumer immediately + DS_ASSERT_OK(client1_->Subscribe("test_stream", config2, remoteCon)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + + // send by all four producers + DS_ASSERT_OK(prod1->Send(element)); + + DS_ASSERT_OK(prod2->Send(element)); + + // start local consumer after two flush. + // No rows have been ack back. So its starting point is the same + // as the remote consumer. + DS_ASSERT_OK(client0_->Subscribe("test_stream", config1, localCon)); + + DS_ASSERT_OK(prod3->Send(element)); + + DS_ASSERT_OK(prod4->Send(element)); + + std::vector outElements; + DS_ASSERT_OK(remoteCon->Receive(4, -1, outElements)); + ASSERT_EQ(outElements.size(), (uint32_t) 4); + + DS_ASSERT_OK(localCon->Receive(0, outElements)); + ASSERT_EQ(outElements.size(), (uint32_t) 4); +} + +TEST_F(PubSubComplexTest, DISABLED_TestProducerOrderInFlush) +{ + // Since we have no local consumer, Ack should not be a part of Flush operation. + // Therefore, No error should appear from order of Ack inside Flush. + // Create 5 concurrent producers in a stream, they will flush in different oreder. + int producerNum = 5; + ThreadPool pool(producerNum); + for (int i = 0; i < producerNum; i++) { + pool.Submit([this, i]() { + std::shared_ptr producer; + DS_ASSERT_OK(client0_->CreateProducer("test_stream", producer, defaultProducerConf_)); + + std::string data = "Hello World " + std::to_string(i); + Element element(reinterpret_cast(&data.front()), data.size()); + if (i % 2 == 0) { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.stream.sleep_while_flush", "sleep(500)")); + } else { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.stream.sleep_while_flush", "sleep(100)")); + } + DS_ASSERT_OK(producer->Send(element)); + }); + } +} + +TEST_F(PubSubComplexTest, CreateAndCloseProducerWithinSameStream) +{ + const size_t testNum = 5; + std::vector> producerList(testNum); + for (auto &producer : producerList) { + DS_ASSERT_OK(client0_->CreateProducer("CreateAndClose", producer, defaultProducerConf_)); + } + for (auto &producer : producerList) { + DS_ASSERT_OK(producer->Close()); + } +} + +TEST_F(PubSubComplexTest, CreateAndCloseProducerWithinDiffStream) +{ + const size_t testNum = 20; + const size_t producerOneStream = 5; + std::vector> producerList(testNum); + size_t cnt = 0; + std::string streamName; + for (auto &producer : producerList) { + streamName = "test" + std::to_string(cnt / producerOneStream); + DS_ASSERT_OK(client0_->CreateProducer(streamName, producer, defaultProducerConf_)); + cnt++; + } + for (auto &producer : producerList) { + DS_ASSERT_OK(producer->Close()); + } +} + +TEST_F(PubSubComplexTest, SubscribeInStreamMode) +{ + const size_t testNum = 5; + const uint64_t mockId = 0; + std::vector> consumerList(testNum); + for (size_t i = 0; i < testNum; ++i) { + std::string subName = "sub" + std::to_string(i); + SubscriptionConfig config(subName, SubscriptionType::STREAM); + DS_ASSERT_OK(client0_->Subscribe("testSubscribeInStreamMode", config, consumerList[i])); + } + for (size_t i = 0; i < testNum; ++i) { + DS_ASSERT_OK(consumerList[i]->Close()); + DS_EXPECT_NOT_OK(consumerList[i]->Ack(mockId)); + } +} + +TEST_F(PubSubComplexTest, SubscribeDuplicateStreamMode) +{ + const size_t testNum = 2; + std::vector> consumerList(testNum); + + std::string subName = "sub1"; + SubscriptionConfig config(subName, SubscriptionType::STREAM); + DS_EXPECT_OK(client0_->Subscribe("DuplicateSub", config, consumerList[0])); + DS_EXPECT_NOT_OK(client0_->Subscribe("DuplicateSub", config, consumerList[1])); + DS_EXPECT_OK(consumerList[0]->Close()); +} + +TEST_F(PubSubComplexTest, MultiStreamInMixMode) +{ + // Create 2 Streams and 4 Producers, each Stream has 2 Producers + const size_t streamNum = 2; + std::vector streamNameList; + std::vector> producerList(streamNum * 2); + + for (size_t i = 0; i < streamNum; ++i) { + streamNameList.emplace_back("MultiStreamInMixMode" + std::to_string(i)); + DS_EXPECT_OK(client0_->CreateProducer(streamNameList.back(), producerList[2 * i], defaultProducerConf_)); + DS_EXPECT_OK(client0_->CreateProducer(streamNameList.back(), producerList[2 * i + 1], defaultProducerConf_)); + } + + // Now we create 5 Consumers, they will have relationship with 2 stream + std::vector configList; + + // First 2 Consumers sub with Stream and in STREAM mode + configList.emplace_back(SubscriptionConfig("sub0", SubscriptionType::ROUND_ROBIN)); + configList.emplace_back(SubscriptionConfig("sub1", SubscriptionType::STREAM)); + + std::vector> consumerList(streamNum); + for (size_t i = 0; i < consumerList.size(); ++i) { + DS_ASSERT_OK(client0_->Subscribe(streamNameList[i], configList[1], consumerList[i])); + } + // The third one sub with Stream Subscription in STREAM mode, hence fail + std::shared_ptr dupStreamConsumer; + DS_ASSERT_NOT_OK(client0_->Subscribe(streamNameList[0], configList[1], dupStreamConsumer)); + + // The fourth one sub with Stream in QUEUE mode, hence fail + std::shared_ptr queueConsumer; + DS_ASSERT_NOT_OK(client0_->Subscribe(streamNameList[1], configList[0], queueConsumer)); + + for (auto &producer : producerList) { + DS_EXPECT_OK(producer->Close()); + // Second close Should be no-op if close one Producer twice + DS_EXPECT_OK(producer->Close()); + } + for (auto &consumer : consumerList) { + DS_ASSERT_OK(consumer->Close()); + // Second close should be no-op if close one Consumer twice + DS_ASSERT_OK(consumer->Close()); + } +} + +TEST_F(PubSubComplexTest, TestSyncSubTableRetry) +{ + // Maintain a connection from worker0 to worker1 by creating a dummy stream s2. + std::string s2("dummyStream"); + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1_->Subscribe(s2, config2, consumer2)); + std::shared_ptr producer2; + DS_ASSERT_OK(client0_->CreateProducer(s2, producer2, defaultProducerConf_)); + + // Create a stream s1 from worker0 to worker1 + std::string s1("SyncSubTableRetry"); + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1_->Subscribe(s1, config1, consumer1)); + std::shared_ptr producer1; + DS_ASSERT_OK(client0_->CreateProducer(s1, producer1, defaultProducerConf_)); + + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "RemoteWorker.SleepBeforeAsyncRead", "call(500)")); + const size_t numElements = 1000; + const size_t elementLength = 1024; + std::string data(elementLength, 'a'); + Element ele(reinterpret_cast(data.data()), data.size()); + for (size_t i = 0; i < numElements; ++i) { + DS_ASSERT_OK(producer1->Send(ele)); + } + producer1->Close(); + consumer1->Close(); + client1_->DeleteStream(s1); + + // Start all over again. Same stream name + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "RemoteWorker.SleepBeforeAsyncRead")); + DS_ASSERT_OK(client1_->Subscribe(s1, config1, consumer1)); + DS_ASSERT_OK(client0_->CreateProducer(s1, producer1, defaultProducerConf_)); + for (size_t i = 0; i < numElements; ++i) { + DS_ASSERT_OK(producer1->Send(ele)); + } + std::vector out; + out.reserve(numElements); + const uint64_t timeoutMs = 60'000; + DS_ASSERT_OK(consumer1->Receive(numElements, timeoutMs, out)); + ASSERT_EQ(out.size(), numElements); + producer1->Close(); + consumer1->Close(); + client1_->DeleteStream(s1); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/pub_sub_multinode_test.cpp b/tests/st/client/stream_cache/pub_sub_multinode_test.cpp new file mode 100644 index 0000000..1737594 --- /dev/null +++ b/tests/st/client/stream_cache/pub_sub_multinode_test.cpp @@ -0,0 +1,977 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Remote send test. + */ +#include +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +#define MULTI_NODE +#ifdef MULTI_NODE +constexpr int K_TWO = 2; +class PubSubMultiNodeTest : public SCClientCommon { +#else +class PubSubMultiNode : public CommonTest { +#endif +public: +#ifdef MULTI_NODE + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; +#endif + void SetUp() override; + + void TearDown() override; + + static std::once_flag onceFlag_; + +protected: + // Mock producer worker. + HostPort w1Addr_; + HostPort w2Addr_; + HostPort w3Addr_; + + Status TryAndDeleteStream(std::shared_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + std::shared_ptr w1Client_ = nullptr; + std::shared_ptr w2Client_ = nullptr; + std::shared_ptr w3Client_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + const uint32_t WORKER_COUNT = 3; +}; +std::once_flag PubSubMultiNodeTest::onceFlag_; + +#ifdef MULTI_NODE +void PubSubMultiNodeTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = WORKER_COUNT; + opts.workerGflagParams = " -page_size=" + std::to_string(PAGE_SIZE); + opts.numRpcThreads = 16; + SCClientCommon::SetClusterSetupOptions(opts); +} +#endif + +void PubSubMultiNodeTest::SetUp() +{ +#ifdef MULTI_NODE + ExternalClusterTest::SetUp(); + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, w1Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, w2Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(2, w3Addr_)); +#else + w1Addr_ = HostPort("127.0.0.1", 2295); + w2Addr_ = HostPort("127.0.0.1", 11589); + w3Addr_ = HostPort("127.0.0.1", 8666); +#endif + InitStreamClient(0, w1Client_); + InitStreamClient(1, w2Client_); + InitStreamClient(2, w3Client_); // index is 2 + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; +} + +void PubSubMultiNodeTest::TearDown() +{ + w1Client_ = nullptr; + w2Client_ = nullptr; + w3Client_ = nullptr; +#ifdef MULTI_NODE + ExternalClusterTest::TearDown(); +#endif +} + +TEST_F(PubSubMultiNodeTest, StreamModeTwoConsumer) +{ + std::string streamName = "testStreamMode2Con"; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + ASSERT_EQ(w2Client_->Subscribe(streamName, config, consumer), Status::OK()); + + std::shared_ptr consumer2; + w1Client_->Subscribe(streamName, config, consumer2); + ASSERT_TRUE(w1Client_->Subscribe(streamName, config, consumer2) != Status::OK()); +} + +TEST_F(PubSubMultiNodeTest, PubCloseFirst) +{ + std::string streamName = "testPubCloseFirst"; + ThreadPool pool(5); + std::promise promise; + std::future subFut = promise.get_future(); + std::vector> futs; + futs.emplace_back(pool.Submit([this, &subFut, streamName]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + subFut.get(); + return producer->Close(); + })); + futs.emplace_back(pool.Submit([this, &subFut, streamName]() { + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w1Client_->Subscribe(streamName, config2, consumer2)); + Status status; + sleep(1); // wait for pub to close first + status = consumer2->Close(); + return status; + })); + futs.emplace_back(pool.Submit([this, &promise, streamName]() { + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promise.set_value(); + sleep(1); // wait for pub to close first + Status status; + status = consumer->Close(); + LOG(INFO) << FormatString("%s", status.ToString()); + return status; + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); +} + +TEST_F(PubSubMultiNodeTest, AvoidMissAddRemote) +{ + std::string strmName = "testAvoidMissAddRemote"; + ThreadPool pool(10); + for (int i = 0; i < 10; i++) { + pool.Submit([this, i, strmName]() { + ThreadPool pool(5); + std::vector> promises(2); + std::vector> subFuts; + for (auto &promise : promises) { + subFuts.emplace_back(promise.get_future()); + } + std::vector> futs; + std::string streamName = strmName + std::to_string(i); + std::shared_ptr producer; + + futs.emplace_back(pool.Submit([this, &subFuts, &producer, streamName]() { + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + for (auto subFut : subFuts) { + subFut.get(); + } + std::string str(10, 'c'); + Element e; + e.ptr = reinterpret_cast(const_cast(str.data())); + e.size = str.size(); + RETURN_IF_NOT_OK(producer->Send(e)); + return Status::OK(); + })); + futs.emplace_back(pool.Submit([this, streamName, &promises]() { + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w1Client_->Subscribe(streamName, config2, consumer2)); + promises[0].set_value(); + std::vector elements; + RETURN_IF_NOT_OK(consumer2->Receive(1, 1'000, elements)); + return consumer2->Close(); + })); + futs.emplace_back(pool.Submit([this, streamName, &promises]() { + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer; + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promises[1].set_value(); + std::vector elements; + RETURN_IF_NOT_OK(consumer->Receive(1, 1'000, elements)); + return consumer->Close(); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } + ASSERT_EQ(producer->Close(), Status::OK()); + ASSERT_EQ(TryAndDeleteStream(w1Client_, streamName), Status::OK()); + }); + } +} + +TEST_F(PubSubMultiNodeTest, TestCreateOrderSingleNode) +{ + const std::string stream1 = "testCreateOrder_sameNode_s1"; + // Create consumer first, then producer + std::shared_ptr con1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config1, con1)); + std::shared_ptr prod1; + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + + // Switch the order. Producer first, then consumer + const std::string stream2 = "testCreateOrder_sameNode_s2"; + std::shared_ptr prod2; + DS_ASSERT_OK(w1Client_->CreateProducer(stream2, prod2, defaultProducerConf_)); + std::shared_ptr con2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config2, con2)); + + // Sanity test by sending/receiving one elemnt + RandomData rand; + auto str = rand.GetRandomString(defaultProducerConf_.pageSize); + DS_ASSERT_OK(prod1->Send(Element((uint8_t *)str.data(), str.size()))); + std::vector outElements; + DS_ASSERT_OK(con1->Receive(RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + con1->Ack(outElements[0].id); + std::string res1(reinterpret_cast(outElements[0].ptr), outElements[0].size); + DS_ASSERT_TRUE(res1, str); + + DS_ASSERT_OK(prod2->Send(Element((uint8_t *)str.data(), str.size()))); + outElements.clear(); + DS_ASSERT_OK(con2->Receive(RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + con2->Ack(outElements[0].id); + std::string res2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + DS_ASSERT_TRUE(res2, str); + + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(prod2->Close()); + DS_ASSERT_OK(con1->Close()); + DS_ASSERT_OK(con2->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); + DS_ASSERT_OK(w1Client_->DeleteStream(stream2)); +} + +TEST_F(PubSubMultiNodeTest, TestCreateOrderCrossNode) +{ + const std::string stream1 = "testCreateOrder_diffNode_s1"; + // Create consumer first, then producer + std::shared_ptr con1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config1, con1)); + std::shared_ptr prod1; + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + + // Switch the order. Producer first, then consumer + const std::string stream2 = "testCreateOrder_diffNode_s2"; + std::shared_ptr prod2; + DS_ASSERT_OK(w1Client_->CreateProducer(stream2, prod2, defaultProducerConf_)); + std::shared_ptr con2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream2, config2, con2)); + + // Sanity test by sending/receiving one elemnt + RandomData rand; + auto str = rand.GetRandomString(defaultProducerConf_.pageSize); + DS_ASSERT_OK(prod1->Send(Element((uint8_t *)str.data(), str.size()))); + std::vector outElements; + DS_ASSERT_OK(con1->Receive(RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + con1->Ack(outElements[0].id); + std::string res1(reinterpret_cast(outElements[0].ptr), outElements[0].size); + DS_ASSERT_TRUE(res1, str); + + DS_ASSERT_OK(prod2->Send(Element((uint8_t *)str.data(), str.size()))); + outElements.clear(); + DS_ASSERT_OK(con2->Receive(RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + con2->Ack(outElements[0].id); + std::string res2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + DS_ASSERT_TRUE(res2, str); + + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(prod2->Close()); + DS_ASSERT_OK(con1->Close()); + DS_ASSERT_OK(con2->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); + DS_ASSERT_OK(w1Client_->DeleteStream(stream2)); +} + +TEST_F(PubSubMultiNodeTest, TestCreateOrderSingleNodeOOM) +{ + RandomData rand; + auto str = rand.GetRandomString(defaultProducerConf_.pageSize); + + Status rc; + const std::string stream1 = "testCreateOrder_sameNodeOOM"; + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "CreatePageZero.AllocMemory", + "1*return(K_OUT_OF_MEMORY)")); + // Create consumer first, then producer + std::shared_ptr con1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config1, con1)); + std::shared_ptr prod1; + rc = w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_); + // This should fail with OOM the first time. + ASSERT_EQ(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); + // Try again should be successful + DS_ASSERT_OK(cluster_->ClearInjectAction(ClusterNodeType::WORKER, 0, "CreatePageZero.AllocMemory")); + prod1.reset(); + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + DS_ASSERT_OK(prod1->Send(Element((uint8_t *)str.data(), str.size()))); + std::vector outElements; + DS_ASSERT_OK(con1->Receive(RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + con1->Ack(outElements[0].id); + std::string res1(reinterpret_cast(outElements[0].ptr), outElements[0].size); + DS_ASSERT_TRUE(res1, str); + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(con1->Close()); +} + +TEST_F(PubSubMultiNodeTest, TestCreateOrderCrossNodeOOM) +{ + RandomData rand; + auto str = rand.GetRandomString(defaultProducerConf_.pageSize); + Status rc; + const std::string stream1 = "testCreateOrder_diffNodeOOM_s1"; + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "CreatePageZero.AllocMemory", + "1*return(K_OUT_OF_MEMORY)")); + // Create consumer first, then producer + std::shared_ptr con1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + // We don't know the page size at this point. So we can't reserve the memory. + // So subscribe should be successful. + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config1, con1)); + const int MULTIPLIER = 16; + defaultProducerConf_.maxStreamSize = MULTIPLIER * defaultProducerConf_.pageSize; + defaultProducerConf_.reserveSize = defaultProducerConf_.maxStreamSize; + std::shared_ptr prod1; + rc = w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_); + // This should fail with OOM when the remote worker got the topo change, and return OOM to the producer + ASSERT_EQ(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); + DS_ASSERT_OK(cluster_->ClearInjectAction(ClusterNodeType::WORKER, 1, "CreatePageZero.AllocMemory")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "ReserveAdditionalMemory.AllocMemory", + "2*return(K_OK)->return(K_OUT_OF_MEMORY)")); + // This should fail with OOM when the remote worker got the topo change, and return OOM to the producer + prod1.reset(); + rc = w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_); + ASSERT_EQ(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); + // Try again should be successful + DS_ASSERT_OK(cluster_->ClearInjectAction(ClusterNodeType::WORKER, 1, "ReserveAdditionalMemory.AllocMemory")); + prod1.reset(); + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + DS_ASSERT_OK(prod1->Send(Element((uint8_t *)str.data(), str.size()))); + std::vector outElements; + DS_ASSERT_OK(con1->Receive(RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + con1->Ack(outElements[0].id); + std::string res1(reinterpret_cast(outElements[0].ptr), outElements[0].size); + DS_ASSERT_TRUE(res1, str); + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(con1->Close()); + + // Switch the order. Producer first, then consumer + const std::string stream2 = "testCreateOrder_diffNodeOOM_s2"; + std::shared_ptr prod2; + DS_ASSERT_OK(w1Client_->CreateProducer(stream2, prod2, defaultProducerConf_)); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "CreatePageZero.AllocMemory", + "1*return(K_OUT_OF_MEMORY)")); + std::shared_ptr con2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + // Subscribe should pick up the page size and returns OOM. + rc = w2Client_->Subscribe(stream2, config2, con2); + // This should fail with OOM + ASSERT_EQ(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); +} + +TEST_F(PubSubMultiNodeTest, TestMultiStreamsOOM) +{ + Status rc; + int i = 0; + std::unordered_map> prodList; + while (rc.IsOk()) { + std::string streamName = FormatString("testMultiStreamsOOM%d", i); + std::shared_ptr prod; + // Each producer will reserve one page of memory. + rc = w1Client_->CreateProducer(streamName, prod, defaultProducerConf_); + if (rc.IsOk()) { + LOG(INFO) << FormatString("[%s] Create producer success", streamName); + prodList.emplace(streamName, std::move(prod)); + ++i; + continue; + } + // Expect we will run out of resources at some point + DS_ASSERT_TRUE(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); + } + + // Expect prodList is not empty. + ASSERT_TRUE(!prodList.empty()); + // Page size is 1m and shared_memory_size_mb is 64m. + // Maximum number of streams we can create is 64 but there can be some overhead and + // in reality we create less than the maximum. + const int maxProducer = 64; + ASSERT_TRUE(prodList.size() <= maxProducer); + + // Delete one of the stream. We expect the stream memory is released and we can create one more. + auto iter = prodList.begin(); + auto oneProducer = std::move(iter->second); + auto oneStreamName = iter->first; + prodList.erase(oneStreamName); + DS_ASSERT_OK(oneProducer->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(oneStreamName)); + LOG(INFO) << FormatString("[%s] delete success", oneStreamName); + + std::string streamName = FormatString("stream%d", i++); + std::shared_ptr prod; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, prod, defaultProducerConf_)); + LOG(INFO) << FormatString("[%s] Create producer success", streamName); + prodList.emplace(streamName, std::move(prod)); + + for (auto &ele: prodList) { + DS_ASSERT_OK(ele.second->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(ele.first)); + } +} + +TEST_F(PubSubMultiNodeTest, TestMultiProducersUndo) +{ + defaultProducerConf_.retainForNumConsumers = 1; + for (uint32_t i = 0; i < WORKER_COUNT; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, i, + "master.PubIncreaseNodeImpl.beforeSendNotification", + "1*return(K_RUNTIME_ERROR)")); + } + const std::string stream1 = "testMultiProdUndo"; + const int NUM_THREADS = 2; + const int NUM_ELEMENTS = 100; + const int ELEMENT_SIZE = 48; + ThreadPool pool(NUM_THREADS); + std::vector> futs; + for (int i = 0; i < NUM_THREADS; ++i) { + futs.emplace_back(pool.Submit([this, &stream1]() { + std::shared_ptr prod; + Status rc = w1Client_->CreateProducer(stream1, prod, defaultProducerConf_); + if (rc.IsError()) { return rc; } + RandomData rand; + auto str = rand.GetRandomString(ELEMENT_SIZE); + for (int k = 0; k < NUM_ELEMENTS; ++k) { + rc = prod->Send(Element((uint8_t *)str.data(), str.size())); + if (rc.IsError()) { return rc; } + } + return Status::OK(); + })); + } + + std::shared_ptr con; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config, con)); + + std::vector status; + for (int i = 0; i < NUM_THREADS; ++i) { + status.push_back(futs[i].get()); + } + // Expect exactly one of the producer should hit K_RUNTIME_ERROR; + if (status[0].IsOk()) { + ASSERT_TRUE(status[1].GetCode() == K_RUNTIME_ERROR); + } else { + ASSERT_TRUE(status[0].GetCode() == K_RUNTIME_ERROR); + ASSERT_TRUE(status[1].IsOk()); + } + + std::vector outElements; + DS_ASSERT_OK(con->Receive(NUM_ELEMENTS, RPC_TIMEOUT, outElements)); + DS_ASSERT_TRUE(outElements.size(), NUM_ELEMENTS); + DS_ASSERT_OK(con->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); + sleep(1); +} + +TEST_F(PubSubMultiNodeTest, BigElement2S2P2C) +{ + // 2 streams: 2 producers -> 2 consumers for each stream. + const std::string stream1 = "testBigEle2S2P2C_s1"; + const std::string stream2 = "testBigEle2S2P2C_s2"; + defaultProducerConf_.pageSize = 4 * KB; + const int timeout = 10000; + + // Create the 4 producers. + std::shared_ptr prod11; // client1, stream 1, producer 1 + std::shared_ptr prod12; // client2, stream 1, producer 2 + std::shared_ptr prod21; // client3, stream 2, producer 1 + std::shared_ptr prod22; // client2, stream 2, producer 2 + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod11, defaultProducerConf_)); + DS_ASSERT_OK(w2Client_->CreateProducer(stream1, prod12, defaultProducerConf_)); + DS_ASSERT_OK(w3Client_->CreateProducer(stream2, prod21, defaultProducerConf_)); + DS_ASSERT_OK(w3Client_->CreateProducer(stream2, prod22, defaultProducerConf_)); + + // Create the 4 consumers. + std::shared_ptr con11; // client 2, stream 1, consumer 1 + std::shared_ptr con12; // client 3, stream 1, consumer 2 + std::shared_ptr con21; // client 1, stream 2, consumer 1 + std::shared_ptr con22; // client 1, stream 2, consumer 2 + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config1, con11)); + DS_ASSERT_OK(w3Client_->Subscribe(stream1, config2, con12)); + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config1, con21)); + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config2, con22)); + + // Create big element. + const uint64_t numOfElementPerProducer = 50; + const uint64_t expectNumOfElementReceivePerConsumer = numOfElementPerProducer * 2; + const uint64_t elementSize = 8 * KB; + RandomData rand; + auto str = rand.GetRandomString(elementSize); + + // Each producer send 50 big elements. + for (uint64_t i = 1; i <= numOfElementPerProducer; i++) { + Element element(reinterpret_cast(&str.front()), elementSize); + ASSERT_EQ(prod11->Send(element), Status::OK()); + ASSERT_EQ(prod12->Send(element), Status::OK()); + ASSERT_EQ(prod21->Send(element), Status::OK()); + ASSERT_EQ(prod22->Send(element), Status::OK()); + } + + // Each consumer receive 100 big elements. + std::vector outElements; + DS_ASSERT_OK(con11->Receive(expectNumOfElementReceivePerConsumer, timeout, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceivePerConsumer); + outElements.clear(); + DS_ASSERT_OK(con12->Receive(expectNumOfElementReceivePerConsumer, timeout, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceivePerConsumer); + outElements.clear(); + DS_ASSERT_OK(con21->Receive(expectNumOfElementReceivePerConsumer, timeout, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceivePerConsumer); + outElements.clear(); + DS_ASSERT_OK(con22->Receive(expectNumOfElementReceivePerConsumer, timeout, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceivePerConsumer); + outElements.clear(); + + DS_ASSERT_OK(prod11->Close()); + DS_ASSERT_OK(prod12->Close()); + DS_ASSERT_OK(prod21->Close()); + DS_ASSERT_OK(prod22->Close()); + DS_ASSERT_OK(con11->Close()); + DS_ASSERT_OK(con12->Close()); + DS_ASSERT_OK(con21->Close()); + DS_ASSERT_OK(con22->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); + DS_ASSERT_OK(w1Client_->DeleteStream(stream2)); +} + +class PubSubMultiNodeDataVerificationTest : public PubSubMultiNodeTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + PubSubMultiNodeTest::SetClusterSetupOptions(opts); + opts.workerGflagParams += " -v=2 -enable_stream_data_verification=true "; + } +}; + +TEST_F(PubSubMultiNodeDataVerificationTest, 2S2P2C) +{ + for (uint32_t i = 0; i < WORKER_COUNT; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "master.PubIncreaseNode.initProducerno", + FormatString("1*call(%zu)", INT64_MAX))); + } + // 2 streams: 2 producers -> 2 consumers for each stream. + const std::string stream1 = "test2S2P2C_s1"; + const std::string stream2 = "test2S2P2C_s2"; + + // Create the 4 producers. + std::shared_ptr prod11; // client1, stream 1, producer 1 + std::shared_ptr prod12; // client2, stream 1, producer 2 + std::shared_ptr prod21; // client3, stream 2, producer 1 + std::shared_ptr prod22; // client2, stream 2, producer 2 + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod11, defaultProducerConf_)); + DS_ASSERT_OK(w2Client_->CreateProducer(stream1, prod12, defaultProducerConf_)); + DS_ASSERT_OK(w3Client_->CreateProducer(stream2, prod21, defaultProducerConf_)); + DS_ASSERT_OK(w3Client_->CreateProducer(stream2, prod22, defaultProducerConf_)); + + // Create the 4 consumers. + std::shared_ptr con11; // client 2, stream 1, consumer 1 + std::shared_ptr con12; // client 3, stream 1, consumer 2 + std::shared_ptr con21; // client 1, stream 2, consumer 1 + std::shared_ptr con22; // client 1, stream 2, consumer 2 + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config1, con11)); + DS_ASSERT_OK(w3Client_->Subscribe(stream1, config2, con12)); + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config1, con21)); + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config2, con22)); + + // Each producer send 51 elements, the 51th element is out of order. + const uint64_t outOfOrderElementNo = 51; + const uint64_t expectNumOfElementReceive = (outOfOrderElementNo - 1) * 2; + const uint64_t elementSize = defaultProducerConf_.pageSize / 10; + RandomData rand; + auto str = rand.GetRandomString(elementSize); + + // Each producer send the first 50 elements. + for (uint64_t i = 1; i < outOfOrderElementNo; i++) { + Element element(reinterpret_cast(&str.front()), elementSize); + ASSERT_EQ(prod11->Send(element), Status::OK()); + ASSERT_EQ(prod12->Send(element), Status::OK()); + ASSERT_EQ(prod21->Send(element), Status::OK()); + ASSERT_EQ(prod22->Send(element), Status::OK()); + } + + // The first 100 elements received should be in order. + std::vector outElements; + DS_ASSERT_OK(con11->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + for (auto &ele : outElements) { + std::string actualData(reinterpret_cast(ele.ptr), ele.size); + EXPECT_EQ(str, actualData); + } + outElements.clear(); + DS_ASSERT_OK(con12->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + outElements.clear(); + DS_ASSERT_OK(con21->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + outElements.clear(); + DS_ASSERT_OK(con22->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + outElements.clear(); + + // Send the out of order element. + datasystem::inject::Set("DataVerificationOutOfOrder", "call(1)"); + Element element(reinterpret_cast(&str.front()), elementSize); + ASSERT_EQ(prod11->Send(element), Status::OK()); + ASSERT_EQ(prod12->Send(element), Status::OK()); + ASSERT_EQ(prod21->Send(element), Status::OK()); + ASSERT_EQ(prod22->Send(element), Status::OK()); + datasystem::inject::Clear("DataVerificationOutOfOrder"); + + // Receiving the elements with incorrect sequence number. + ASSERT_EQ(con11->Receive(1, RPC_TIMEOUT, outElements).GetCode(), K_DATA_INCONSISTENCY); + outElements.clear(); + ASSERT_EQ(con12->Receive(1, RPC_TIMEOUT, outElements).GetCode(), K_DATA_INCONSISTENCY); + outElements.clear(); + ASSERT_EQ(con21->Receive(1, RPC_TIMEOUT, outElements).GetCode(), K_DATA_INCONSISTENCY); + outElements.clear(); + ASSERT_EQ(con22->Receive(1, RPC_TIMEOUT, outElements).GetCode(), K_DATA_INCONSISTENCY); + outElements.clear(); + + DS_ASSERT_OK(prod11->Close()); + DS_ASSERT_OK(prod12->Close()); + DS_ASSERT_OK(prod21->Close()); + DS_ASSERT_OK(prod22->Close()); + DS_ASSERT_OK(con11->Close()); + DS_ASSERT_OK(con12->Close()); + DS_ASSERT_OK(con21->Close()); + DS_ASSERT_OK(con22->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); + DS_ASSERT_OK(w1Client_->DeleteStream(stream2)); +} + +TEST_F(PubSubMultiNodeDataVerificationTest, BigElement2S2P1C) +{ + // 2 streams: 2 producers -> 2 consumers for each stream. + const std::string stream1 = "testBigEle2S2P1C_s1"; + const std::string stream2 = "testBigEle2S2P1C_s2"; + defaultProducerConf_.pageSize = 4 * KB; + + // Create the 4 producers. + std::shared_ptr prod11; // client1, stream 1, producer 1 + std::shared_ptr prod12; // client2, stream 1, producer 2 + std::shared_ptr prod21; // client3, stream 2, producer 1 + std::shared_ptr prod22; // client2, stream 2, producer 2 + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod11, defaultProducerConf_)); + DS_ASSERT_OK(w2Client_->CreateProducer(stream1, prod12, defaultProducerConf_)); + DS_ASSERT_OK(w3Client_->CreateProducer(stream2, prod21, defaultProducerConf_)); + DS_ASSERT_OK(w3Client_->CreateProducer(stream2, prod22, defaultProducerConf_)); + + // Create the 4 consumers. + std::shared_ptr con1; // client 2, stream 1, consumer 1 + std::shared_ptr con2; // client 1, stream 2, consumer 1 + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client_->Subscribe(stream1, config1, con1)); + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config1, con2)); + + // Each producer send 51 elements, the 51th element is out of order. + const uint64_t outOfOrderElementNo = 50; + const uint64_t expectNumOfElementReceive = (outOfOrderElementNo - 1) * 2; + const uint64_t elementSize = 8 * KB; + RandomData rand; + auto str = rand.GetRandomString(elementSize); + + // Each producer send the first 50 elements. + for (uint64_t i = 1; i < outOfOrderElementNo; i++) { + Element element(reinterpret_cast(&str.front()), elementSize); + ASSERT_EQ(prod11->Send(element), Status::OK()); + ASSERT_EQ(prod12->Send(element), Status::OK()); + ASSERT_EQ(prod21->Send(element), Status::OK()); + ASSERT_EQ(prod22->Send(element), Status::OK()); + } + + // The first 100 elements received should be in order. + std::vector outElements; + DS_ASSERT_OK(con1->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + for (auto &ele : outElements) { + std::string actualData(reinterpret_cast(ele.ptr), ele.size); + EXPECT_EQ(str, actualData); + } + outElements.clear(); + DS_ASSERT_OK(con2->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + outElements.clear(); + + // Send the out of order element. + datasystem::inject::Set("DataVerificationOutOfOrder", "call(1)"); + Element element(reinterpret_cast(&str.front()), elementSize); + ASSERT_EQ(prod11->Send(element), Status::OK()); + ASSERT_EQ(prod12->Send(element), Status::OK()); + ASSERT_EQ(prod21->Send(element), Status::OK()); + ASSERT_EQ(prod22->Send(element), Status::OK()); + datasystem::inject::Clear("DataVerificationOutOfOrder"); + + // Receiving the elements with incorrect sequence number. + ASSERT_EQ(con1->Receive(1, RPC_TIMEOUT, outElements).GetCode(), K_DATA_INCONSISTENCY); + outElements.clear(); + ASSERT_EQ(con2->Receive(1, RPC_TIMEOUT, outElements).GetCode(), K_DATA_INCONSISTENCY); + outElements.clear(); + outElements.clear(); + DS_ASSERT_OK(prod11->Close()); + DS_ASSERT_OK(prod12->Close()); + DS_ASSERT_OK(prod21->Close()); + DS_ASSERT_OK(prod22->Close()); + DS_ASSERT_OK(con1->Close()); + DS_ASSERT_OK(con2->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); + DS_ASSERT_OK(w1Client_->DeleteStream(stream2)); +} + +TEST_F(PubSubMultiNodeDataVerificationTest, LEVEL2_ConsumerSubscribleLater) +{ + // 1 stream: 1 producer -> 2 consumers. + const std::string stream1 = "testConSubscribLater"; + size_t size = 1 * MB; + RandomData rand; + std::string data = rand.GetRandomString(size); + defaultProducerConf_.maxStreamSize = 20 * MB; + defaultProducerConf_.pageSize = 4 * MB; + const int timeout = 10000; + + // Create 1 producer. + std::shared_ptr prod1; + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + + // Create first consumer. + std::shared_ptr con1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config1, con1, true)); + + // The producer send Elements to first consumer. + int send_count = 0; + Status status; + while (true) { + Element element1(reinterpret_cast(&data.front()), data.size()); + status = prod1->Send(element1); + if (status.IsOk()) { + send_count += 1; + } else { + break; + } + } + + // First consumer receive elements. + std::vector outElements; + DS_ASSERT_OK(con1->Receive(send_count, timeout, outElements)); + for (auto &ele : outElements) { + DS_ASSERT_OK(con1->Ack(ele.id)); + } + outElements.clear(); + + // The producer send more elements (existing page on consumer side dropped, element with seqNo = 0 is gone). + send_count = 0; + while (true) { + Element element1(reinterpret_cast(&data.front()), data.size()); + status = prod1->Send(element1); + if (status.IsOk()) { + send_count += 1; + } else { + break; + } + } + + // Create second consumer. + std::shared_ptr con2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config2, con2)); + + // Second consumer receive elements starting with seqNo != 0. + DS_ASSERT_OK(con2->Receive(send_count, timeout, outElements)); + outElements.clear(); + + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(con1->Close()); + DS_ASSERT_OK(con2->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); +} + +TEST_F(PubSubMultiNodeDataVerificationTest, ProducerInsertFailed) +{ + // 1 stream: 1 prodcuer -> 1 consumer + const std::string stream1 = "testProdInsertFailed"; + + // Create 1 producer. + std::shared_ptr prod1; + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + + // Create 1 consumer. + std::shared_ptr con1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(stream1, config1, con1)); + + // The producer send 100 elements, failed to insert element 1 out of 4 times. + const uint64_t numOfElement = 100; + const uint64_t failToInsertFrequency = 4; + const uint64_t expectNumOfElementReceive = numOfElement - (numOfElement / failToInsertFrequency); + const uint64_t elementSize = defaultProducerConf_.pageSize / 10; + RandomData rand; + auto str = rand.GetRandomString(elementSize); + + for (uint64_t i = 1; i <= numOfElement; i++) { + Element element(reinterpret_cast(&str.front()), elementSize); + if (i % failToInsertFrequency == 0) { + datasystem::inject::Set("producer_insert", "1*return(K_INVALID)"); + DS_ASSERT_NOT_OK(prod1->Send(element)); + datasystem::inject::Clear("producer_insert"); + } else { + ASSERT_EQ(prod1->Send(element), Status::OK()); + } + } + + // The consumer receive 75 elements. + std::vector outElements; + DS_ASSERT_OK(con1->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + outElements.clear(); + + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(con1->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); +} + +class PubSubMultiNode1Of2ProducersEnableDataVerificationTest : public PubSubMultiNodeTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + PubSubMultiNodeTest::SetClusterSetupOptions(opts); + opts.workerSpecifyGflagParams[0] += " -enable_stream_data_verification=true "; + opts.workerSpecifyGflagParams[1] += " -enable_stream_data_verification=false "; + } +}; + +TEST_F(PubSubMultiNode1Of2ProducersEnableDataVerificationTest, EnableDataVerification1Of2ProducersDifferentNode) +{ + // 1 streams: 2 producers -> 1 consumers. + const std::string stream1 = "testDataVerify1Of2ProdDiffNode"; + + // Create the 2 producers. + std::shared_ptr prod1; // client1, stream 1, producer 1 + std::shared_ptr prod2; // client2, stream 1, producer 2 + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + DS_ASSERT_OK(w2Client_->CreateProducer(stream1, prod2, defaultProducerConf_)); + + // Create the 1 consumers. + std::shared_ptr con1; // client 3, stream 1, consumer 1 + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client_->Subscribe(stream1, config1, con1)); + + // Each producer send 100 elements. + const uint64_t numOfElementPerProducer = 100; + const uint64_t expectNumOfElementReceive = numOfElementPerProducer * 2; + const uint64_t elementSize = defaultProducerConf_.pageSize / 10; + RandomData rand; + auto str = rand.GetRandomString(elementSize); + + for (uint64_t i = 1; i <= numOfElementPerProducer; i++) { + Element element(reinterpret_cast(&str.front()), elementSize); + ASSERT_EQ(prod1->Send(element), Status::OK()); + ASSERT_EQ(prod2->Send(element), Status::OK()); + } + + // The consumer receive 200 elements + std::vector outElements; + DS_ASSERT_OK(con1->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + for (auto &ele : outElements) { + std::string data(reinterpret_cast(ele.ptr), ele.size); + ASSERT_EQ(data, str); + } + outElements.clear(); + + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(prod2->Close()); + DS_ASSERT_OK(con1->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); +} + +TEST_F(PubSubMultiNode1Of2ProducersEnableDataVerificationTest, EnableDataVerification1Of2ProducersSameNode) +{ + // 1 streams: 2 producers -> 1 consumers. + const std::string stream1 = "testDataVerify1Of2ProdSameNode"; + + // Create the 2 producers. + std::shared_ptr prod1; // client1, stream 1, producer 1 + std::shared_ptr prod2; // client2, stream 1, producer 2 + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod1, defaultProducerConf_)); + datasystem::inject::Set("Mimic.Producer.Old.Version", "1*call()"); + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, prod2, defaultProducerConf_)); + datasystem::inject::Clear("Mimic.Producer.Old.Version"); + + // Create the 1 consumers. + std::shared_ptr con1; // client 3, stream 1, consumer 1 + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client_->Subscribe(stream1, config1, con1)); + + // Each producer send 100 elements. + const uint64_t numOfElementPerProducer = 100; + const uint64_t expectNumOfElementReceive = numOfElementPerProducer * 2; + const uint64_t elementSize = defaultProducerConf_.pageSize / 10; + RandomData rand; + auto str = rand.GetRandomString(elementSize); + + for (uint64_t i = 1; i <= numOfElementPerProducer; i++) { + Element element(reinterpret_cast(&str.front()), elementSize); + ASSERT_EQ(prod1->Send(element), Status::OK()); + ASSERT_EQ(prod2->Send(element), Status::OK()); + } + + // The consumer receive 200 elements + std::vector outElements; + DS_ASSERT_OK(con1->Receive(expectNumOfElementReceive, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), expectNumOfElementReceive); + for (auto &ele : outElements) { + std::string data(reinterpret_cast(ele.ptr), ele.size); + ASSERT_EQ(data, str); + } + outElements.clear(); + + DS_ASSERT_OK(prod1->Close()); + DS_ASSERT_OK(prod2->Close()); + DS_ASSERT_OK(con1->Close()); + DS_ASSERT_OK(w1Client_->DeleteStream(stream1)); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/pub_sub_test.cpp b/tests/st/client/stream_cache/pub_sub_test.cpp new file mode 100644 index 0000000..f845af6 --- /dev/null +++ b/tests/st/client/stream_cache/pub_sub_test.cpp @@ -0,0 +1,299 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/stream/stream_config.h" +#include "sc_client_common.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class PubSubTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + SCClientCommon::SetClusterSetupOptions(opts); + } + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + void TearDown() override + { + client_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + void InitTest() + { + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions options; + options.accessKey = accessKey_; + options.secretKey = secretKey_; + options.host = workerAddress.Host(); + options.port = workerAddress.Port(); + client_ = std::make_shared(options); + EXPECT_NE(client_, nullptr); + DS_ASSERT_OK(client_->Init()); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + std::shared_ptr client_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +class SingleStreamTest : public PubSubTest {}; + +TEST_F(SingleStreamTest, MultiSubMultiConsumer) +{ + std::string stream1("singleStream"); + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumer)); + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config2, consumer2)); + + // The last two should fail + std::shared_ptr consumer3; + DS_EXPECT_NOT_OK(client_->Subscribe(stream1, config, consumer3)); + std::shared_ptr consumer4; + DS_ASSERT_NOT_OK(client_->Subscribe(stream1, config2, consumer4)); +} + +class MultiStreamTest : public PubSubTest {}; +// Single sub test for multi stream +TEST_F(MultiStreamTest, SingleProducerSingleConsumerWithOneSub) +{ + // Create stream1 with one producer and one consumer in one subscription + std::string stream1("stream1_SPSCSS"); + std::shared_ptr producerS1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producerS1, defaultProducerConf_)); + std::shared_ptr consumerS1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumerS1)); + + // Create stream2 with one producer and one consumer in one subscription + std::string stream2("stream2_SPSCSS"); + std::shared_ptr producerS2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producerS2, defaultProducerConf_)); + std::shared_ptr consumerS2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumerS2)); +} + +TEST_F(MultiStreamTest, SingleProducerMultiConsumerWithOneSub) +{ + // Create stream1 with one producer and two consumers in one subscription + std::string stream1("stream1_SPMCSS"); + std::shared_ptr producerS1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producerS1, defaultProducerConf_)); + std::shared_ptr consumer1S1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumer1S1)); + std::shared_ptr consumer2S1; + DS_ASSERT_NOT_OK(client_->Subscribe(stream1, config, consumer2S1)); + + // Create stream2 with one producer and two consumers in one subscription + std::string stream2("stream2_SPMCSS"); + std::shared_ptr producerS2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producerS2, defaultProducerConf_)); + std::shared_ptr consumer1S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumer1S2)); + std::shared_ptr consumer2S2; + DS_ASSERT_NOT_OK(client_->Subscribe(stream2, config, consumer2S2)); +} + +TEST_F(MultiStreamTest, MultiProducerSingleConsumerWithOneSub) +{ + // Create stream1 with two producer and one consumer + std::string stream1("stream1_MPSCSS"); + std::shared_ptr producer1S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer1S1, defaultProducerConf_)); + std::shared_ptr producer2S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer2S1, defaultProducerConf_)); + std::shared_ptr consumerS1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumerS1)); + + // Create stream2 with one producer and one consumer + std::string stream2("stream2_MPSCSS"); + std::shared_ptr producer1S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer1S2, defaultProducerConf_)); + std::shared_ptr producer2S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer2S2, defaultProducerConf_)); + std::shared_ptr consumerS2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumerS2)); +} + +TEST_F(MultiStreamTest, MultiProducerMultiConsumerWithOneSub) +{ + // Create stream1 with two producer and two consumer in one subscription + std::string stream1("stream1_MPMCSS"); + std::shared_ptr producer1S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer1S1, defaultProducerConf_)); + std::shared_ptr producer2S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer2S1, defaultProducerConf_)); + std::shared_ptr consumer1S1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumer1S1)); + std::shared_ptr consumer2S1; + DS_ASSERT_NOT_OK(client_->Subscribe(stream1, config, consumer2S1)); + + // Create stream2 with two producer and two consumer in one subscription + std::string stream2("stream2_MPMCSS"); + std::shared_ptr producer1S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer1S2, defaultProducerConf_)); + std::shared_ptr producer2S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer2S2, defaultProducerConf_)); + std::shared_ptr consumer1S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumer1S2)); + std::shared_ptr consumer2S2; + DS_ASSERT_NOT_OK(client_->Subscribe(stream2, config, consumer2S2)); +} + +// Multi sub test for multi stream +TEST_F(MultiStreamTest, SingleProducerSingleConsumerWithMultiSub) +{ + // Create stream1 with one producer and two subscription(each subscription have one consumer) + std::string stream1("stream1_SPSCMS"); + std::shared_ptr producerS1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producerS1, defaultProducerConf_)); + std::shared_ptr consumerSub1S1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumerSub1S1)); + std::shared_ptr consumerSub2S1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config2, consumerSub2S1)); + + // Create stream2 with one producer and one consumer in one subscription + std::string stream2("stream2_SPSCMS"); + std::shared_ptr producerS2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producerS2, defaultProducerConf_)); + std::shared_ptr consumerSub1S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumerSub1S2)); + std::shared_ptr consumerSub2S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config2, consumerSub2S2)); +} + +TEST_F(MultiStreamTest, SingleProducerMultiConsumerWithMultiSub) +{ + // Create stream1 with one producer and two subscription, each subscription have two consumer + std::string stream1("stream1_SPMCMS"); + std::shared_ptr producerS1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producerS1, defaultProducerConf_)); + std::shared_ptr consumer1Sub1S1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumer1Sub1S1)); + std::shared_ptr consumer2Sub1S1; + DS_ASSERT_NOT_OK(client_->Subscribe(stream1, config, consumer2Sub1S1)); + std::shared_ptr consumer1Sub2S1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config2, consumer1Sub2S1)); + std::shared_ptr consumer2Sub2S1; + DS_ASSERT_NOT_OK(client_->Subscribe(stream1, config2, consumer2Sub2S1)); + + // Create stream2 with one producer and two consumer in one subscription + std::string stream2("stream2_SPMCMS"); + std::shared_ptr producerS2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producerS2, defaultProducerConf_)); + std::shared_ptr consumer1Sub1S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumer1Sub1S2)); + std::shared_ptr consumer2Sub1S2; + DS_ASSERT_NOT_OK(client_->Subscribe(stream2, config, consumer2Sub1S2)); + std::shared_ptr consumer1Sub2S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config2, consumer1Sub2S2)); + std::shared_ptr consumer2Sub2S2; + DS_ASSERT_NOT_OK(client_->Subscribe(stream2, config2, consumer2Sub2S2)); +} + +TEST_F(MultiStreamTest, MultiProducerSingleConsumerWithMultiSub) +{ + // Create stream1 with two producer and two subscription each subscription have one consumer + std::string stream1("stream1_MPSCMS"); + std::shared_ptr producer1S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer1S1, defaultProducerConf_)); + std::shared_ptr producer2S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer2S1, defaultProducerConf_)); + std::shared_ptr consumerSub1S1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumerSub1S1)); + std::shared_ptr consumerSub2S1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config2, consumerSub2S1)); + + // Create stream2 with two producer and two subscription each subscription have one consumer + std::string stream2("stream2_MPSCMS"); + std::shared_ptr producer1S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer1S2, defaultProducerConf_)); + std::shared_ptr producer2S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer2S2, defaultProducerConf_)); + std::shared_ptr consumerSub1S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumerSub1S2)); + std::shared_ptr consumerSub2S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config2, consumerSub2S2)); +} + +TEST_F(MultiStreamTest, MultiProducerMultiConsumerWithMultiSub) +{ + // Create stream1 with two producer and two subscription , each subscription have two consumer + std::string stream1("stream1_MPMCMS"); + std::shared_ptr producer1S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer1S1, defaultProducerConf_)); + std::shared_ptr producer2S1; + DS_ASSERT_OK(client_->CreateProducer(stream1, producer2S1, defaultProducerConf_)); + std::shared_ptr consumer1Sub1S1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config, consumer1Sub1S1)); + std::shared_ptr consumer2Sub1S1; + DS_ASSERT_NOT_OK(client_->Subscribe(stream1, config, consumer2Sub1S1)); + std::shared_ptr consumer1Sub2S1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(stream1, config2, consumer1Sub2S1)); + std::shared_ptr consumer2Sub2S1; + DS_ASSERT_NOT_OK(client_->Subscribe(stream1, config2, consumer2Sub2S1)); + + // Create stream2 with two producer and two subscription, each subscription have two consumer + std::string stream2("stream2_MPMCMS"); + std::shared_ptr producer1S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer1S2, defaultProducerConf_)); + std::shared_ptr producer2S2; + DS_ASSERT_OK(client_->CreateProducer(stream2, producer2S2, defaultProducerConf_)); + std::shared_ptr consumer1Sub1S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config, consumer1Sub1S2)); + std::shared_ptr consumer2Sub1S2; + DS_ASSERT_NOT_OK(client_->Subscribe(stream2, config, consumer2Sub1S2)); + std::shared_ptr consumer1Sub2S2; + DS_ASSERT_OK(client_->Subscribe(stream2, config2, consumer1Sub2S2)); + std::shared_ptr consumer2Sub2S2; + DS_ASSERT_NOT_OK(client_->Subscribe(stream2, config2, consumer2Sub2S2)); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/pub_sub_utils.h b/tests/st/client/stream_cache/pub_sub_utils.h new file mode 100644 index 0000000..3909f65 --- /dev/null +++ b/tests/st/client/stream_cache/pub_sub_utils.h @@ -0,0 +1,85 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASYSTEM_PUB_SUB_UTILS_H +#define DATASYSTEM_PUB_SUB_UTILS_H + +#include +#include +#include + +#include "common/stream_cache/stream_common.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +namespace mock { +struct InputStreamInfo { + int producerNum; + std::unordered_map> subscriptions; +}; + +struct OutputStreamInfo { + std::vector> producers; + std::unordered_map>> consumers; +}; +} // namespace mock + +template +Status CreateProducersAndConsumers(std::shared_ptr &client, + std::unordered_map &input, + std::unordered_map &output) +{ + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + output.reserve(input.size()); + for (auto &iter : input) { + std::string streamName = iter.first; + auto &info = iter.second; + // Create producers for the stream + output[streamName].producers.reserve(info.producerNum); + for (int i = 0; i < info.producerNum; i++) { + LOG(INFO) << "Start create producer " << i << " for stream " << streamName; + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, conf)); + LOG(INFO) << "Finished create producer for stream " << streamName; + output[streamName].producers.push_back(std::move(producer)); + } + // Create consumers for the subscriptions + for (auto &subInfo : info.subscriptions) { + auto &subName = subInfo.first; + SubscriptionConfig config(subName, subInfo.second.first); + int consumerNum = subInfo.second.second; + output[streamName].consumers[subName].reserve(consumerNum); + for (int i = 0; i < consumerNum; i++) { + LOG(INFO) << "Start create consumer" << i << " for stream " << streamName; + std::shared_ptr consumer; + RETURN_IF_NOT_OK(client->Subscribe(streamName, config, consumer)); + LOG(INFO) << "Finished create consumer for sub" << subName; + output[streamName].consumers[subName].push_back(std::move(consumer)); + } + } + } + return Status::OK(); +} +} // namespace st +} // namespace datasystem + +#endif // DATASYSTEM_PUB_SUB_UTILS_H diff --git a/tests/st/client/stream_cache/query_stream_topo_test.cpp b/tests/st/client/stream_cache/query_stream_topo_test.cpp new file mode 100644 index 0000000..1c548cc --- /dev/null +++ b/tests/st/client/stream_cache/query_stream_topo_test.cpp @@ -0,0 +1,543 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include + +#include +#include +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/common/util/random_data.h" +#include "sc_client_common.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class QueryStreamTopoTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = 5; + opts.numMasters = 1; + opts.enableDistributedMaster = "false"; + opts.numRpcThreads = 0; + opts.vLogLevel = 2; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + clientVector_.clear(); + ExternalClusterTest::TearDown(); + } + +protected: + void InitTest() + { + workerAddressVector_.resize(clientNum_); + for (int i = 0; i < clientNum_; ++i) { + DS_ASSERT_OK(cluster_->GetWorkerAddr(i, workerAddressVector_[i])); + LOG(INFO) << FormatString("Worker%d: <%s>", i, workerAddressVector_[i].ToString()); + } + + clientVector_.resize(clientNum_); + for (size_t i = 0; i < clientVector_.size(); i++) { + InitStreamClient(i, clientVector_[i]); + } + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + + Status ProduceRandomData(std::shared_ptr &producer, const std::string &producerName, uint64_t eleSz, + uint64_t eleNum) + { + ElementGenerator elementGenerator(eleSz); + auto strs = elementGenerator.GenElements(producerName, eleNum, 1); + for (size_t i = 0; i < eleNum; i++) { + RETURN_IF_NOT_OK(producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()))); + } + return producer->Close(); + } + + Status ProduceMockData(std::shared_ptr &producer, size_t round, int timeoutMs = 0) + { + std::vector strs; + strs.emplace_back("hello world"); + strs.emplace_back("hello China"); + strs.emplace_back("hello 2022"); + + std::stringstream ss; + ss << FormatString("\n============= Round:%d Send Data =============\n", round); + for (size_t i = 0; i < strs.size(); i++) { + if (timeoutMs != 0 && i == strs.size() - 1) { + LOG(INFO) << FormatString("Round:%d, after number %d element flush, sleep for %d ms", round, i, + timeoutMs); + std::this_thread::sleep_for(std::chrono::milliseconds(timeoutMs)); + } + ss << FormatString("String %d is:%s\n", i, strs[i]); + RETURN_IF_NOT_OK(producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()))); + } + ss << "=============================================\n"; + LOG(INFO) << ss.str(); + return producer->Close(); + } + + void FinanceCase(size_t round, bool withRandomData = true, bool withRandomNode = false); + + void TimeoutCase(size_t round, int timeoutMs = 0); + + std::vector> clientVector_; + std::vector workerAddressVector_; + uint8_t clientNum_ = 5; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +class NewRandom : RandomData { +public: + std::vector RandomSequenceFromSet(std::set &set) + { + std::vector sequence; + auto upperLimit = *std::max_element(set.begin(), set.end()); + auto lowerLimit = *std::min_element(set.begin(), set.end()); + thread_local static std::uniform_int_distribution distribution(lowerLimit, upperLimit); + while (!set.empty()) { + thread_local static auto generator = randomDevice_; + auto ele = distribution(generator); + if (set.erase(ele) == 1) { + sequence.emplace_back(ele); + } + } + return sequence; + } +}; + +TEST_F(QueryStreamTopoTest, QueryTest) +{ + std::string stream1("testQueryTest"); + + uint64_t producersCount = 0; + uint64_t consumersCount = 0; + + // stream not exists. + DS_ASSERT_OK(clientVector_[0]->QueryGlobalProducersNum(stream1, producersCount)); + ASSERT_EQ(producersCount, 0ul); + DS_ASSERT_OK(clientVector_[0]->QueryGlobalConsumersNum(stream1, consumersCount)); + ASSERT_EQ(consumersCount, 0ul); + + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + std::shared_ptr node1Producer1; + std::shared_ptr node1Producer2; + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, node1Producer2, defaultProducerConf_)); + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(clientVector_[0]->Subscribe(stream1, config1, node1Consumer1)); + + DS_ASSERT_OK(clientVector_[0]->QueryGlobalProducersNum(stream1, producersCount)); + // Producer count will still be 1 + // as master just counts number of workers having atleast one producer + ASSERT_EQ(producersCount, 1ul); + DS_ASSERT_OK(clientVector_[0]->QueryGlobalConsumersNum(stream1, consumersCount)); + ASSERT_EQ(consumersCount, size_t(1)); + + producersCount = 0; + consumersCount = 100; + std::shared_ptr node2Producer1; + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(clientVector_[1]->Subscribe(stream1, config2, node2Consumer1)); + + DS_ASSERT_OK(clientVector_[1]->QueryGlobalProducersNum(stream1, producersCount)); + DS_ASSERT_OK(clientVector_[1]->QueryGlobalConsumersNum(stream1, consumersCount)); + ASSERT_EQ(consumersCount, size_t(2)); +} + +TEST_F(QueryStreamTopoTest, ConcurrentQueryTest) +{ + std::string stream1("testConcurrentQueryTest"); + std::vector configVector = { SubscriptionConfig("sub0", SubscriptionType::STREAM), + SubscriptionConfig("sub1", SubscriptionType::STREAM), + SubscriptionConfig("sub2", SubscriptionType::STREAM) }; + ThreadPool pool(clientNum_); + pool.Submit([this, stream1, &configVector]() { + thread_local uint32_t queryRet = 0; + std::shared_ptr n0p0; + std::shared_ptr n0c0; + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, n0p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->Subscribe(stream1, configVector[0], n0c0)); + + std::string localHostName = workerAddressVector_[0].ToString(); + uint64_t producerNum = 0; + uint64_t consumerNum = 0; + DS_ASSERT_OK(clientVector_[0]->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(clientVector_[0]->QueryGlobalConsumersNum(stream1, consumerNum)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(clientVector_[0]->QueryGlobalProducersNum(stream1, producerNum)); + ASSERT_GE(producerNum, size_t(1)); + ASSERT_LE(producerNum, size_t(3)); + LOG(INFO) << FormatString("Thread 0, #<%d>, global pub node number query ret:<%d>", + std::hash{}(std::this_thread::get_id()), queryRet); + consumerNum = 0; + DS_ASSERT_OK(clientVector_[0]->QueryGlobalConsumersNum(stream1, consumerNum)); + ASSERT_GE(consumerNum, size_t(1)); + ASSERT_LE(consumerNum, size_t(3)); + LOG(INFO) << FormatString("Thread 0, #<%d>, global consumer number query ret:<%d>", + std::hash{}(std::this_thread::get_id()), queryRet); + + DS_ASSERT_OK(n0p0->Close()); + DS_ASSERT_OK(n0c0->Close()); + clientVector_[0]->DeleteStream(stream1); + }); + pool.Submit([this, stream1, &configVector]() { + thread_local uint32_t queryRet = 0; + std::shared_ptr n1p0; + std::shared_ptr n1c0; + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, n1p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->Subscribe(stream1, configVector[1], n1c0)); + + std::string localHostName = workerAddressVector_[1].ToString(); + uint64_t producerNum = 0; + uint64_t consumerNum = 0; + DS_ASSERT_OK(clientVector_[1]->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(clientVector_[1]->QueryGlobalConsumersNum(stream1, consumerNum)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(clientVector_[1]->QueryGlobalProducersNum(stream1, producerNum)); + ASSERT_GE(producerNum, size_t(1)); + ASSERT_LE(producerNum, size_t(3)); + consumerNum = 0; + LOG(INFO) << FormatString("Thread 1, #<%d>, global pub node number query ret:<%d>", + std::hash{}(std::this_thread::get_id()), queryRet); + DS_ASSERT_OK(clientVector_[1]->QueryGlobalConsumersNum(stream1, consumerNum)); + ASSERT_GE(consumerNum, size_t(1)); + ASSERT_LE(consumerNum, size_t(3)); + LOG(INFO) << FormatString("Thread 1, #<%d>, global consumer number query ret:<%d>", + std::hash{}(std::this_thread::get_id()), queryRet); + + DS_ASSERT_OK(n1p0->Close()); + DS_ASSERT_OK(n1c0->Close()); + clientVector_[1]->DeleteStream(stream1); + }); + pool.Submit([this, stream1, &configVector]() { + thread_local uint32_t queryRet = 0; + std::shared_ptr n2p0; + std::shared_ptr n2c0; + DS_ASSERT_OK(clientVector_[2]->CreateProducer(stream1, n2p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[2]->Subscribe(stream1, configVector[2], n2c0)); + + uint64_t producerNum = 0; + uint64_t consumerNum = 0; + DS_ASSERT_OK(clientVector_[2]->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(clientVector_[2]->QueryGlobalConsumersNum(stream1, consumerNum)); + + LOG(INFO) << FormatString("Thread 2, #<%d>, global pub node number query ret:<%d>", + std::hash{}(std::this_thread::get_id()), queryRet); + ASSERT_GE(consumerNum, size_t(1)); + ASSERT_LE(consumerNum, size_t(3)); + LOG(INFO) << FormatString("Thread 2, #<%d>, global consumer number query ret:<%d>", + std::hash{}(std::this_thread::get_id()), queryRet); + + DS_ASSERT_OK(n2p0->Close()); + DS_ASSERT_OK(n2c0->Close()); + clientVector_[2]->DeleteStream(stream1); + }); +} + +void QueryStreamTopoTest::FinanceCase(size_t round, bool withRandomData, bool withRandomNode) +{ + size_t nodeNum = 3; + std::string stream1 = FormatString("FinanceCase%d-S1", round); + std::string stream2 = FormatString("FinanceCase%d-S2", round); + LOG(INFO) << FormatString("Src ----> Process stream name:%s", stream1); + LOG(INFO) << FormatString("Process ----> Sink stream name:%s", stream2); + + ThreadPool pool(nodeNum); + std::vector> futs; + size_t eleSz = 16; + uint64_t eleNum = 3; + + std::vector nodeIdx; + std::set idxSet = { 0, 1, 2 }; + if (withRandomNode) { + auto rndGenerator = NewRandom(); + nodeIdx = rndGenerator.RandomSequenceFromSet(idxSet); + } else { + nodeIdx = { 0, 1, 2 }; + } + + // w1:p1(stream1) and then continuously query, then producer data + futs.emplace_back(pool.Submit([this, &stream1, eleSz, eleNum, round, withRandomData, &nodeIdx]() { + std::shared_ptr producer0; + RETURN_IF_NOT_OK(clientVector_[nodeIdx[0]]->CreateProducer(stream1, producer0, defaultProducerConf_)); + thread_local auto begin = Timer(); + while (begin.ElapsedMilliSecond() <= 10 * 1000) { + // Simulate finance scenario, detect downstream consumer register successfully + uint64_t consumerNum = 0; + RETURN_IF_NOT_OK(clientVector_[nodeIdx[0]]->QueryGlobalConsumersNum(stream1, consumerNum)); + if (consumerNum > 0) { + LOG(INFO) << FormatString("[S:%s] Node 1 detect %d consumer subscribe success, start to producer data", + stream1, consumerNum); + break; + } + } + // Send. + if (withRandomData) { + RETURN_IF_NOT_OK(ProduceRandomData(producer0, "producer1", eleSz, eleNum)); + } else { + RETURN_IF_NOT_OK(ProduceMockData(producer0, round)); + } + return Status::OK(); + })); + // w2:c1 (stream1, sub1) + futs.emplace_back(pool.Submit([this, &stream1, &stream2, eleNum, round, withRandomData, &nodeIdx]() { + std::shared_ptr consumer1 = nullptr; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(clientVector_[nodeIdx[1]]->Subscribe(stream1, config1, consumer1)); + CHECK_FAIL_RETURN_STATUS(consumer1, StatusCode::K_RUNTIME_ERROR, "Fail to subscribe"); + std::vector outElements; + thread_local auto begin = Timer(); + uint64_t timeOut = 3 * 1000; + uint64_t retry = 0; + while (outElements.size() < eleNum && begin.ElapsedMilliSecond() <= timeOut) { + std::vector output; + consumer1->Receive(3, 0, output); + outElements.insert(outElements.end(), output.begin(), output.end()); + retry++; + } + LOG(INFO) << FormatString("Round:%d, Process receive retry time:%d", round, retry); + CHECK_FAIL_RETURN_STATUS(outElements.size() == eleNum, StatusCode::K_RUNTIME_ERROR, + FormatString("Round:%d, Expect receive %d elements, Actually got %d elements", round, + eleNum, outElements.size())); + + std::shared_ptr producer1; + RETURN_IF_NOT_OK(clientVector_[nodeIdx[1]]->CreateProducer(stream2, producer1, defaultProducerConf_)); + thread_local auto begin1 = Timer(); + while (begin1.ElapsedMilliSecond() <= timeOut) { + // Simulate finance scenario, detect downstream consumer register successfully + uint64_t consumerNum = 0; + RETURN_IF_NOT_OK(clientVector_[nodeIdx[1]]->QueryGlobalConsumersNum(stream2, consumerNum)); + if (consumerNum > 0) { + LOG(INFO) << FormatString("[S:%s] Node2 detect %d consumer subscribe success, start to producer data", + stream2, consumerNum); + break; + } + } + if (!withRandomData) { + std::stringstream ss; + ss << FormatString("\n============= Round:%d Middle Data =============\n", round); + size_t idx = 0; + for (auto &ele : outElements) { + std::string tmpString{ reinterpret_cast(ele.ptr), ele.size }; + ss << FormatString("String %d is:%s\n", idx, tmpString); + idx++; + } + ss << "===============================================\n"; + LOG(INFO) << ss.str(); + } + for (auto &ele : outElements) { + RETURN_IF_NOT_OK(producer1->Send(ele)); + } + RETURN_IF_NOT_OK(producer1->Close()); + return consumer1->Close(); + })); + futs.emplace_back(pool.Submit([this, &stream2, eleNum, round, withRandomData, &nodeIdx]() { + std::shared_ptr consumer2 = nullptr; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(clientVector_[nodeIdx[2]]->Subscribe(stream2, config2, consumer2)); + CHECK_FAIL_RETURN_STATUS(consumer2, StatusCode::K_RUNTIME_ERROR, "Fail to subscribe"); + std::vector outElements; + thread_local auto begin = Timer(); + uint64_t retry = 0; + uint64_t timeOut = 10 * 1000; + while (outElements.size() < eleNum && begin.ElapsedMilliSecond() <= timeOut) { + std::vector output; + consumer2->Receive(3, 0, output); + outElements.insert(outElements.end(), output.begin(), output.end()); + retry++; + } + LOG(INFO) << FormatString("Round:%d, Src Receive retry time:%d", round, retry); + + if (!withRandomData) { + std::stringstream ss; + ss << FormatString("\n============= Round:%d Recv Data =============\n", round); + size_t idx = 0; + for (auto &ele : outElements) { + std::string tmpString{ reinterpret_cast(ele.ptr), ele.size }; + ss << FormatString("String %d is:%s\n", idx, tmpString); + idx++; + } + ss << "=============================================\n"; + LOG(INFO) << ss.str(); + } + + CHECK_FAIL_RETURN_STATUS(outElements.size() == eleNum, StatusCode::K_RUNTIME_ERROR, + FormatString("Round:%d, Expect to receive %d elements, Actually got %d elements", + round, eleNum, outElements.size())); + return consumer2->Close(); + })); + for (auto &fut : futs) { + ASSERT_EQ(fut.get(), Status::OK()); + } +} + +void QueryStreamTopoTest::TimeoutCase(size_t round, int timeoutMs) +{ + std::string stream1 = FormatString("TimeoutCase%d-S1", round); + size_t nodeNum = 2; + size_t eleNum = 3; + std::vector> futs; + ThreadPool pool(nodeNum); + futs.emplace_back(pool.Submit([this, &stream1, round, timeoutMs]() { + std::shared_ptr producer0; + RETURN_IF_NOT_OK(clientVector_[0]->CreateProducer(stream1, producer0, defaultProducerConf_)); + thread_local auto begin = Timer(); + while (begin.ElapsedMilliSecond() <= 10 * 1000) { + // Simulate finance scenario, detect downstream consumer register successfully + uint64_t consumerNum = 0; + RETURN_IF_NOT_OK(clientVector_[0]->QueryGlobalConsumersNum(stream1, consumerNum)); + if (consumerNum > 0) { + LOG(INFO) << FormatString("[S:%s] Node 1 detect %d consumer subscribe success, start to producer data", + stream1, consumerNum); + break; + } + } + // Send. + RETURN_IF_NOT_OK(ProduceMockData(producer0, round, timeoutMs)); + return Status::OK(); + })); + futs.emplace_back(pool.Submit([this, &stream1, round, eleNum]() { + std::shared_ptr consumer = nullptr; + SubscriptionConfig config("sub0", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(clientVector_[1]->Subscribe(stream1, config, consumer)); + CHECK_FAIL_RETURN_STATUS(consumer, StatusCode::K_RUNTIME_ERROR, "Fail to subscribe"); + std::vector outElements; + thread_local auto begin = Timer(); + uint64_t retry = 0; + // Set receive loop timeout = 100 seconds + uint64_t timeOut = 100 * 1000; + while (outElements.size() < eleNum && begin.ElapsedMilliSecond() <= timeOut) { + std::vector output; + consumer->Receive(3, 0, output); + outElements.insert(outElements.end(), output.begin(), output.end()); + retry++; + } + LOG(INFO) << FormatString("Round:%d, Src Receive retry time:%d", round, retry); + std::stringstream ss; + ss << FormatString("\n============= Round:%d Recv Data =============\n", round); + size_t idx = 0; + for (auto &ele : outElements) { + std::string tmpString{ reinterpret_cast(ele.ptr), ele.size }; + ss << FormatString("String %d is:%s\n", idx, tmpString); + idx++; + } + ss << "=============================================\n"; + LOG(INFO) << ss.str(); + + CHECK_FAIL_RETURN_STATUS(outElements.size() == eleNum, StatusCode::K_RUNTIME_ERROR, + FormatString("Round:%d, Expect to receive %d elements, Actually got %d elements", + round, eleNum, outElements.size())); + return consumer->Close(); + })); +} + +TEST_F(QueryStreamTopoTest, FinanceCaseFixedNode) +{ + auto rounds = 8; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { FinanceCase(round, false); }); + } +} + +class QueryStreamTopoTest1 : public QueryStreamTopoTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + QueryStreamTopoTest::SetClusterSetupOptions(opts); + } +}; + +TEST_F(QueryStreamTopoTest1, FinanceCaseRndNode) +{ + auto rounds = 8; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Submit([this, round]() { FinanceCase(round, false, true); }); + } +} + +TEST_F(QueryStreamTopoTest, DISABLED_TimeoutCase) +{ + const int timeoutMs = 80 * 1000; + TimeoutCase(0, timeoutMs); +} + +TEST_F(QueryStreamTopoTest, ConcurrentWithSubscribe) +{ + FinanceCase(0); +} + +TEST_F(QueryStreamTopoTest, QueryDistributionTest) +{ + std::string stream1("stream1"); + std::vector configVector = { SubscriptionConfig("sub1", SubscriptionType::STREAM), + SubscriptionConfig("sub2", SubscriptionType::STREAM), + SubscriptionConfig("sub3", SubscriptionType::STREAM), + SubscriptionConfig("sub4", SubscriptionType::STREAM), + SubscriptionConfig("sub5", SubscriptionType::STREAM) }; + + std::vector> producerVector(5); + std::vector> consumerVector(5); + uint64_t producerNum = 0; + uint64_t consumerNum = 0; + for (int i = 0; i < 5; ++i) { + DS_ASSERT_OK(clientVector_[i]->CreateProducer(stream1, producerVector[i], defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[i]->Subscribe(stream1, configVector[i], consumerVector[i])); + } + for (int i = 0; i < 5; ++i) { + producerNum = 0; + DS_ASSERT_OK(clientVector_[i]->QueryGlobalProducersNum(stream1, producerNum)); + ASSERT_EQ(producerNum, size_t(5)); + + consumerNum = 0; + DS_ASSERT_OK(clientVector_[i]->QueryGlobalConsumersNum(stream1, consumerNum)); + ASSERT_EQ(consumerNum, size_t(5)); + } +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/remote_push_test.cpp b/tests/st/client/stream_cache/remote_push_test.cpp new file mode 100644 index 0000000..b7e50bf --- /dev/null +++ b/tests/st/client/stream_cache/remote_push_test.cpp @@ -0,0 +1,339 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Remote cache push. + */ +#include +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/utils/status.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +constexpr int NUM_ELES = 100; +constexpr uint64_t PAGE_SIZE = 4 * 1024; +constexpr uint64_t BIG_SIZE_RATIO = 16; +constexpr uint64_t BIG_SIZE = PAGE_SIZE / BIG_SIZE_RATIO; +constexpr uint64_t TEST_STREAM_SIZE = 64 * 1024 * 1024; +class RemotePushTest : public SCClientCommon { +public: + static std::string streamName_; + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; + + void SetUp() override; + + void TearDown() override; + + void StartStream(const std::unique_ptr &cluster, const std::string &streamName, const std::string &ak, + const std::string &sk, uint64_t numElements); + +protected: + void InitTest(); + + // Mock producer worker. + HostPort pubWorkerAddress_; + HostPort subWorkerAddress_; + + // Construct consumer worker client. + std::shared_ptr consumer_; + std::shared_ptr consumerClient_ = nullptr; + + static constexpr int CLIENT_RPC_TIMEOUT = 4 * 60 * 1000; + std::shared_ptr producer_; + std::shared_ptr producerClient_ = nullptr; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + std::unique_ptr signature_ = std::make_unique(accessKey_, secretKey_); +}; +std::string RemotePushTest::streamName_ = "stream"; + +void RemotePushTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numWorkers = 3; + opts.numEtcd = 1; + opts.workerGflagParams = "-shared_memory_size_mb=10000"; + opts.numRpcThreads = 0; + opts.vLogLevel = 2; + SCClientCommon::SetClusterSetupOptions(opts); +} + +void RemotePushTest::SetUp() +{ + ExternalClusterTest::SetUp(); + InitTest(); +} + +void RemotePushTest::TearDown() +{ + producerClient_ = nullptr; + consumerClient_ = nullptr; + ExternalClusterTest::TearDown(); +} + +void RemotePushTest::InitTest() +{ + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, pubWorkerAddress_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, subWorkerAddress_)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + InitStreamClient(0, producerClient_); + ASSERT_EQ(producerClient_->CreateProducer(streamName_, producer_, conf), Status::OK()); + + auto api = std::make_unique(pubWorkerAddress_, RpcCredential(), signature_.get()); + ASSERT_EQ(api->Init(CLIENT_RPC_TIMEOUT), Status::OK()); + + InitStreamClient(1, consumerClient_); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + ASSERT_EQ(consumerClient_->Subscribe(streamName_, config, consumer_), Status::OK()); +} + +TEST_F(RemotePushTest, TestSingleProducer) +{ + ElementGenerator elementGenerator(BIG_SIZE * 5 / 4); + std::string producerName = "producer1"; + auto strs = elementGenerator.GenElements(producerName, NUM_ELES); + + // Send. + for (int i = 0; i < NUM_ELES; i++) { + ASSERT_EQ(producer_->Send(Element((uint8_t *)strs[i].data(), strs[i].size())), Status::OK()); + LOG(INFO) << FormatString("Sz: [%d]: [%zu]", i, strs[i].size()); + } + + // Recv. + std::vector outElements; + ASSERT_EQ(consumer_->Receive(NUM_ELES, 0, outElements), Status::OK()); + std::unordered_map seqNoMap; + for (const auto &element : outElements) { + LOG(INFO) << FormatString("Cursor: [%zu], Sz: [%zu]", element.id, element.size); + ElementView view(std::string((const char *)element.ptr, element.size)); + ASSERT_EQ(view.VerifyIntegrity(), Status::OK()); + ASSERT_EQ(view.VerifyFifo(seqNoMap), Status::OK()); + } +} + +void RemotePushTest::StartStream(const std::unique_ptr &cluster, const std::string &streamName, + const std::string &ak, const std::string &sk, uint64_t numElements) +{ + const int numPubs = 24; + const int numSubs = 3; + const int numNodes = 3; + auto producerSend = [this, &cluster, ak, sk, &streamName](int index, uint64_t numElements, + const std::string &producerName) { + std::shared_ptr client; + InitStreamClient(index, client); + + ProducerConf conf; + conf.maxStreamSize = 100 * 1024 * 1024; + conf.pageSize = 2 * 1024 * 1024; + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(streamName, producer, conf)); + ElementGenerator elementGenerator(1024 * 1024, 100); + while (numElements > 0) { + auto str = elementGenerator.GenElement(producerName); + Status rc = producer->Send(Element((uint8_t *)str.data(), str.size())); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + continue; + } + DS_ASSERT_OK(rc); + --numElements; + } + producer->Close(); + client.reset(); + }; + + auto consumerRecv = [&cluster, ak, sk, &streamName, this](int index, uint64_t numElements) { + std::shared_ptr client; + InitStreamClient(index, client); + std::shared_ptr consumer; + std::string subName = "sub" + std::to_string(index); + SubscriptionConfig sub(subName, SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe(streamName, sub, consumer)); + uint64_t numElementsToReceive = numPubs * numElements; + LOG(INFO) << FormatString("[%s] Consumer expects to receive %zu elements", subName, numElementsToReceive); + const uint64_t timeoutMs = 60000; + while (numElementsToReceive > 0) { + std::vector outElements; + ASSERT_EQ(consumer->Receive(timeoutMs, outElements), Status::OK()); + for (const auto &element : outElements) { + LOG(INFO) << FormatString("[%s] Cursor: [%zu], Sz: [%zu]", subName, element.id, element.size); + ElementView view(std::string((const char *)element.ptr, element.size)); + ASSERT_EQ(view.VerifyIntegrity(), Status::OK()); + --numElementsToReceive; + consumer->Ack(element.id); + } + } + consumer->Close(); + client.reset(); + }; + + std::vector threads; + + for (int i = 0; i < numPubs; i++) { + threads.emplace_back([i, &producerSend, numElements] { + std::string pubName = "pub" + std::to_string(i); + producerSend(i % numNodes, numElements, pubName); + }); + } + + for (int i = 0; i < numSubs; ++i) { + threads.emplace_back([i, &consumerRecv, numElements] { consumerRecv(i % numNodes, numElements); }); + } + + for (auto &t : threads) { + t.join(); + } +} + +TEST_F(RemotePushTest, DISABLED_LEVEL1_TestDifferentWorkersMultipleProducers) +{ + FLAGS_v = 0; + std::vector threads; + + cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "RemoteWorkerStreamBlockingWakeupTimeout", "1*call(45000)"); + const int numElementsPerPub = 128; + const int numStreams = 1; + for (int i = 0; i < numStreams; i++) { + threads.emplace_back([i, this] { + std::string streamName = "stream000" + std::to_string(i); + StartStream(cluster_, streamName, accessKey_, secretKey_, numElementsPerPub); + }); + } + + for (auto &t : threads) { + t.join(); + } +} + +class RemotePushOOMTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + const int workerNumber = 2; + opts.numWorkers = workerNumber; + opts.numEtcd = 1; + opts.workerGflagParams = "-shared_memory_size_mb=10000"; + opts.enableDistributedMaster = "true"; + opts.numRpcThreads = 0; + opts.vLogLevel = 0; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + producerClient_ = nullptr; + consumerClient_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + void InitTest() + { + InitStreamClient(0, producerClient_); + InitStreamClient(1, consumerClient_); + } + + // Mock producer worker. + HostPort pubWorkerAddress_; + HostPort subWorkerAddress_; + + // Construct consumer worker client. + std::shared_ptr consumerClient_ = nullptr; + + static constexpr int CLIENT_RPC_TIMEOUT = 4 * 60 * 1000; + std::shared_ptr producerClient_ = nullptr; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + // std::unique_ptr signature_ = std::make_unique(accessKey_, secretKey_); + size_t maxPageCount_ = 2; + size_t pageSize_ = 1024 * 1024; +}; + +TEST_F(RemotePushOOMTest, TestRecvOOM) +{ + std::vector> consumers; + std::vector> producers; + int streamCount = 1; + int oomTimeout = 3; + + ProducerConf conf; + conf.pageSize = pageSize_; + conf.maxStreamSize = pageSize_ * maxPageCount_; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + for (int index = 0; index < streamCount; index++) { + std::string streamName = "TestOOMStream-" + std::to_string(index); + std::shared_ptr producer; + std::shared_ptr consumer; + ASSERT_EQ(producerClient_->CreateProducer(streamName, producer, conf), Status::OK()); + ASSERT_EQ(consumerClient_->Subscribe(streamName, config, consumer, true), Status::OK()); + consumers.emplace_back(std::move(consumer)); + producers.emplace_back(std::move(producer)); + } + + const size_t elementSize = 10240; // 10k. + size_t nums = pageSize_ * maxPageCount_ * 2 / elementSize - 10; + + std::string data(elementSize, 'a'); + Element element((uint8_t *)data.data(), data.size()); + // Send. + for (int index = 0; index < streamCount; index++) { + for (size_t i = 0; i < nums; i++) { + Status rc = producers[index]->Send(element); + Timer timer; + const int maxTimeout = 10; + while (rc.GetCode() == K_OUT_OF_MEMORY && timer.ElapsedSecond() < maxTimeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + rc = producers[index]->Send(element); + } + DS_ASSERT_OK(rc); + LOG(INFO) << FormatString("Stream index %zu,: send count: %zu", streamCount, i); + } + } + + sleep(oomTimeout); + + for (int index = 0; index < streamCount; index++) { + size_t recvNum = 0; + while (recvNum < nums) { + std::vector outElements; + const int recvTimeout = 1000; + ASSERT_EQ(consumers[index]->Receive(nums, recvTimeout, outElements), Status::OK()); + recvNum += outElements.size(); + LOG(INFO) << "Recv num:" << recvNum; + } + ASSERT_EQ(recvNum, nums); + } +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/remote_send_recv_test.cpp b/tests/st/client/stream_cache/remote_send_recv_test.cpp new file mode 100644 index 0000000..0ec0eb3 --- /dev/null +++ b/tests/st/client/stream_cache/remote_send_recv_test.cpp @@ -0,0 +1,1546 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Remote send test. + */ +#include +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +#define MULTI_NODE +#ifdef MULTI_NODE +constexpr int K_TWO = 2; +constexpr int K_TEN = 10; +constexpr int K_TWENTY = 20; +class RemoteSendRecvTest : public SCClientCommon { +#else +class RemoteSendRecvMoreTest : public CommonTest { +#endif +public: +#ifdef MULTI_NODE + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; +#endif + void SetUp() override; + + void TearDown() override; + + static std::string streamName_; + static std::once_flag onceFlag_; + +protected: + static Status Produce(std::shared_ptr &producer, std::string producerName, uint64_t eleSz); + + static Status ConsumeAll(std::shared_ptr &consumer, int timeout = 5000, bool checkFIFO = true, + uint64_t *res = nullptr, int producerNum = 1); + + static Status ConsumeAllClose(std::shared_ptr &consumer, int timeout = 5000, bool checkFIFO = true, + uint64_t *res = nullptr, int producerNum = 1); + + Status TestConsumerSetup(const std::string &subName, const std::string &streamName, + std::shared_ptr &client, std::promise &promise, + std::shared_ptr &consumer); + + Status TestProducerSetup(std::vector> &sFuts, std::shared_ptr &client, + const std::string &streamName, const std::string &producerName, + std::shared_ptr &producer); + + void BothDirectionTestCreateProducers(std::vector> &futs, std::shared_ptr &pool, + std::vector> &sFuts, const std::string &streamName, + std::vector> &producers); + + void BothDirectionTestCreateConsumers(std::vector> &futs, std::shared_ptr &pool, + std::vector> &promises, const std::string &streamName); + + Status CheckFuts(std::vector> &futs); + + // Different tests for a given stream name. + void SingleThreaded(int round, uint64_t num_eles, ProducerConf producerConf); + + void BasicSPSC(int round, bool checkFIFO = true, uint64_t eleSz = 2 * KB); + + void SendSideMultiProducers(int round); + + void RecvSideAddConsumer(int round); + + void SendSideConsumer(int round); + + void BothDirection(int round); + + void CreateStream_different_client(std::string base_data, int pagesize, std::string name, int send_num); + + Status TryAndDeleteStream(std::shared_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + // Mock producer worker. + HostPort w1Addr_; + HostPort w2Addr_; + HostPort w3Addr_; + + std::shared_ptr w1Client_ = nullptr; + std::shared_ptr w2Client_ = nullptr; + std::shared_ptr w3Client_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; +std::string RemoteSendRecvTest::streamName_ = "stream"; +std::once_flag RemoteSendRecvTest::onceFlag_; + +#ifdef MULTI_NODE +void RemoteSendRecvTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = 3; + opts.enableDistributedMaster = "false"; + opts.workerGflagParams = " -page_size=" + std::to_string(PAGE_SIZE); + opts.numRpcThreads = 0; + opts.vLogLevel = 2; + SCClientCommon::SetClusterSetupOptions(opts); +} +#endif + +void RemoteSendRecvTest::SetUp() +{ +#ifdef MULTI_NODE + ExternalClusterTest::SetUp(); + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, w1Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, w2Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(2, w3Addr_)); +#else + w1Addr_ = HostPort("127.0.0.1", 2295); + w3Addr_ = HostPort("127.0.0.1", 8666); + w2Addr_ = HostPort("127.0.0.1", 11589); +#endif + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 2)); + // Worker 1. + InitStreamClient(0, w1Client_); + InitStreamClient(1, w2Client_); + InitStreamClient(2, w3Client_); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; +} + +void RemoteSendRecvTest::TearDown() +{ + w1Client_ = nullptr; + w2Client_ = nullptr; + w3Client_ = nullptr; +#ifdef MULTI_NODE + ExternalClusterTest::TearDown(); +#endif +} + +Status RemoteSendRecvTest::Produce(std::shared_ptr &producer, std::string producerName, uint64_t eleSz) +{ + Status stat; + ElementGenerator elementGenerator(eleSz); + auto strs = elementGenerator.GenElements(producerName, NUM_ELES, 1); + + for (int i = 0; i < NUM_ELES; i++) { + RETURN_IF_NOT_OK(producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()))); + } + return producer->Close(); +} + +Status RemoteSendRecvTest::ConsumeAll(std::shared_ptr &consumer, int timeout, bool checkFIFO, uint64_t *res, + int producerNum) +{ + std::vector outElements; + size_t expectNum = NUM_ELES * producerNum; + RETURN_IF_NOT_OK(consumer->Receive(expectNum, timeout, outElements)); + LOG(INFO) << FormatString("Stream Consumer Receive %d elements.", outElements.size()); + std::unordered_map seqNoMap; + uint64_t eleTotalSz = 0; + for (const auto &element : outElements) { + ElementView view(std::string((const char *)element.ptr, element.size)); + RETURN_IF_NOT_OK(view.VerifyIntegrity()); + if (checkFIFO) { + RETURN_IF_NOT_OK(view.VerifyFifo(seqNoMap, 0)); + } + eleTotalSz += element.size; + } + if (res != nullptr) { + *res = eleTotalSz; + } + return Status::OK(); +} + +Status RemoteSendRecvTest::ConsumeAllClose(std::shared_ptr &consumer, int timeout, bool checkFIFO, + uint64_t *res, int producerNum) +{ + RETURN_IF_NOT_OK(ConsumeAll(consumer, timeout, checkFIFO, res, producerNum)); + RETURN_IF_NOT_OK(consumer->Close()); + return Status::OK(); +} + +void RemoteSendRecvTest::SingleThreaded(int round, uint64_t num_eles, ProducerConf producerConf) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + std::shared_ptr w2consumer; + std::shared_ptr w1Consumer; + std::shared_ptr w1Producer; + // Worker 1 pub/subs. + ASSERT_EQ(w1Client_->CreateProducer(streamName, w1Producer, producerConf), Status::OK()); + SubscriptionConfig localConfig("sub2", SubscriptionType::STREAM); + ASSERT_EQ(w1Client_->Subscribe(streamName, localConfig, w1Consumer), Status::OK()); + + // Worker 2 subs. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + ASSERT_EQ(w2Client_->Subscribe(streamName, config, w2consumer), Status::OK()); + + ElementGenerator elementGenerator(BIG_SIZE / 4); + std::string producerName = "producer1"; + auto strs = elementGenerator.GenElements(producerName, num_eles); + + // Send. + for (uint8_t i = 0; i < num_eles; i++) { + ASSERT_EQ(w1Producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size())), Status::OK()); + } + + std::vector outElements1; + outElements1.reserve(num_eles); + std::vector outElements2; + outElements2.reserve(num_eles); + bool c1Finish = false; + bool c2Finish = false; + // Recv. + auto begin = Timer(); + int timeoutMs = 30 * 1000; + while (begin.ElapsedMilliSecond() <= timeoutMs) { + if (!c1Finish) { + std::vector output1; + w1Consumer->Receive(5, 0, output1); + outElements1.insert(outElements1.end(), output1.begin(), output1.end()); + c1Finish = (outElements1.size() == num_eles); + } + if (!c2Finish) { + std::vector output2; + w2consumer->Receive(5, 0, output2); + outElements2.insert(outElements2.end(), output2.begin(), output2.end()); + c2Finish = (outElements2.size() == num_eles); + } + if (c1Finish && c2Finish) { + break; + } + } + + ASSERT_EQ(w1Consumer->Ack(num_eles), Status::OK()); + ASSERT_EQ(w1Producer->Close(), Status::OK()); + ASSERT_EQ(w2consumer->Ack(num_eles), Status::OK()); + ASSERT_EQ(w1Consumer->Close(), Status::OK()); + ASSERT_EQ(w2consumer->Close(), Status::OK()); +} + +// W1: Producer, Consumer. (Each five elements flush once). +// W2: Consumer. +TEST_F(RemoteSendRecvTest, TestSingleThreaded) +{ + auto rounds = 5; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Execute([this, round]() { SingleThreaded(round, NUM_ELES, defaultProducerConf_); }); + } +} + +TEST_F(RemoteSendRecvTest, TestSingleThreadedWindow) +{ + // Sets window size to 4 and sends four at a time + ProducerConf producerConf; + producerConf.pageSize = TEST_STREAM_SIZE / 4; + producerConf.maxStreamSize = TEST_STREAM_SIZE; + + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Execute([this, round, producerConf]() { SingleThreaded(round, NUM_ELES, producerConf); }); + } +} + +TEST_F(RemoteSendRecvTest, TestSingleThreadedWindowHalf) +{ + // Sets the stream size to half of available shm memory + ProducerConf producerConf; + producerConf.pageSize = TEST_STREAM_SIZE/4; + producerConf.maxStreamSize = TEST_STREAM_SIZE/2; + + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Execute([this, round, producerConf]() { SingleThreaded(round, NUM_ELES, producerConf); }); + } +} + +void RemoteSendRecvTest::BasicSPSC(int round, bool checkFIFO, uint64_t eleSz) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + ThreadPool pool(10); + for (int i = 0; i < 1; i++) { + LOG(INFO) << FormatString("===================== [Round: %d] [%s] Start =====================", i, streamName); + std::vector> futs; + std::promise promise; + // w1:p1 + futs.emplace_back(pool.Submit([this, &promise, streamName, eleSz]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + promise.get_future().get(); + + // Send. + RETURN_IF_NOT_OK(Produce(producer, "producer1", eleSz)); + return Status::OK(); + })); + // w2:c1 + futs.emplace_back(pool.Submit([this, &promise, streamName, checkFIFO]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + promise.set_value(); + + RETURN_IF_NOT_OK(ConsumeAllClose(consumer, 10'000, checkFIFO)); + return Status::OK(); + })); + DS_ASSERT_OK(CheckFuts(futs)); + LOG(INFO) << FormatString("Finish: %d", i); + LOG(INFO) << FormatString("===================== [Round: %d] [%s] End =====================", i, streamName); + } + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); +} + +// Recv slower than send. +// W1: Producer. +// W2: Consumer. +TEST_F(RemoteSendRecvTest, DISABLED_TestBasicSPSC) +{ + auto rounds = 1; + ThreadPool roundThreads(rounds); + for (int round = 0; round < rounds; round++) { + roundThreads.Execute([this, round]() { BasicSPSC(round); }); + } +} + +TEST_F(RemoteSendRecvTest, TestParallelRemoteSendProfiling) +{ + auto rounds = 8; + std::unique_ptr roundThreads; + LOG_IF_EXCEPTION_OCCURS(roundThreads = std::make_unique(rounds)); + + for (int round = 0; round < rounds; round++) { + roundThreads->Execute([this, round]() { BasicSPSC(round, false, 64); }); + } +} + +void RemoteSendRecvTest::SendSideMultiProducers(int round) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + std::unique_ptr pool; + int poolSize = 10; + LOG_IF_EXCEPTION_OCCURS(pool = std::make_unique(poolSize)); + std::stringstream ss; + int okCnt = 0; + std::vector ids; + ids.resize(15); + for (int i = 0; i < 5; i++) { + std::vector> futs; + std::promise promise; + std::shared_future sFut = promise.get_future(); + for (auto j = 0; j < 2; j++) { + futs.emplace_back(pool->Submit([this, &sFut, i, j, streamName, &ids]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + sFut.get(); + auto producerId = "producer" + std::to_string(i * 5 + j); + ids[i * 3 + j] = producerId; + + // Send. + RETURN_IF_NOT_OK(Produce(producer, producerId, 2 * KB)); + return Status::OK(); + })); + } + futs.emplace_back(pool->Submit([this, &promise, streamName, i, &ids]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + auto consumerId = "consumer" + std::to_string(i); + ids[i * 3 + 2] = consumerId; + promise.set_value(); + + return ConsumeAllClose(consumer, 20'000, false, nullptr, 2); + })); + int index = 0; + for (auto &fut : futs) { + auto status = fut.get(); + ss << ((index % 3 == 2) ? "Consumer: " : "Producer:") << ids[index]; + index++; + if (status.IsError()) { + ss << "----------------------" << status.ToString(); + } else { + okCnt++; + ss << "------------OK------------"; + } + ss << std::endl; + } + LOG(INFO) << FormatString("Finish-iteration: %d", i); + } + // Delete stream in last round + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + LOG(INFO) << "ok:" << okCnt << ", status:" << ss.str(); + EXPECT_EQ(okCnt, 15); +} + +// W1: Two producers. +// W2: Consumer. +// Flush need FIFO for a producer. +TEST_F(RemoteSendRecvTest, TestSendSideMultiProducers) +{ + auto rounds = 2; + std::unique_ptr roundThreads; + LOG_IF_EXCEPTION_OCCURS(roundThreads = std::make_unique(rounds)); + for (int round = 0; round < rounds; round++) { + roundThreads->Execute([this, round]() { SendSideMultiProducers(round); }); + } +} + +void RemoteSendRecvTest::RecvSideAddConsumer(int round) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + std::unique_ptr pool; + LOG_IF_EXCEPTION_OCCURS(pool = std::make_unique(10)); + for (int i = 0; i < 10; i++) { + std::vector> futs; + std::future futs1; + std::promise promise; + std::shared_future sfut = promise.get_future(); + + futs.emplace_back(pool->Submit([this, &sfut, streamName]() { + std::shared_ptr producer; + std::vector> sfuts = { sfut }; + return TestProducerSetup(sfuts, w1Client_, streamName, "producer", producer); + })); + futs1 = pool->Submit([this, &promise, streamName]() { + std::shared_ptr consumer; + RETURN_IF_NOT_OK(TestConsumerSetup("sub1", streamName, w2Client_, promise, consumer)); + return ConsumeAllClose(consumer); + }); + futs.emplace_back(pool->Submit([this, &futs1, &sfut, streamName]() { + std::shared_ptr consumer; + SubscriptionConfig config("sub2", SubscriptionType::STREAM); + sfut.get(); + RETURN_IF_NOT_OK(w2Client_->Subscribe(streamName, config, consumer)); + + std::vector outElements; + std::unordered_map seqNoMap; + int recvNum = 0; + while (!IsThreadFinished(futs1, 0)) { + auto stat = consumer->Receive(1, 0, outElements); + if (stat == Status::OK() && !outElements.empty()) { + const auto &e = outElements.back(); + LOG(INFO) << "Cursor: " << e.id << ", Sz: " << e.size; + + ElementView view(std::string((const char *)e.ptr, e.size)); + + RETURN_IF_NOT_OK(view.VerifyIntegrity()); + RETURN_IF_NOT_OK(view.VerifyFifoInitOff(seqNoMap)); + recvNum++; + } + consumer->Ack(recvNum); + } + CHECK_FAIL_RETURN_STATUS(recvNum <= NUM_ELES, StatusCode::K_RUNTIME_ERROR, ""); + return consumer->Close(); + })); + for (auto &fut : futs) { + fut.get(); + } + futs1.get(); + LOG(INFO) << FormatString("Finish: %d", i); + } + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); +} + +// W1: Producer. +// W: Consumer, then dynamically add another. +TEST_F(RemoteSendRecvTest, DISABLED_TestRecvSideAddConsumer) +{ + auto rounds = 5; + std::unique_ptr roundThreads; + LOG_IF_EXCEPTION_OCCURS(roundThreads = std::make_unique(rounds)); + for (int round = 0; round < rounds; round++) { + roundThreads->Execute([this, round]() { RecvSideAddConsumer(round); }); + } +} + +Status RemoteSendRecvTest::TestConsumerSetup(const std::string &subName, const std::string &streamName, + std::shared_ptr &client, std::promise &promise, + std::shared_ptr &consumer) +{ + SubscriptionConfig config(subName, SubscriptionType::STREAM); + RETURN_IF_NOT_OK(client->Subscribe(streamName, config, consumer)); + promise.set_value(); + return Status::OK(); +} + +Status RemoteSendRecvTest::TestProducerSetup(std::vector> &sFuts, + std::shared_ptr &client, const std::string &streamName, + const std::string &producerName, std::shared_ptr &producer) +{ + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, defaultProducerConf_)); + + // Send. + for (auto &sFut : sFuts) { + sFut.get(); + } + return Produce(producer, producerName, 10); +} + +Status RemoteSendRecvTest::CheckFuts(std::vector> &futs) +{ + for (auto &fut : futs) { + RETURN_IF_NOT_OK(fut.get()); + } + return Status::OK(); +} + +void RemoteSendRecvTest::SendSideConsumer(int round) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + std::unique_ptr pool; + LOG_IF_EXCEPTION_OCCURS(pool = std::make_unique(10)); + for (int i = 0; i < 10; i++) { + std::vector> futs; + std::vector> promises(3); + std::vector> sFuts; + for (auto &promise : promises) { + sFuts.emplace_back(promise.get_future()); + } + futs.emplace_back(pool->Submit([this, &promises, streamName]() { + std::shared_ptr consumer; + RETURN_IF_NOT_OK(TestConsumerSetup("sub1", streamName, w1Client_, promises[0], consumer)); + + return ConsumeAllClose(consumer); + })); + futs.emplace_back(pool->Submit([this, &promises, streamName]() { + std::shared_ptr consumer; + RETURN_IF_NOT_OK(TestConsumerSetup("sub2", streamName, w1Client_, promises[1], consumer)); + + Timer timer; + uint64_t sz; + RETURN_IF_NOT_OK(ConsumeAllClose(consumer, 5'000, true, &sz)); + auto elapsed = timer.ElapsedSecond(); + LOG(INFO) << FormatString("Elapsed: [%.6lf]s, Throughput: [%.6lf] MB/s", elapsed, + sz / timer.ElapsedSecond() / MB); + return Status::OK(); + })); + futs.emplace_back(pool->Submit([this, &sFuts, streamName]() { + std::shared_ptr producer; + return TestProducerSetup(sFuts, w1Client_, streamName, "producer", producer); + })); + futs.emplace_back(pool->Submit([this, &promises, streamName]() { + std::shared_ptr consumer; + RETURN_IF_NOT_OK(TestConsumerSetup("sub3", streamName, w2Client_, promises[2], consumer)); + return ConsumeAllClose(consumer); + })); + DS_ASSERT_OK(CheckFuts(futs)); + if (i == 9) { // Delete stream in last round + std::this_thread::sleep_for(std::chrono::seconds(2)); + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + } + LOG(INFO) << FormatString("Stream:%d, Finish: %d", round, i); + } +} + +// W1: Producer, Consumer. +// W2: Consumer. +TEST_F(RemoteSendRecvTest, DISABLED_TestSendSideConsumer) +{ + auto rounds = 5; + std::unique_ptr roundThreads; + LOG_IF_EXCEPTION_OCCURS(roundThreads = std::make_unique(rounds)); + for (int round = 0; round < rounds; round++) { + roundThreads->Execute([this, round]() { SendSideConsumer(round); }); + } +} + +void RemoteSendRecvTest::CreateStream_different_client(std::string base_data, int pagesize, std::string name, + int send_num) { + for (int m = 0; m < 1000; m++) { + std::shared_ptr client, client2; + InitStreamClient(0, client); + InitStreamClient(0, client2); + + ProducerConf conf; + conf.maxStreamSize = 7 * MB; + conf.pageSize = pagesize * KB; + std::shared_ptr producer; + std::shared_ptr consumer; + std::string streamName = "Stream_" + RandomData().GetRandomString(12); + std::string subName = "Sub_" + RandomData().GetRandomString(13); + SubscriptionConfig config(subName, SubscriptionType::STREAM); + DS_ASSERT_OK(client->CreateProducer(streamName, producer, conf)); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + for (int i = 0; i < send_num; i++) { + std::string data = base_data + RandomData().GetRandomString(13); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, 10000, outElements)); + if (outElements.size() == 1) { + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + DS_ASSERT_OK(consumer->Ack(outElements[0].id)); + } else { + LOG(INFO) << name << " outElements.size() is 0"; + } + } + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client->DeleteStream(streamName)); + + DS_ASSERT_OK(client->ShutDown()); + DS_ASSERT_OK(client2->ShutDown()); + } +} + +TEST_F(RemoteSendRecvTest, DISABLED_TestMemoryAllocationOverflow) +{ + std::unique_ptr roundThreads; + LOG_IF_EXCEPTION_OCCURS(roundThreads = std::make_unique(10)); + std::string base_data = RandomData().GetRandomString(1020 * 1020); + roundThreads->Execute([this, base_data]() { + CreateStream_different_client(base_data, 1024, "continuous_data", 5); + }); +} + +void RemoteSendRecvTest::BothDirectionTestCreateProducers(std::vector> &futs, + std::shared_ptr &pool, + std::vector> &sFuts, + const std::string &streamName, + std::vector> &producers) +{ + futs.emplace_back(pool->Submit([this, &sFuts, streamName, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(TestProducerSetup(sFuts, w1Client_, streamName, "producer1", producer)); + producers.emplace_back(std::move(producer)); + return Status::OK(); + })); + futs.emplace_back(pool->Submit([this, &sFuts, streamName, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(TestProducerSetup(sFuts, w2Client_, streamName, "producer2", producer)); + producers.emplace_back(std::move(producer)); + return Status::OK(); + })); + futs.emplace_back(pool->Submit([this, &sFuts, streamName, &producers]() { + std::shared_ptr producer; + RETURN_IF_NOT_OK(TestProducerSetup(sFuts, w3Client_, streamName, "producer3", producer)); + producers.emplace_back(std::move(producer)); + return Status::OK(); + })); +} + +void RemoteSendRecvTest::BothDirectionTestCreateConsumers(std::vector> &futs, + std::shared_ptr &pool, + std::vector> &promises, + const std::string &streamName) +{ + futs.emplace_back(pool->Submit([this, &promises, streamName]() { + std::shared_ptr consumer; + RETURN_IF_NOT_OK(TestConsumerSetup("sub1", streamName, w1Client_, promises[0], consumer)); + return ConsumeAllClose(consumer, 2000); + })); + futs.emplace_back(pool->Submit([this, &promises, streamName]() { + std::shared_ptr consumer; + RETURN_IF_NOT_OK(TestConsumerSetup("sub2", streamName, w1Client_, promises[1], consumer)); + return ConsumeAllClose(consumer, 2000); + })); + futs.emplace_back(pool->Submit([this, &promises, streamName]() { + std::shared_ptr consumer; + RETURN_IF_NOT_OK(TestConsumerSetup("sub3", streamName, w2Client_, promises[2], consumer)); + return ConsumeAllClose(consumer, 2000); + })); +} + +void RemoteSendRecvTest::BothDirection(int round) +{ + auto streamName = FormatString("%s-%d", streamName_, round); + std::shared_ptr pool; + LOG_IF_EXCEPTION_OCCURS(pool = std::make_shared(10)); + for (int i = 0; i < 10; i++) { + LOG(INFO) << FormatString("===================== [Round: %d] Start =====================", i); + std::vector> producers; + std::vector> futs; + std::vector> promises(3); + std::vector> sFuts; + for (auto &promise : promises) { + sFuts.emplace_back(promise.get_future()); + } + // create producers + BothDirectionTestCreateProducers(futs, pool, sFuts, streamName, producers); + // create consumers + BothDirectionTestCreateConsumers(futs, pool, promises, streamName); + DS_ASSERT_OK(CheckFuts(futs)); + for (auto &producer : producers) { + ASSERT_EQ(producer->Close(), Status::OK()); + } + if (i == 9) { // Delete stream in last round + std::this_thread::sleep_for(std::chrono::seconds(2)); + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); + } + LOG(INFO) << FormatString("Finish: %d", i); + LOG(INFO) << FormatString("===================== [Round: %d] End =====================", i); + } +} + +// W1: Producer, 2Consumer. +// W2: Producer, Consumer. +// W3: Producer. +TEST_F(RemoteSendRecvTest, DISABLED_TestBothDirection) +{ + auto rounds = 5; + std::unique_ptr roundThreads; + LOG_IF_EXCEPTION_OCCURS(roundThreads = std::make_unique(rounds)); + for (int round = 0; round < rounds; round++) { + roundThreads->Execute([this, round]() { BothDirection(round); }); + } +} + +TEST_F(RemoteSendRecvTest, TestProducerCloseAck) +{ + std::shared_ptr producer; + auto streamName = "stream-test-producer"; + w1Client_->CreateProducer(streamName, producer, defaultProducerConf_); + // Send. + Produce(producer, FormatString("producer-stream-test"), 2 * KB); + producer->Close(); + ASSERT_EQ(w1Client_->DeleteStream(streamName), Status::OK()); +} + +TEST_F(RemoteSendRecvTest, TestAutoDeleteWaitFailed) +{ + // Ack thread will get protect delete on stream manager and sleeps for 2secs + // If we get DeleteStreamContext during this time, we will end up waiting for StreamManager to be free + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamManager.AckCursors.delay", "2*sleep(2000)")); + // We set timeout to be 1secs, so we fail the wait in DeleteStreamContext once + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, + "MasterLocalWorkerSCApi.DelStreamContextBroadcast.setTimeout", + "call(1000)")); + std::string streamName = "streamAutoDelWaitFailed"; + defaultProducerConf_.autoCleanup = true; + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + DS_ASSERT_OK(producer->Close()); + // Auto delete retries every 10secs + // and then test if stream is deleted + sleep(10); + // Now stream should have been deleted + ASSERT_EQ(w1Client_->DeleteStream(streamName).GetCode(), StatusCode::K_NOT_FOUND); +} + +TEST_F(RemoteSendRecvTest, DISABLED_TestAutoDeleteDeadlock) +{ + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamManager.AckCursors.delay", "1*sleep(100000)")); + std::string streamName = "stream"; + ThreadPool threadPool(7); + std::vector> futs; + defaultProducerConf_.autoCleanup = true; + for (u_int i = 0; i < 7; i++) { + auto fut = threadPool.Submit([&]() { + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + for (uint32_t j = 0; j < 100000; j++) { + std::string data = "data_" + std::to_string(i); + Element element((uint8_t *)data.data(), data.size()); + DS_ASSERT_OK(producer->Send(element)); + } + DS_ASSERT_OK(producer->Close()); + }); + futs.push_back(std::move(fut)); + } + + for (auto &fut : futs) { + fut.get(); + } +} + +TEST_F(RemoteSendRecvTest, TestRemotePushToMultiNode) +{ + std::string streamName = "streamRemotePushMultiNode"; + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + std::shared_ptr consumer1; + std::shared_ptr consumer2; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config1, consumer1)); + DS_ASSERT_OK(w3Client_->Subscribe(streamName, config2, consumer2)); + + const uint32_t testCount = 100; + const uint32_t waitTime = 3000; // 3s. + for (uint32_t i = 0; i < testCount; i++) { + std::string data = "data_" + std::to_string(i); + Element element((uint8_t *)data.data(), data.size()); + DS_ASSERT_OK(producer->Send(element)); + } + std::vector outElements1; + DS_ASSERT_OK(consumer1->Receive(testCount, waitTime, outElements1)); + EXPECT_EQ(outElements1.size(), testCount); + + std::vector outElements2; + DS_ASSERT_OK(consumer2->Receive(testCount, waitTime, outElements2)); + EXPECT_EQ(outElements2.size(), testCount); +} + +void SendHelper(std::shared_ptr producer, Element element) +{ + const int DEFAULT_SLEEP_TIME = 300; + int retryLimit = 30; + datasystem::Status rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + } + DS_ASSERT_OK(rc); +} + +TEST_F(RemoteSendRecvTest, TestRemotePushToMultiNodeDifferentOrder) +{ + // Create a producer and consumer for a stream on Node1 and Node2 + // Having a consumer for any stream will prevent remoteWorker from getting deleted + std::string streamName1 = "testMultiNodeDiffOrder_s1"; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName1, producer1, defaultProducerConf_)); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName1, config1, consumer1)); + + // Create a new stream for our actual test + std::string streamName2 = "testMultiNodeDiffOrder_s2"; + defaultProducerConf_.maxStreamSize = 2*1024*1024; + defaultProducerConf_.autoCleanup = true; + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName2, producer2, defaultProducerConf_)); + + // Create a consumer on same node + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName2, config2, consumer2)); + + // Close producer and consumer + producer2->Close(); + consumer2->Close(); + sleep(K_TEN); // wait for AutoDelete + + // Now Repeat above steps + std::shared_ptr producer3; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName2, producer3, defaultProducerConf_)); + + // Create a consumer on different node + // Now worker will try to use AckCursor from previous RemoteWorker + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client_->Subscribe(streamName2, config3, consumer3)); + + const uint32_t testCount = 1000; + const uint32_t eleSz = 2*1024; + const uint32_t waitTime = 3000; // 3s. + ElementGenerator elementGenerator(eleSz + 1, eleSz); + auto elements = elementGenerator.GenElements("producer", testCount, 8ul); + ThreadPool threadPool(1); + uint64_t numElements = 0; + auto fut = threadPool.Submit([&]() { + while (numElements < testCount) { + std::vector outElements2; + DS_ASSERT_OK(consumer3->Receive(100, waitTime, outElements2)); + if (!outElements2.empty()) { + numElements += outElements2.size(); + LOG(INFO) << "Got num Elements: "<Ack(outElements2.back().id)); + } + } + }); + elements = elementGenerator.GenElements("producer", testCount, 8ul); + for (auto ele : elements) { + SendHelper(producer3, Element((uint8_t *)(ele.data()), ele.size())); + } + fut.get(); + // After receive is done, check if we got enough count + EXPECT_EQ(numElements, testCount); +} + +TEST_F(RemoteSendRecvTest, DISABLED_TestRemoteSendTimeout) +{ + std::shared_ptr producer; + ProducerConf conf; + conf.delayFlushTime = 1000; + conf.maxStreamSize = TEST_STREAM_SIZE; + DS_ASSERT_OK(w1Client_->CreateProducer("stream", producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe("stream", config, consumer)); + + std::string str("abcabc"); + Element e((uint8_t *)str.data(), str.length()); + DS_ASSERT_OK(producer->Send(e)); + + const int receiveTimeout = 1000; // 1s; + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, receiveTimeout, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + + const int remoteReceiveTimeout = 10000; // 10s; + std::this_thread::sleep_for(std::chrono::milliseconds(remoteReceiveTimeout)); + + const size_t loopCount = 3; + for (size_t i = 0; i < loopCount; i++) { + DS_ASSERT_OK(producer->Send(e)); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + + DS_ASSERT_OK(consumer->Receive(loopCount, receiveTimeout, outElements)); + ASSERT_EQ(outElements.size(), loopCount); +} + +TEST_F(RemoteSendRecvTest, TestFlowControl) +{ + std::string streamName = "streamFlowCtrl"; + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = 16 * 1024; + conf.maxStreamSize = 64 * 1024; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + const uint32_t eleSz = 512; + const uint32_t eleNum = 200; + ElementGenerator elementGenerator(eleSz); + auto strs = elementGenerator.GenElements("producer1", eleNum, 1); + const int64_t timeoutMs = 1000; + for (uint32_t i = 0; i < eleNum; i++) { + DS_ASSERT_OK(producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()), timeoutMs)); + } + const uint32_t recvNum = 100; + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(recvNum, 5000, outElements)); + ASSERT_EQ(outElements.size(), recvNum); + // Do not ack for some time so then the max stream size is reached on Consumer side worker + std::this_thread::sleep_for(std::chrono::seconds(5)); + DS_ASSERT_OK(consumer->Ack(recvNum)); + + // Retry receive upon failure for testcase stability purposes + uint32_t remaining = recvNum; + int retryCount = 3; + while (remaining > 0 && retryCount-- > 0) { + DS_ASSERT_OK(consumer->Receive(remaining, 5000, outElements)); + uint32_t received = outElements.size(); + DS_ASSERT_OK(consumer->Ack(received)); + remaining -= received; + } + ASSERT_EQ(remaining, 0); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, streamName)); +} + +class RemoteSendRecvBigElementTest : public RemoteSendRecvTest +{ +public: + const int minThreads = 20; + const int maxThreads = 128; + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + RemoteSendRecvTest::SetClusterSetupOptions(opts); + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = + "-shared_memory_size_mb=512 -client_dead_timeout_s=15 -enable_stream_data_verification=true"; + opts.vLogLevel = SC_INTERNAL_LOG_LEVEL; + } + void SetUp() override + { + RemoteSendRecvTest::SetUp(); + const int numPages = 16; + defaultProducerConf_.pageSize = 1 * MB; + defaultProducerConf_.maxStreamSize = defaultProducerConf_.pageSize * numPages; + defaultProducerConf_.retainForNumConsumers = 1; + pool = std::make_unique(minThreads, maxThreads); + allClients_.push_back(w1Client_.get()); + allClients_.push_back(w2Client_.get()); + allClients_.push_back(w3Client_.get()); + } + void TearDown() override + { + RemoteSendRecvTest::TearDown(); + } + +protected: + std::unique_ptr pool; + std::vector allClients_; + + Status FillMemoryUntilOOM(const std::string &streamName, size_t numProducersPerWorker, size_t minEleSz, + size_t maxEleSz, size_t &totalElements) + { + std::atomic numInserted = 0; + std::vector> fut; + for (auto *client : allClients_) { + for (size_t i = 0; i < numProducersPerWorker; ++i) { + fut.emplace_back( + pool->Submit([this, client, &streamName, &numInserted, &minEleSz, &maxEleSz]() -> Status { + RandomData rand; + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, defaultProducerConf_)); + // Send small elements until EOM + Status rc; + while (rc.IsOk()) { + auto eleSz = rand.GetRandomUint64(minEleSz, maxEleSz + 1); + auto str = rand.GetRandomString(eleSz); + rc = producer->Send(Element((uint8_t *)str.data(), str.size())); + if (rc.IsOk()) { + numInserted++; + } + } + return rc; + })); + } + } + for (auto &f : fut) { + auto res = f.get(); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(res.GetCode() == StatusCode::K_OUT_OF_MEMORY, K_RUNTIME_ERROR, + FormatString("Expect OOM but get %s", res.ToString())); + } + totalElements = numInserted.load(std::memory_order_relaxed); + LOG(INFO) << "Total number of small elements insert: " << totalElements; + return Status::OK(); + } + + void CreateProducersAndPush(std::vector> &fut, const std::string &streamName, + int initialNumOfProducersPerWorker, size_t minEleSz, size_t maxEleSz, + size_t totalElements) + { + for (auto *client : allClients_) { + for (auto i = 0; i < initialNumOfProducersPerWorker; ++i) { + fut.emplace_back(pool->Submit([this, client, streamName, totalElements, minEleSz, + maxEleSz]() -> Status { + std::shared_ptr producer; + RETURN_IF_NOT_OK_PRINT_ERROR_MSG(client->CreateProducer(streamName, producer, defaultProducerConf_), + FormatString("[S:%s] CreateProducer failed.", streamName)); + LOG(INFO) << FormatString("[S:%s] CreateProducer success. Number of elements to push %zu", + streamName, totalElements); + size_t numElementSent = 0; + Status rc; + const int bigEleSz = 2 * MB; + RandomData rand; + const int FREQUENCY = 1000; // 0.1% will be big element + while (numElementSent < totalElements) { + auto eleSz = rand.GetRandomUint64(minEleSz, maxEleSz + 1); + size_t sz = (numElementSent > 0 && numElementSent % FREQUENCY == 0) ? bigEleSz : eleSz; + auto str = rand.GetRandomString(sz); + rc = producer->Send(Element((uint8_t *)str.data(), str.size())); + if (rc.IsOk()) { + numElementSent++; + continue; + } + // rest is error case + LOG(ERROR) << FormatString("[S:%s] Fail to send. rc = %s", streamName, rc.ToString()); + if (rc.GetCode() == K_OUT_OF_MEMORY) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + continue; + } + break; + } + LOG(INFO) << FormatString("[S:%s] %zu number of elements pushed", streamName, numElementSent); + return rc; + })); + } + } + } + + void ConsumeAndAckAll(std::vector> &fut, const std::string &streamName, + const size_t totalElements, Optional rand) + { + int idx = allClients_.size() - 1; + if (rand) { + idx = rand.value().GetRandomIndex(allClients_.size()); + } + auto *streamClient = allClients_.at(idx); + LOG(INFO) << FormatString("[S:%s] Create consumer on worker node %d", streamName, idx); + fut.emplace_back(pool->Submit([streamName, totalElements, streamClient]() -> Status { + std::shared_ptr consumer; + SubscriptionConfig localConfig(streamName + "_sub000", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(streamClient->Subscribe(streamName, localConfig, consumer, true)); + LOG(INFO) << FormatString("[S:%s] Total elements expected to receive: %zu", streamName, totalElements); + size_t numElementsReceived = 0; + while (numElementsReceived < totalElements) { + std::vector out; + RETURN_IF_NOT_OK(consumer->Receive(RPC_TIMEOUT, out)); + if (!out.empty()) { + numElementsReceived += out.size(); + LOG(INFO) << FormatString("[S:%s] Received %zu. Remaining %zu", streamName, numElementsReceived, + totalElements - numElementsReceived); + consumer->Ack(out[out.size() - 1].id); + } + } + return Status::OK(); + })); + } + + void SendOneBigElement(std::vector> &fut, const std::string &streamName) + { + for (auto *client : allClients_) { + fut.emplace_back(pool->Submit([this, client, streamName]() -> Status { + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, defaultProducerConf_)); + // Send a big element + const int bigEleSz = 4 * MB; + RandomData rand; + auto str = rand.GetRandomString(bigEleSz); + Status rc = producer->Send(Element((uint8_t *)str.data(), str.size()), RPC_TIMEOUT); + return rc; + })); + } + } +}; + +TEST_F(RemoteSendRecvBigElementTest, TestBigElementFairness1) +{ + const std::string streamName = "BigElementFairness1"; + const int initialNumOfProducersPerWorker = 3; + const size_t minEleSz = 48; + const size_t maxEleSz = 1024; + size_t totalElements = 0; + + // Set up OOM on all the workers without any consumer to consume + DS_ASSERT_OK(FillMemoryUntilOOM(streamName, initialNumOfProducersPerWorker, minEleSz, maxEleSz, totalElements)); + + // Now we create three BigElement producer, and specify a timeout, and then finally + // create a consumer. + std::vector> futs1; + SendOneBigElement(futs1, streamName); + totalElements += allClients_.size(); + ConsumeAndAckAll(futs1, streamName, totalElements, Optional()); + + for (auto &f : futs1) { + auto res = f.get(); + DS_ASSERT_OK(res); + } +} + +TEST_F(RemoteSendRecvBigElementTest, TestBigElementFairness2) +{ + const std::string streamName = "BigElementFairness2"; + const int initialNumOfProducersPerWorker = 3; + const size_t minEleSz = 48; + const size_t maxEleSz = 1024; + size_t totalElements = 0; + // Set up OOM on all the workers without any consumer to consume + DS_ASSERT_OK(FillMemoryUntilOOM(streamName, initialNumOfProducersPerWorker, minEleSz, maxEleSz, totalElements)); + + // Create a consumer to consume all the rows with 3 additional big element rows which come later + std::vector> futs1; + totalElements += allClients_.size(); + ConsumeAndAckAll(futs1, streamName, totalElements, Optional()); + SendOneBigElement(futs1, streamName); + + for (auto &f : futs1) { + auto res = f.get(); + DS_ASSERT_OK(res); + } +} + +TEST_F(RemoteSendRecvBigElementTest, DISABLED_TestBigElementFairness3) +{ + // README + // The numStreams has been decreased from 5 to reduce run time during CI + // To run the intended load locally, edit the value. + const int initialNumOfProducersPerWorker = 4; + const size_t minEleSz = 48; + const size_t maxEleSz = 1024; + size_t totalElements = 12'000; + const int numStreams = 2; + const int width = 3; + + std::vector> futs1; + RandomData rand; + for (int i = 0; i < numStreams; ++i) { + std::stringstream oss; + oss << "stream" << std::setw(width) << std::setfill('0') << i; + const std::string streamName = oss.str(); + LOG(INFO) << "Create stream " << streamName; + // Create a consumer to consume everything. + ConsumeAndAckAll(futs1, streamName, totalElements * allClients_.size() * initialNumOfProducersPerWorker, + Optional(rand)); + // Create a few producers. + CreateProducersAndPush(futs1, streamName, initialNumOfProducersPerWorker, minEleSz, maxEleSz, totalElements); + } + + for (auto &f : futs1) { + auto res = f.get(); + DS_ASSERT_OK(res); + } +} + +TEST_F(RemoteSendRecvBigElementTest, DISABLED_TestBigElementFairness4) +{ + const std::string streamName = "BigElementFairness4"; + const int initialNumOfProducersPerWorker = 3; + const size_t minEleSz = 48; + const size_t maxEleSz = 1024; + size_t totalElements = 0; + // Set up OOM on all the workers without any consumer to consume + DS_ASSERT_OK(FillMemoryUntilOOM(streamName, initialNumOfProducersPerWorker, minEleSz, maxEleSz, totalElements)); + + // Create a consumer to consume a few rows and then exit. + auto pid = fork(); + if (pid == 0) { + std::shared_ptr scClient; + InitStreamClient(2, scClient); // index is 2 + std::shared_ptr consumer; + SubscriptionConfig localConfig(streamName + "_sub000", SubscriptionType::STREAM); + DS_ASSERT_OK(scClient->Subscribe(streamName, localConfig, consumer, true)); + LOG(INFO) << FormatString("[S:%s] Total elements expected to receive: %zu", streamName, totalElements); + size_t numElementsReceived = 0; + const size_t exitThreshold = 500; + while (numElementsReceived < totalElements) { + std::vector out; + if (numElementsReceived >= exitThreshold) { + break; + } + DS_ASSERT_OK(consumer->Receive(RPC_TIMEOUT, out)); + if (!out.empty()) { + numElementsReceived += out.size(); + LOG(INFO) << FormatString("[S:%s] Received %zu. Remaining %zu", streamName, numElementsReceived, + totalElements - numElementsReceived); + } + } + _exit(0); + } + ASSERT_TRUE(pid > 0); + int status; + waitpid(pid, &status, 0); + // Wait at least client_dead_timeout_s (15s) + const uint64_t sleepMs = 16'000; + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); + // Create a consumer to consume 3 additional big element rows which come later + std::shared_ptr consumer; + SubscriptionConfig localConfig(streamName + "_sub001", SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client_->Subscribe(streamName, localConfig, consumer, true)); + std::vector> futs1; + totalElements = allClients_.size(); + SendOneBigElement(futs1, streamName); + for (auto &f : futs1) { + auto res = f.get(); + DS_ASSERT_OK(res); + } + std::vector out; + DS_ASSERT_OK(consumer->Receive(totalElements, RPC_TIMEOUT, out)); + ASSERT_EQ(out.size(), totalElements); +} + +TEST_F(RemoteSendRecvBigElementTest, TestReclaimMemoryReuseStream) +{ + const std::string streamName = "ReclaimMemoryReuseStream"; + const int initialNumOfProducersPerWorker = 2; + const size_t minEleSz = 48; + const size_t maxEleSz = 1024; + size_t totalElements = 1'000; + + // Create a consumer to consume everything. + std::vector> futs1; + auto totalCount = totalElements * allClients_.size() * initialNumOfProducersPerWorker; + ConsumeAndAckAll(futs1, streamName, totalCount, Optional()); + + // Create a few producers. + for (auto *client : allClients_) { + for (auto i = 0; i < initialNumOfProducersPerWorker; ++i) { + futs1.emplace_back(pool->Submit([this, client, &streamName, totalElements]() -> Status { + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, defaultProducerConf_)); + // 1% will be big elements + size_t numElementSent = 0; + Status rc; + RandomData rand; + while (numElementSent < totalElements) { + // 0.1% will be big element + auto eleSz = rand.GetRandomUint64(minEleSz, maxEleSz + 1); + auto str = rand.GetRandomString(eleSz); + rc = producer->Send(Element((uint8_t *)str.data(), str.size())); + if (rc.IsOk()) { + numElementSent++; + } + } + return rc; + })); + } + } + + for (auto &f : futs1) { + auto res = f.get(); + DS_ASSERT_OK(res); + } + + futs1.clear(); + // All producers/consumers are closed at this point and all memory are released. + // Reuse the same stream and supposedly to resume the last ack point. + futs1.emplace_back(pool->Submit([this, &streamName, totalCount]() { + std::shared_ptr consumer; + SubscriptionConfig localConfig("sub001", SubscriptionType::STREAM); + RETURN_IF_NOT_OK(w3Client_->Subscribe(streamName, localConfig, consumer, true)); + std::vector out; + RETURN_IF_NOT_OK(consumer->Receive(RPC_TIMEOUT, out)); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(!out.empty(), K_RUNTIME_ERROR, FormatString("Expect not empty")); + CHECK_FAIL_RETURN_STATUS_PRINT_ERROR( + out[0].id == totalCount + 1, K_RUNTIME_ERROR, + FormatString("Id mismatch. Expect %zu but get %zu", totalCount + 1, out[0].id)); + consumer->Ack(out[0].id); + return Status::OK(); + })); + + futs1.emplace_back(pool->Submit([this, &streamName]() -> Status { + std::shared_ptr producer; + RETURN_IF_NOT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + RandomData rand; + auto eleSz = rand.GetRandomUint64(minEleSz, maxEleSz + 1); + auto str = rand.GetRandomString(eleSz); + RETURN_IF_NOT_OK(producer->Send(Element((uint8_t *)str.data(), str.size()))); + return Status::OK(); + })); + + for (auto &f : futs1) { + auto res = f.get(); + DS_ASSERT_OK(res); + } +} + +TEST_F(RemoteSendRecvBigElementTest, TestBlockedReqTimeout1) +{ + const std::string streamName = "TestBlockedReqTimeout1"; + const int initialNumOfProducersPerWorker = 3; + const size_t minEleSz = 48; + const size_t maxEleSz = 1024; + size_t totalElements = 0; + + // Set up OOM on all the workers without any consumer to consume + DS_ASSERT_OK(FillMemoryUntilOOM(streamName, initialNumOfProducersPerWorker, minEleSz, maxEleSz, totalElements)); + // Construct a blocked request with timeout 0ms + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "AddBlockedCreateRequest.subTimeout", "call()")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "AutoAckImpl.WaitAndRetry", "2*call()")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "HandleBlockedRequestImpl.subTimeout", "call()")); + const int FIVE_S = 5; + std::this_thread::sleep_for(std::chrono::seconds(FIVE_S)); + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + // Send a big element + const int bigEleSz = 4 * MB; + RandomData rand; + auto str = rand.GetRandomString(bigEleSz); + Status rc = producer->Send(Element((uint8_t *)str.data(), str.size()), RPC_TIMEOUT); + LOG(INFO) << rc.ToString(); + DS_ASSERT_NOT_OK(rc); +} + +TEST_F(RemoteSendRecvBigElementTest, LEVEL1_TestBlockedReqTimeout2) +{ + const std::string streamName = "BlockedReqTimeout2"; + const int initialNumOfProducersPerWorker = 3; + const size_t minEleSz = 48; + const size_t maxEleSz = 1024; + size_t totalElements = 0; + + // Set up OOM on all the workers without any consumer to consume + DS_ASSERT_OK(FillMemoryUntilOOM(streamName, initialNumOfProducersPerWorker, minEleSz, maxEleSz, totalElements)); + // Construct a blocked request with timeout 5s + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "HandleBlockedRequestImpl.sleep", "1*sleep(10000)")); + // Send a big element + const int bigEleSz = 4 * MB; + RandomData rand; + auto str = rand.GetRandomString(bigEleSz); + const int TWO_S = 2000; + Status rc = producer->Send(Element((uint8_t *)str.data(), str.size()), TWO_S); + LOG(INFO) << rc.ToString(); + ASSERT_EQ(rc.GetCode(), K_OUT_OF_MEMORY); +} + +TEST_F(RemoteSendRecvBigElementTest, TestBlockedReqTimeout3) +{ + const std::string streamName = "BlockedReqTimeout3"; + const int totalElements = 2; + const int workerInx = 2; + // Simulate the case memory is orphaned + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, workerInx, "StreamManager.AllocBigShmMemoryInternalReq.SetTimeoutMs", + "2*return(0)")); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, workerInx, "StreamManager.AllocBigShmMemory.NoHandShake1", "call()")); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, workerInx, "StreamManager.AllocBigShmMemory.NoHandShake2", "return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, workerInx, "ExclusivePageQueue.Ack.Start", "return(K_OK)")); + std::vector> futs1; + ConsumeAndAckAll(futs1, streamName, totalElements, Optional()); + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr localProducer; + DS_ASSERT_OK(w3Client_->CreateProducer(streamName, localProducer, defaultProducerConf_)); + const int bigEleSz = 4 * MB; + RandomData rand; + auto str = rand.GetRandomString(bigEleSz); + DS_ASSERT_OK(producer->Send(Element((uint8_t *)str.data(), str.size()))); + // The above set up will create some orphaned big element. + // If we send one more local big element, we will get OOM + const int FIVE_S = 5000; + const int K_2 = 2; + std::this_thread::sleep_for(std::chrono::seconds(K_2)); + Status rc = localProducer->Send(Element((uint8_t *)str.data(), str.size()), FIVE_S); + LOG(INFO) << rc.ToString(); + ASSERT_EQ(rc.GetCode(), K_OUT_OF_MEMORY); + // Need to kill of the consumer which can't get the element sent by the local producer. + // Easier to send one more dummy elements. + std::string a("bye"); + localProducer->Send(Element((uint8_t *)a.data(), a.size())); +} + +TEST_F(RemoteSendRecvBigElementTest, TestBlockedReqTimeout4) +{ + const std::string streamName = "BlockedReqTimeout4"; + const int totalElements = 1; + const int workerInx = 2; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, workerInx, "ExclusivePageQueue.Ack.Start", "return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, workerInx, "StreamManager.AllocBigShmMemoryInternalReq.sleep", + "1*sleep(5000)")); + std::vector> futs1; + ConsumeAndAckAll(futs1, streamName, totalElements, Optional()); + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + const int bigEleSz = 4 * MB; + RandomData rand; + auto str = rand.GetRandomString(bigEleSz); + DS_ASSERT_OK(producer->Send(Element((uint8_t *)str.data(), str.size()))); + for (auto &f : futs1) { + auto res = f.get(); + DS_ASSERT_OK(res); + } +} + +TEST_F(RemoteSendRecvBigElementTest, TestHandShakeUndo) +{ + const std::string streamName = "testHandShakeUndo"; + const int totalElements = 2; + const int workerInx = 2; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, workerInx, "ExclusivePageQueue.Ack.Start", "return(K_OK)")); + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, workerInx, "BlockedCreateRequest.ReceiverHandShake.sleep", "10*sleep(1000)")); + std::vector> futs1; + ConsumeAndAckAll(futs1, streamName, totalElements, Optional()); + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr localProducer; + DS_ASSERT_OK(w3Client_->CreateProducer(streamName, localProducer, defaultProducerConf_)); + const int bigEleSz = 4 * MB; + RandomData rand; + auto str = rand.GetRandomString(bigEleSz); + DS_ASSERT_OK(producer->Send(Element((uint8_t *)str.data(), str.size()))); + const int K_10 = 10000; + std::this_thread::sleep_for(std::chrono::milliseconds(K_10)); + DS_ASSERT_OK(localProducer->Send(Element((uint8_t *)str.data(), str.size()))); +} + +class RemoteSendRecvDuplicateTest : public RemoteSendRecvTest { +public: + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = 3; + opts.enableDistributedMaster = "false"; + opts.workerGflagParams = " -sc_local_cache_memory_size_mb=1"; + opts.numRpcThreads = 0; + opts.vLogLevel = 2; + SCClientCommon::SetClusterSetupOptions(opts); + } +}; + +TEST_F(RemoteSendRecvDuplicateTest, LEVEL1_TestDuplicateSendOOM) +{ + const int DEFAULT_WAIT_TIME = 60'000; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "RemoteWorker.BatchFlushAsyncRead.rpc.timeout", "2048*call()")); + std::shared_ptr producer; + std::shared_ptr consumer; + + ProducerConf producerConf{ + .delayFlushTime = 20, .pageSize = 512 * KB, .maxStreamSize = 1 * MB, .autoCleanup = false + }; + + std::string streamName("testDupSendOOM"); + + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, producerConf)); + SubscriptionConfig consumerConf("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, consumerConf, consumer)); + + const int numEle = 2; + const int eleSize = 8 * KB; + std::thread consumerThrd([&consumer]() { + int received = 0; + while (received < numEle) { + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + if (!outElements.empty()) { + received += outElements.size(); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + } + } + }); + + std::thread producerThrd([&producer]() { + std::vector writeElement = RandomData().RandomBytes(eleSize); + Element element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + DS_ASSERT_OK(producer->Send(element)); + const int sleepMs = 10'000; + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); + DS_ASSERT_OK(producer->Send(element)); + }); + + producerThrd.join(); + consumerThrd.join(); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, 0, "RemoteWorker.BatchFlushAsyncRead.rpc.timeout")); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/reset_stream_test.cpp b/tests/st/client/stream_cache/reset_stream_test.cpp new file mode 100644 index 0000000..e5a3bb4 --- /dev/null +++ b/tests/st/client/stream_cache/reset_stream_test.cpp @@ -0,0 +1,477 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "common.h" +#include "sc_client_common.h" +#include "client/stream_cache/pub_sub_utils.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/common/util/random_data.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class ResetStreamTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = 3; + opts.enableDistributedMaster = "false"; + opts.workerGflagParams = " -v=3"; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTestClientInstance(); + } + + void TearDown() override + { + client_ = nullptr; + client2_ = nullptr; + client3_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + void InitTestClientInstance() + { + int32_t timeoutMs = 60000; // timeout is 60000 ms + InitStreamClient(0, client_, timeoutMs, true); + InitStreamClient(1, client2_, timeoutMs, true); + InitStreamClient(2, client3_, timeoutMs, true); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + + Status DataFlowAfterResetResume(std::shared_ptr producer, std::shared_ptr consumer) + { + // Check if the previous data is received or the new one. + std::string data = "This is different data"; + std::vector outElements; + Element element(reinterpret_cast(&data.front()), data.size()); + RETURN_IF_NOT_OK(producer->Send(element)); + + RETURN_IF_NOT_OK(consumer->Receive(1, 10000, outElements)); + if (outElements.size() == 0) { + return Status(K_RUNTIME_ERROR, "Didn't receive any element"); + } + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + LOG(INFO) << "Original " << data << " actual data " << actualData; + CHECK_FAIL_RETURN_STATUS(data == actualData, K_RUNTIME_ERROR, "received data is different"); + return Status::OK(); + } + + void SendDataContinuously(std::shared_ptr prod1, std::shared_ptr prod2, + std::shared_ptr prod3) + { + std::string data = RandomData().GetRandomString(200); + Element element(reinterpret_cast(&data.front()), data.size()); + Status rc; + while (true) { + rc = prod1->Send(element); + if (rc.IsOk()) { + rc = prod2->Send(element); + } + if (rc.IsOk()) { + rc = prod3->Send(element); + } + if (rc.IsError()) { + EXPECT_EQ(rc.GetCode(), K_SC_STREAM_IN_RESET_STATE); + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } + + Status MultiHopDataMovement(std::string &data, std::shared_ptr prod1, std::shared_ptr prod2, + std::shared_ptr con1, std::shared_ptr con2) + { + Element element(reinterpret_cast(&data.front()), data.size()); + RETURN_IF_NOT_OK(prod1->Send(element)); + + std::vector outElements; + RETURN_IF_NOT_OK(con1->Receive(1, 10000, outElements)); + CHECK_FAIL_RETURN_STATUS(outElements.size() > (size_t) 0, K_RUNTIME_ERROR, "Didn't receive any data1"); + RETURN_IF_NOT_OK(prod1->Send(element)); + std::string actualData1(reinterpret_cast(outElements[0].ptr), outElements[0].size); + CHECK_FAIL_RETURN_STATUS(data == actualData1, K_RUNTIME_ERROR, "received data is different"); + + Element element2(reinterpret_cast(&actualData1.front()), actualData1.size()); + RETURN_IF_NOT_OK(prod2->Send(element2)); + + RETURN_IF_NOT_OK(con2->Receive(1, 10000, outElements)); + CHECK_FAIL_RETURN_STATUS(outElements.size() > (size_t) 0, K_RUNTIME_ERROR, "Didn't receive any data2"); + RETURN_IF_NOT_OK(prod2->Send(element2)); + std::string actualData2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + CHECK_FAIL_RETURN_STATUS(data == actualData2, K_RUNTIME_ERROR, "received data is different"); + return Status::OK(); + } + + Status WorkerRestartAndClientDetect(int workerIdx) + { + // Shutdown worker + cluster_->QuicklyShutdownWorker(workerIdx); + // Restart worker + cluster_->StartNode(WORKER, workerIdx, ""); + cluster_->WaitNodeReady(WORKER, workerIdx); + + // wait for heartbeat interval so client can detect worker restarted + std::this_thread::sleep_for(std::chrono::seconds(6)); + return Status::OK(); + } + void CreatePubSubForClients(std::vector> &clients, int clientCount, + std::vector> &producers, + std::vector> &consumers, + std::vector> &clientStreams); + void CreatePubSubs(std::shared_ptr client, std::vector> &producers, + std::vector> &consumers, int prodStart, int prodEnd, int conStart, + int conEnd, std::vector> &clientStreams); + + std::shared_ptr client_ = nullptr; + std::shared_ptr client2_ = nullptr; + std::shared_ptr client3_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(ResetStreamTest, CloseProducersAndConsumers) +{ + std::shared_ptr producer; + std::string streamName = "testCloseProdsAndCons"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + std::vector streamNames; + streamNames.push_back(streamName); + + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + + producer.reset(); + consumer.reset(); +} + +TEST_F(ResetStreamTest, ResetSingleStreamsLocalSubs) +{ + std::shared_ptr producer1; + std::string streamName1 = "testResetStreamLocalSub"; + DS_ASSERT_OK(client_->CreateProducer(streamName1, producer1, defaultProducerConf_)); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName1, config1, consumer1)); + + std::vector streamNames; + streamNames.push_back(streamName1); +} + +TEST_F(ResetStreamTest, ResetSingleStreamsRemoteSubs) +{ + std::shared_ptr producer1; + std::string streamName1 = "testResetStreamRemotelSub"; + DS_ASSERT_OK(client_->CreateProducer(streamName1, producer1, defaultProducerConf_)); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2_->Subscribe(streamName1, config1, consumer1)); + + std::vector streamNames; + streamNames.push_back(streamName1); +} + +TEST_F(ResetStreamTest, ResetMultiStreamsSingleReset) +{ + std::shared_ptr producer1; + std::string streamName1 = "testMultiStreamSingleReset"; + DS_ASSERT_OK(client_->CreateProducer(streamName1, producer1, defaultProducerConf_)); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName1, config1, consumer1)); + std::shared_ptr producer2; + std::string streamName2 = "test2"; + DS_ASSERT_OK(client_->CreateProducer(streamName2, producer2, defaultProducerConf_)); + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName2, config2, consumer2)); + + std::shared_ptr producer3; + std::string streamName3 = "test3"; + DS_ASSERT_OK(client_->CreateProducer(streamName3, producer3, defaultProducerConf_)); + + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName3, config3, consumer3)); + + std::vector streamNames; + streamNames.push_back(streamName1); + streamNames.push_back(streamName2); + streamNames.push_back(streamName3); + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(producer3->Close()); + DS_ASSERT_OK(consumer3->Close()); + DS_ASSERT_OK(client_->DeleteStream(streamName1)); +} + +TEST_F(ResetStreamTest, ResetMultiStreamsMultiResets) +{ + std::shared_ptr producer1; + std::string streamName1 = "testMultiStreamMultiReset"; + DS_ASSERT_OK(client_->CreateProducer(streamName1, producer1, defaultProducerConf_)); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName1, config1, consumer1)); + std::shared_ptr producer2; + std::string streamName2 = "test2"; + DS_ASSERT_OK(client_->CreateProducer(streamName2, producer2, defaultProducerConf_)); + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName2, config2, consumer2)); + + std::shared_ptr producer3; + std::string streamName3 = "test3"; + DS_ASSERT_OK(client_->CreateProducer(streamName3, producer3, defaultProducerConf_)); + + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName3, config3, consumer3)); + + std::vector streamNames1; + streamNames1.push_back(streamName1); + streamNames1.push_back(streamName2); + std::vector streamNames2; + streamNames2.push_back(streamName3); + + streamNames1.push_back(streamName3); + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(client_->DeleteStream(streamName1)); +} + +TEST_F(ResetStreamTest, ResetDeleteStream) +{ + std::shared_ptr producer; + std::string streamName = "testResetDelStream"; + DS_ASSERT_OK(client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName, config, consumer)); + + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + + std::vector streamNames; + streamNames.push_back(streamName); +} + +TEST_F(ResetStreamTest, ResetStreamsCrossDependencyTest1) +{ + std::shared_ptr producer1, producer2; + std::string streamName1 = "testResetCrossDependency_s1"; + std::string streamName2 = "testResetCrossDependency_s2"; + DS_ASSERT_OK(client_->CreateProducer(streamName1, producer1, defaultProducerConf_)); + + std::shared_ptr consumer1, consumer2; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe(streamName2, config1, consumer1)); + + std::shared_ptr client1; + InitStreamClient(0, client1); + DS_ASSERT_OK(client1->CreateProducer(streamName2, producer2, defaultProducerConf_)); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName1, config2, consumer2)); + + std::vector streamNames1, streamNames2; + streamNames1.push_back(streamName1); + streamNames1.push_back(streamName2); + streamNames2.push_back(streamName2); + streamNames2.push_back(streamName1); +} + +void ResetStreamTest::CreatePubSubs( + std::shared_ptr client, std::vector> &producers, + std::vector> &consumers, int prodStart, int prodEnd, int conStart, + int conEnd, std::vector> &clientStreams) +{ + std::string streamPrefix = "ResetStreamsCrossDependencyTest2"; + std::vector clientStream; + if (prodEnd > 0) { + for (int i = prodStart; i <= prodEnd; i++) { + std::string streamName = streamPrefix + std::to_string(i); + std::shared_ptr producer; + DS_ASSERT_OK(client->CreateProducer(streamName, producer, defaultProducerConf_)); + producers.push_back(producer); + clientStream.push_back(streamName); + } + } + if (conEnd > 0) { + for (int i = conStart; i <= conEnd; i++) { + std::string streamName = streamPrefix + std::to_string(i); + std::shared_ptr consumer; + std::string subName = "sub" +std::to_string(clientStreams.size()) + "_" + std::to_string(i); + SubscriptionConfig config(subName, SubscriptionType::STREAM); + DS_ASSERT_OK(client->Subscribe(streamName, config, consumer)); + consumers.push_back(consumer); + clientStream.push_back(streamName); + } + } + clientStreams.push_back(clientStream); +} + +void ResetStreamTest::CreatePubSubForClients(std::vector> &clients, int clientCount, + std::vector> &producers, + std::vector> &consumers, + std::vector> &clientStreams) +{ + // have three blocks of streams. 1-16, 17-20 and 21-21. + const int b1Range = 16; + const int b2Range = 20; + const int b3Range = 21; + int clientIdx = 0; + + // Five client connected to worker 1 and one client connected to worker 2 + for (int i = 1; i < clientCount ; i++) { + std::shared_ptr client; + InitStreamClient(0, client); + clients.push_back(client); + } + + std::shared_ptr client; + InitStreamClient(1, client); + clients.push_back(client); + client->Init(false); + + // Create pubs/subs for clients in the specific stream blocks + CreatePubSubs(clients[clientIdx++], producers, consumers, 1, b1Range, b1Range + 1, b2Range, clientStreams); + CreatePubSubs(clients[clientIdx++], producers, consumers, -1, -1, 1, b1Range, clientStreams); + CreatePubSubs(clients[clientIdx++], producers, consumers, 1, b1Range, b1Range + 1, b2Range, clientStreams); + CreatePubSubs(clients[clientIdx++], producers, consumers, b2Range + 1, b3Range, -1, -1, clientStreams); + CreatePubSubs(clients[clientIdx++], producers, consumers, -1, -1, 1, b1Range, clientStreams); + CreatePubSubs(clients[clientIdx++], producers, consumers, b1Range + 1, b2Range, b2Range + 1, b3Range, + clientStreams); +} + +TEST_F(ResetStreamTest, LEVEL1_TestClientLostWorkerSameNode) +{ + std::shared_ptr prod1; + std::string streamName = "workerlosttestsamenode"; + DS_ASSERT_OK(client_->CreateProducer(streamName, prod1, defaultProducerConf_)); + + std::vector streamNames; + streamNames.push_back(streamName); + + DS_ASSERT_OK(WorkerRestartAndClientDetect(0)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + Status rc = prod1->Send(element); + ASSERT_EQ(rc.GetCode(), StatusCode::K_SC_WORKER_WAS_LOST); + // close client and producer + client_.reset(); + prod1.reset(); + + InitStreamClient(0, client_); + DS_ASSERT_OK(client_->CreateProducer(streamName, prod1, defaultProducerConf_)); + + DS_ASSERT_OK(prod1->Send(element)); +} + +TEST_F(ResetStreamTest, LEVEL1_TestClientLostWorkerCrossNode) +{ + std::shared_ptr prod1; + std::string streamName = "workerlosttestcrossnode"; + DS_ASSERT_OK(client_->CreateProducer(streamName, prod1, defaultProducerConf_)); + + std::shared_ptr con1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2_->Subscribe(streamName, config, con1)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(prod1->Send(element)); + + std::vector streamNames; + streamNames.push_back(streamName); + + DS_ASSERT_OK(WorkerRestartAndClientDetect(1)); + DS_ASSERT_OK(prod1->Send(element)); + + std::vector outElements; + Status rc = con1->Receive(1, 1000, outElements); + ASSERT_EQ(rc.GetCode(), StatusCode::K_SC_WORKER_WAS_LOST); + // close client2 and consumer + client2_.reset(); + con1.reset(); + InitStreamClient(1, client2_); + DS_ASSERT_OK(client2_->Subscribe(streamName, config, con1)); +} + +TEST_F(ResetStreamTest, LEVEL1_TestClientNotTrackLostWorker) +{ + // Set up two clients to not track worker lost between heartbeats. One explicit and one default. + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + std::shared_ptr client1, client2; + InitStreamClient(0, client1); + InitStreamClient(0, client2); + + std::shared_ptr prod1, prod2; + std::string streamName = "workerlosttestnottracked"; + DS_ASSERT_OK(client1->CreateProducer(streamName, prod1, defaultProducerConf_)); + + std::shared_ptr con1, con2; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, con1)); + + DS_ASSERT_OK(WorkerRestartAndClientDetect(0)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + // After restart, producer Id is erased at the worker and the existing producer can't send. + DS_ASSERT_NOT_OK(prod1->Send(element)); + + std::vector outElements; + Status rc = con1->Receive(1, 1000, outElements); + // After restart, consumer is lost at the worker. However, the error code is not K_SC_WORKER_WAS_LOST. + DS_ASSERT_NOT_OK(rc); + ASSERT_NE(rc.GetCode(), StatusCode::K_SC_WORKER_WAS_LOST); + + // However, we can add new producers and consumers as the worker lost is not tracked. + DS_ASSERT_OK(client1->CreateProducer(streamName, prod2, defaultProducerConf_)); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, con2)); + + DS_ASSERT_OK(DataFlowAfterResetResume(prod2, con2)); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/retain_data_test.cpp b/tests/st/client/stream_cache/retain_data_test.cpp new file mode 100644 index 0000000..986e02f --- /dev/null +++ b/tests/st/client/stream_cache/retain_data_test.cpp @@ -0,0 +1,1199 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache on retain data feature + */ +#include +#include + +#include "common.h" +#include "datasystem/common/encrypt/secret_manager.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "common/stream_cache/element_generator.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/stream_client.h" +#include "datasystem/client/stream_cache/stream_client_impl.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "client/stream_cache/pub_sub_utils.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +constexpr int K_TWO = 2; +constexpr int K_TEN = 10; +class RetainDataTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.enableDistributedMaster = "false"; + opts.masterIdx = 1; + opts.numWorkers = WORKER_COUNT; + opts.vLogLevel = logLevel; + opts.workerGflagParams += FormatString(" -node_timeout_s=%d -node_dead_timeout_s=%d -client_reconnect_wait_s=2", + nodeTimeout, nodeDead); + SCClientCommon::SetClusterSetupOptions(opts); + } + +protected: + Status InitClient(int index, std::shared_ptr &client) + { + InitStreamClient(index, client); + return Status::OK(); + } + + Status CreateConsumer(std::shared_ptr client, const std::string &streamName, + const std::string &subName, std::shared_ptr &consumer) + { + SubscriptionConfig config(subName, SubscriptionType::STREAM); + return client->Subscribe(streamName, config, consumer); + } + + ProducerConf GetDefaultConf() + { + const int maxStreamSize = 10 * MB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.pageSize = 1 * MB; + return conf; + } + + void CheckCount(std::shared_ptr client, const std::string &streamName, int producerCount, + int consumerCount) + { + uint64_t result = 0; + if (producerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, result)); + EXPECT_EQ(result, static_cast(producerCount)); + result = 0; + } + if (consumerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, result)); + EXPECT_EQ(result, static_cast(consumerCount)); + result = 0; + } + } + + void CreateElement(size_t elementSize, Element &element, std::vector &writeElement) + { + writeElement = RandomData().RandomBytes(elementSize); + element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + } + + Status TryAndDeleteStream(std::shared_ptr spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + const int nodeTimeout = 4; // 4s; + const int nodeDead = nodeTimeout * 3; + const int waitNodeTimeout = nodeTimeout + 2; + const int waitNodeDead = nodeDead + 4; + const int logLevel = 2; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + const uint32_t WORKER_COUNT = 3; + const int DEFAULT_WAIT_TIME = 5000; + const int WAIT_TIME = 1000; + const int DEFAULT_NUM_ELEMENT = 100; + const int SMALL_NUM_ELEMENT = 10; + const size_t TEST_ELEMENT_SIZE = 4 * KB; +}; + +// We do not retain data when retainForNumConsumers == 0 +// Should have no impact on existing flow +TEST_F(RetainDataTest, TestNotRetainData) +{ + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::string streamName = "NotRetainData"; + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 0; + + std::shared_ptr producer; + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < SMALL_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(SMALL_NUM_ELEMENT, WAIT_TIME, outElements)); + // No data received by late consumer + ASSERT_EQ(outElements.size(), 0); +} + +// We retain data when retainForNumConsumers > 0 +TEST_F(RetainDataTest, TestRetainDataSPSCLocalConsumer) +{ + // Test data is retained when producers start to send before a local consumer is created + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; // retain data until one consumer + + std::shared_ptr producer; + // Create producer and send data + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Create a late consumer + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + std::vector outElements; + // Now should get the data + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer->Ack(size_t(DEFAULT_NUM_ELEMENT))); + + outElements.clear(); + // Create a 2nd late consumer + std::shared_ptr consumer2; + SubscriptionConfig config1("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config1, consumer2)); + // Data should be present as we only retain for one consumer + // retainForNumConsumers == 1 + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); +} + +TEST_F(RetainDataTest, TestRetainDataSPSCRemoteConsumer) +{ + // Test data is retained when producers start to send before a remote consumer is created + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + + std::shared_ptr producer; + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer->Ack(size_t(DEFAULT_NUM_ELEMENT))); + DS_ASSERT_OK(consumer->Close()); + + outElements.clear(); + // Data should not be present as we only retain for one consumer + // retainForNumConsumers == 1 + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer2)); + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); +} + +TEST_F(RetainDataTest, TestRetainDataMPMC1) +{ + // Node1: Producer 1 Consumer1 + // Node2: Producer 2 Consumer2 + // retainForNumConsumers == 1 + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(1, client2)); + + // Test config + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = K_TWO; + + // On node 1, Create producer1 + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer1, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer1->Send(element)); + } + + // On node 2, Create producer2 + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(streamName, producer2, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer2->Send(element)); + } + + // On node 1, Create Consumer1 + std::shared_ptr consumer1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer1)); + + // Get local producer data in on node 1 + std::vector outElements; + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Ack(size_t(DEFAULT_NUM_ELEMENT))); + outElements.clear(); + + // On node 2, Create Consumer2 + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer2)); + + // Get data on node 2 + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT * K_TWO, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT) * K_TWO); + outElements.clear(); + + // Get remaining data from node 2 + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Ack(size_t(DEFAULT_NUM_ELEMENT))); +} + +TEST_F(RetainDataTest, TestRetainDataMPMC2) +{ + // Node1: Producer1 Consumer1 + // Node2: Consumer2 + // retainForNumConsumers == 2 + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = K_TWO; + + std::shared_ptr producer; + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + // Remote can receive non of the elements before the expected num of consumers are all created + ASSERT_EQ(outElements.size(), size_t(0)); + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer2)); + + // Now remote consumer can get the elements + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); +} + +TEST_F(RetainDataTest, TestRetainDataMPMC3) +{ + // Test that the node with no producer still gets to release pages + // In other words, test the INIT state + // Node1: Producer1 Consumer2 + // Node2: Consumer1 Consumer3 + // retainForNumConsumers == 2 + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = K_TWO; + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + // Remote can receive non of the elements before the expected num of consumers are all created + ASSERT_EQ(outElements.size(), size_t(0)); + + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer2)); + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer2->Close()); + + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + LOG(INFO) << "Acking element id " << outElements.back().id; + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config3, consumer3)); + DS_ASSERT_OK(consumer3->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); +} + +TEST_F(RetainDataTest, TestRetainDataCreateLocalConsumerFirst) +{ + // Test creating new local producers when subscriber already created + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + // Retain only for one consumer + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + + // Create that one consumer + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + + // Local producer should not have data retained + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // We should get data in remote + std::shared_ptr consumer2; + std::vector outElements; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer2)); + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer2->Close()); +} + +TEST_F(RetainDataTest, LEVEL1_TestRetainDataMPMC4) +{ + // Test that the node with no producer still gets to release pages + // In other words, test that Subscribe does not set INIT to RETAIN + // Node1: Producer1 Consumer3 + // Node2: Consumer1 Consumer2 Consumer4 + // retainForNumConsumers == 3 + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + const int expectedNumConsumer = 3; + conf.retainForNumConsumers = expectedNumConsumer; + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer2)); + + std::shared_ptr consumer3; + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config3, consumer3)); + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer2->Ack(outElements.back().id)); + + std::shared_ptr consumer4; + SubscriptionConfig config4("sub4", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config4, consumer4)); + DS_ASSERT_OK(consumer4->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(0)); +} + +TEST_F(RetainDataTest, TestCloseProducerWhileRetainData1) +{ + // Test that retained data can be received by local consumer after producer is closed + // Create a producer with retainForNumConsumers == 1 + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + // Close only producer + DS_ASSERT_OK(producer->Close()); + + // Create a local consumer + // It should be able to get the retained data + std::vector outElements; + std::shared_ptr consumer1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer1)); + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Close()); +} + +TEST_F(RetainDataTest, TestCloseProducerWhileRetainData2) +{ + // Test that retained data can be received by remote consumer after producer is closed + // Create a producer with retainForNumConsumers == 1 + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + // Close only producer + DS_ASSERT_OK(producer->Close()); + + // Create a remote consumer + // It should be able to get the retained data + std::vector outElements; + std::shared_ptr consumer1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer1)); + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Close()); +} + +TEST_F(RetainDataTest, TestCloseProducerWhileRetainData3) +{ + // Test that retained data can be received by remote consumer after producer is closed + // Create a producer with retainForNumConsumers == 2 + std::shared_ptr client1; + std::shared_ptr client2; + std::shared_ptr client3; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(K_TWO, client3)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client3->Subscribe(streamName, config1, consumer1)); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 2; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < SMALL_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + // Close only producer + DS_ASSERT_OK(producer->Close()); + + // Create a remote consumer + // It should be able to get the retained data + std::vector outElements; + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer2)); + + // Get data from 2nd client + outElements.clear(); + DS_ASSERT_OK(consumer2->Receive(SMALL_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(SMALL_NUM_ELEMENT)); + DS_ASSERT_OK(consumer2->Close()); + + // Get data from 3rd client + outElements.clear(); + DS_ASSERT_OK(consumer1->Receive(SMALL_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(SMALL_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Close()); +} + +TEST_F(RetainDataTest, TestDeleteWhileRetainDataSameNode) +{ + // Create a producer with retainForNumConsumers == 1 + // And Delete the stream + // We should give higher priority to delete stream and allow it + + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < SMALL_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(client1->DeleteStream(streamName)); + + // Create a consumer + // It should not get any data + std::vector outElements; + std::shared_ptr consumer1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer1)); + DS_ASSERT_OK(consumer1->Receive(SMALL_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer1->Close()); +} + +TEST_F(RetainDataTest, TestDeleteWhileRetainDataRemoteNode) +{ + // Create a producer with retainForNumConsumers == 1 + // And Delete the stream in a different node + // We should give higher priority to delete stream and allow it + + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + DS_ASSERT_OK(producer->Close()); + // Delete in a different node + DS_ASSERT_OK(client2->DeleteStream(streamName)); + + // Create a consumer + // It should not get any data + std::vector outElements; + std::shared_ptr consumer1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer1)); + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer1->Close()); +} + +TEST_F(RetainDataTest, LEVEL2_TestProducerCloseRemoteDeleteWhileRetainDataRemoteNode) +{ + // Create a consumer before producer + // Then create a producer with retainForNumConsumers == 2 + // And Delete the stream in a different node + // We should give higher priority to delete stream and allow it + // Delayed ClearAllConsumers should be invoked and Flush is a no-op + + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config1, consumer1)); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 2; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(producer->Close()); + // Now close consumer sends async update topo notifications + // So, we need to wait sometime before doing delete stream + // Delete in a different node + DS_ASSERT_OK(TryAndDeleteStream(client2, streamName)); + + // Create a consumer + // It should not get any data + std::vector outElements; + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer2)); + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer2->Close()); + // Sleep for auto-delete logic to go through + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_WAIT_TIME)); +} + +TEST_F(RetainDataTest, LEVEL1_TestProducerCloseLocalDeleteWhileRetainDataRemoteNode) +{ + // Create a consumer before producer + // Then create a producer with retainForNumConsumers == 2 + // And Delete the stream in the same node as producer + // We should give higher priority to delete stream and allow it + // Delayed ClearAllConsumers should be invoked and Flush is a no-op + + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config1, consumer1)); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 2; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(producer->Close()); + // Now close consumer sends async update topo notifications + // So, we need to wait sometime before doing delete stream + // Delete is called on the same node where ClearAllConsumer is delayed + DS_ASSERT_OK(TryAndDeleteStream(client1, streamName)); + + // Create a consumer + // It should not get any data + std::vector outElements; + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer2)); + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer2->Close()); + // Sleep for auto-delete logic to go through + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_WAIT_TIME)); +} + +TEST_F(RetainDataTest, TestAutoDeleteWhileRetainData1) +{ + // Create a producer + // with retainForNumConsumers == 2 + // And AutoDelete enabled + // We should give higher priority to Auto delete and allow it + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + // Shorten the delay for auto delete, so the auto delete goes through + for (uint32_t i = 0; i < WORKER_COUNT; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "AutoCleanup.AdjustDelay", "call(3000)")); + } + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = K_TWO; + conf.autoCleanup = true; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + // Close only producer + DS_ASSERT_OK(producer->Close()); + + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_WAIT_TIME)); + // Create a consumer + // It should not get any data + std::vector outElements; + std::shared_ptr consumer1; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer1)); + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer1->Close()); +} + +TEST_F(RetainDataTest, TestRetainDataSPMCConsumersComeAndGo) +{ + // In this test case, + // we need to check if master restart + // restores the state + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + // Config + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = K_TWO; + + // Create a new producer - this data should be retained + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Create a consumer - It should see data from producer + std::vector outElements; + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config1, consumer1)); + // Try to receive data + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer1->Ack(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Close()); + outElements.clear(); + + // Create a new consumer on other node - It should see data from producer + // Here current count == 1 but life time consumer count == 2 + outElements.clear(); + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config1, consumer2)); + // Try to receive data + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer2->Ack(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer2->Close()); + outElements.clear(); +} + +// DFX testcases - Master crash +TEST_F(RetainDataTest, LEVEL1_TestMasterRestartWhileRetainData) +{ + // In this test case, + // we need to check if master restart + // restores the state + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + // Config + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + + // Create a new producer - this data should be retained + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Now do a master restart - then check if retainData status remains same + // Check if we have atleast a producer and no consumers + CheckCount(client1, streamName, 1, 0); + CheckCount(client2, streamName, 1, 0); + // Restart the master + cluster_->QuicklyShutdownWorker(1); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + // Extend the sleep time for test case stability purposes + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + // Check if data is restored in the master + CheckCount(client1, streamName, 1, 0); + CheckCount(client2, streamName, 1, 0); + + // Create a new consumer on restarted node - It should see data from producer + std::vector outElements; + std::shared_ptr consumer2, consumer3; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + DS_ASSERT_OK(client2->Subscribe(streamName, config1, consumer2)); + // Try to receive data + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer2->Ack(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer2->Close()); + outElements.clear(); + + // As retainForNumConsumers==1 data should be gone now. + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer3)); + // Try to receive data + DS_ASSERT_OK(consumer3->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer3->Close()); +} + +// Worker crash with consumer +TEST_F(RetainDataTest, TestConsumerWorkerCrashStopRemotePush) +{ + LOG(INFO) << "TestWorkerCrashStopRemotePush start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(K_TWO, client2)); + + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = K_TWO; + + // Create a new producer - this data should be retain + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Worker Restart + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + // Node with consumers - crash happens + cluster_->ShutdownNode(ClusterNodeType::WORKER, K_TWO); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, K_TWO, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, K_TWO)); + + // new consumer should not get this data + // Create a consumer - It should not see data from producer + std::vector outElements; + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config1, consumer1)); + // Try to receive data + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer1->Ack(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Close()); + outElements.clear(); + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer2)); + // Try to receive data + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer2->Close()); + outElements.clear(); +} +// Worker crash with producer +TEST_F(RetainDataTest, LEVEL1_TestProducerWorkerCrashStopRemotePush) +{ + LOG(INFO) << "TestWorkerCrashStopRemotePush start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(K_TWO, client2)); + + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = K_TWO; + + // Create a new producer - this data will be gone due to crash + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Worker Restart + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + // Node with consumers - crash happens + cluster_->QuicklyShutdownWorker(0); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + + // Create a new producer - this data should be retain + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer1, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer1->Send(element)); + } + + // new consumer should not get this data + // Create a consumer - It should not see data from producer + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config1, consumer1)); + + // Existing consumer reads new data - But as its remote consumer it + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer->Ack(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer->Close()); + outElements.clear(); + // Try to receive data + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer1->Ack(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Close()); + outElements.clear(); + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config2, consumer2)); + // Try to receive data + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + DS_ASSERT_OK(consumer2->Close()); + outElements.clear(); +} + +TEST_F(RetainDataTest, TestCreateConsumerRollback) +{ + // Test that when notification fails in the middle for CreateConsumer and triggers rollback, + // the retain data state on the producers are handled correctly + // Node1: Producer 1 + // Node2: Producer 2 + // Node3: Consumer1 + // retainForNumConsumers == 1 + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(1, client2)); + std::shared_ptr client3; + DS_ASSERT_OK(InitClient(K_TWO, client3)); + + // Test config + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + + // On node 1, Create producer1 + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer1, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer1->Send(element)); + } + + // On node 2, Create producer2 + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(streamName, producer2, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer2->Send(element)); + } + + // On node 3, Create Consumer1, but fail with the second NotifyNewConsumer notification + for (uint32_t i = 0; i < WORKER_COUNT; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, i, + "master.SubIncreaseNodeImpl.afterSendNotification", + "1*return(K_RUNTIME_ERROR)")); + } + std::shared_ptr consumer1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_NOT_OK(client3->Subscribe(streamName, config, consumer1)); + + // On node 3, Re-Create Consumer1 + DS_ASSERT_OK(client3->Subscribe(streamName, config, consumer1)); + // Make sure we can get all the data + std::vector outElements; + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT * K_TWO, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT) * K_TWO); +} + +TEST_F(RetainDataTest, TestCreateProducerRollback) +{ + // Test that the retain data state in StreamManager is not set when CreateProducer fails and rolls back + // Node1: Producer 1 Consumer1 + // Node2: Producer 2 + std::shared_ptr client1; + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 1; + + for (uint32_t i = 0; i < WORKER_COUNT; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, i, + "master.PubIncreaseNodeImpl.beforeSendNotification", + "1*return(K_RUNTIME_ERROR)")); + } + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + + // The retain data state should remain in INIT for this test to pass + std::shared_ptr producer1; + DS_ASSERT_NOT_OK(client1->CreateProducer(streamName, producer1, conf)); + + std::shared_ptr producer2; + DS_ASSERT_OK(client2->CreateProducer(streamName, producer2, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer2->Send(element)); + } + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer->Ack(size_t(DEFAULT_NUM_ELEMENT))); + DS_ASSERT_OK(consumer->Close()); + + // Data should be released + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config2, consumer2)); + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); +} + +TEST_F(RetainDataTest, DISABLED_TestResetWhileRetainData) +{ + std::vector outElements; + std::shared_ptr client1; + std::shared_ptr client2; + std::shared_ptr client3; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(K_TWO, client3)); + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf = GetDefaultConf(); + conf.retainForNumConsumers = 3; + + std::shared_ptr consumer1; + DS_ASSERT_OK(CreateConsumer(client2, streamName, "sub1", consumer1)); + + // Create a new producer - this data should be retain + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Create a consumer should not get any data yet + std::shared_ptr consumer2; + DS_ASSERT_OK(CreateConsumer(client3, streamName, "sub2", consumer2)); + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + + // After Reset and Resume there should be no data though retain data condition is met + std::shared_ptr consumer3; + DS_ASSERT_OK(CreateConsumer(client1, streamName, "sub3", consumer3)); + DS_ASSERT_OK(consumer3->Receive(DEFAULT_NUM_ELEMENT, WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 0); + + // Now let the producer add new data + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + DS_ASSERT_OK(producer->Close()); + + // All three consumers should get this data + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer1->Close()); + + DS_ASSERT_OK(consumer2->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer2->Close()); + + DS_ASSERT_OK(consumer3->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer3->Close()); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/sc_client_aksk_auth_test.cpp b/tests/st/client/stream_cache/sc_client_aksk_auth_test.cpp new file mode 100644 index 0000000..4c85319 --- /dev/null +++ b/tests/st/client/stream_cache/sc_client_aksk_auth_test.cpp @@ -0,0 +1,118 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "datasystem/context/context.h" +#include "datasystem/stream/producer.h" + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/utils/sensitive_value.h" +#include "datasystem/common/log/log.h" +#include "sc_client_common.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/stream_client.h" +#include "datasystem/common/util/random_data.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/stream/consumer.h" + +DS_DECLARE_uint32(page_size); +using namespace datasystem::client::stream_cache; + +namespace datasystem { +namespace st { +class SCClientAkSkAuthTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.workerGflagParams = + " -authorization_enable=true "; + opts.numEtcd = 1; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + spClient_ = nullptr; + ExternalClusterTest::TearDown(); + } + + static void ArrToStr(void *data, size_t sz, std::string &str) + { + str.assign(reinterpret_cast(data), sz); + } + +protected: + void InitTest() + { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.akauth", "return(accessKey,secretKey,tenant1)")); + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions = { .host = workerAddress.Host(), + .port = workerAddress.Port(), + .connectTimeoutMs = 60 * 1000, // 60s + .clientPublicKey = "", + .clientPrivateKey = "", + .serverPublicKey = "", + .accessKey = "accessKey", + .secretKey = "secretKey", + .tenantId = "tenant1" }; + spClient_ = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient_->Init()); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + std::shared_ptr spClient_ = nullptr; + ProducerConf defaultProducerConf_; + + Status CreateElement(size_t elementSize, Element &element, std::vector &writeElement) + { + writeElement = RandomData().RandomBytes(elementSize); + element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + return Status::OK(); + } +}; + +TEST_F(SCClientAkSkAuthTest, TestAkSkAuth) +{ + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient_->Subscribe("test", config, consumer)); + + size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + DS_ASSERT_OK(spClient_->CreateProducer("test", producer, defaultProducerConf_)); + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::vector actualData(outElements[0].ptr, outElements[0].ptr + outElements[0].size); + std::vector data(writeElement.data(), writeElement.data() + writeElement.size()); + EXPECT_EQ(data, actualData); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/sc_client_common.h b/tests/st/client/stream_cache/sc_client_common.h new file mode 100644 index 0000000..2e0f5e5 --- /dev/null +++ b/tests/st/client/stream_cache/sc_client_common.h @@ -0,0 +1,59 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: common class of stream client test + */ + +#ifndef DATASYSTEM_UT_SC_CLIENT_COMMON_H +#define DATASYSTEM_UT_SC_CLIENT_COMMON_H + +#include "common.h" +#include "datasystem/stream_client.h" +namespace datasystem { +namespace st { +class SCClientCommon : public ExternalClusterTest { +public: +protected: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + for (size_t i = 0; i < opts.numWorkers; ++i) { + auto port = GetFreePort(); + opts.workerSpecifyGflagParams.emplace(i, FormatString("-sc_worker_worker_direct_port=%d", port)); + } + opts.isStreamCacheCase = true; + } + + void InitStreamClient(uint32_t index, std::shared_ptr &client, int32_t timeoutMs = 60000, + bool reportWorkerLost = false) + { + HostPort workerAddress; + ASSERT_TRUE(index < cluster_->GetWorkerNum()); + DS_ASSERT_OK(cluster_->GetWorkerAddr(index, workerAddress)); + LOG(INFO) << "worker index " << index << ": " << workerAddress.ToString(); + ConnectOptions connectOptions; + connectOptions = { .host = workerAddress.Host(), .port = workerAddress.Port(), .connectTimeoutMs = timeoutMs }; + connectOptions.accessKey = "QTWAOYTTINDUT2QVKYUC"; + connectOptions.secretKey = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + client = std::make_shared(connectOptions); + DS_ASSERT_OK(client->Init(reportWorkerLost)); + } + +private: +}; +} // namespace st +} // namespace datasystem +#endif // DATASYSTEM_UT_SC_CLIENT_COMMON_H diff --git a/tests/st/client/stream_cache/sc_client_evict_object_test.cpp b/tests/st/client/stream_cache/sc_client_evict_object_test.cpp new file mode 100644 index 0000000..9d19ecb --- /dev/null +++ b/tests/st/client/stream_cache/sc_client_evict_object_test.cpp @@ -0,0 +1,151 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include +#include + +#include "common.h" +#include "client/object_cache/oc_client_common.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/kv_client.h" +#include "datasystem/utils/connection.h" +#include "sc_client_common.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class SCClientEvictObjectTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.enableSpill = true; + opts.numEtcd = 1; + opts.workerGflagParams = workerConf_; + opts.injectActions = "worker.Spill.Sync:return()"; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + LOG(INFO) << "start worker for test"; + } + void TearDown() override + { + client_ = nullptr; + ExternalClusterTest::TearDown(); + } + + void StartClusters() + { + ExternalClusterTest::SetUp(); + InitTest(); + } + +protected: + void InitTest() + { + InitStreamClient(0, client_); + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions = { .host = workerAddress.Host(), .port = workerAddress.Port() }; + connectOptions.accessKey = "QTWAOYTTINDUT2QVKYUC"; + connectOptions.secretKey = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + client1_ = std::make_shared(connectOptions); + DS_ASSERT_OK(client1_->Init()); + defaultProducerConf_.maxStreamSize = 25 * 1024 * 1024; // max stream size is 25 * 1024 * 1024 + } + std::shared_ptr client_ = nullptr; + std::shared_ptr client1_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + std::string workerConf_ = ""; +}; + +TEST_F(SCClientEvictObjectTest, TestEvictObject) +{ + workerConf_ = + "--shared_memory_size_mb=50 --sc_shm_threshold_percentage=50 --oc_shm_threshold_percentage=100 -v=3 "; + StartClusters(); + size_t size = 1 * 1024 * 1024; + std::string prifixKey = "object_data_"; + std::string data = randomData_.GetRandomString(size); + for (int i = 0; i < 40; i++) { // object size is 40 + DS_ASSERT_OK(client1_->Set(prifixKey + std::to_string(i), data)); + } + std::shared_ptr producer; + DS_ASSERT_OK(client_->CreateProducer("stream1", producer, defaultProducerConf_)); + for (int i = 0; i < 20; i++) { // stream element num is 20 + Element element1(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element1)); + } +} + +TEST_F(SCClientEvictObjectTest, TestStreamSizeMax) +{ + workerConf_ = + "--shared_memory_size_mb=50 --sc_shm_threshold_percentage=50 --oc_shm_threshold_percentage=100 -v=2"; + StartClusters(); + size_t size = 6 * 1024 * 1024; + std::string prifixKey = "object_data_"; + std::string data = randomData_.GetRandomString(size); + std::shared_ptr producer; + defaultProducerConf_.maxStreamSize = 50 * 1024 * 1024; // max stream size is 50 * 1024 * 1024 + defaultProducerConf_.pageSize = 10 * 1024 * 1024; // page size is 10 * 1024 * 1024 + DS_ASSERT_OK(client_->CreateProducer("stream1", producer, defaultProducerConf_)); + Status status; + while (status.IsOk()) { // stream element num is 10 + Element element1(reinterpret_cast(&data.front()), data.size()); + status = producer->Send(element1); + } + LOG(INFO) << status.GetMsg(); + ASSERT_TRUE(status.GetMsg().find("Stream cache memory size overflow, maxStreamSize") != std::string::npos); +} + +TEST_F(SCClientEvictObjectTest, TestEvictObjNotSpill) +{ + workerConf_ = + "--shared_memory_size_mb=10 --sc_shm_threshold_percentage=100 --oc_shm_threshold_percentage=50 -v=2"; + StartClusters(); + size_t size = 1 * 1024 * 1024; + std::string prifixKey = "object_data_"; + std::string data = randomData_.GetRandomString(size); + for (int i = 0; i < 3; i++) { // obj size is 3 + DS_ASSERT_OK(client1_->Set(prifixKey + std::to_string(i), data)); + } + std::shared_ptr producer; + std::shared_ptr consumer; + ProducerConf conf; + conf.maxStreamSize = 10 * MB; // maxStreamSize is 10 MB + conf.pageSize = 1 * MB; + conf.delayFlushTime = 2000; // delay flush time is 2000. + DS_ASSERT_OK(client_->CreateProducer("stream1", producer, conf)); + SubscriptionConfig config("subName", SubscriptionType::STREAM); + DS_ASSERT_OK(client_->Subscribe("stream1", config, consumer)); + Status status; + for (int i = 0; i < 8; i++) { // stream element num is 8 + Element element1(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element1)); + } +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/sc_client_token_auth_test.cpp b/tests/st/client/stream_cache/sc_client_token_auth_test.cpp new file mode 100644 index 0000000..1869443 --- /dev/null +++ b/tests/st/client/stream_cache/sc_client_token_auth_test.cpp @@ -0,0 +1,300 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "datasystem/context/context.h" +#include "datasystem/stream/producer.h" + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/utils/sensitive_value.h" +#include "datasystem/common/log/log.h" +#include "sc_client_common.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/common/util/random_data.h" +#include "datasystem/common/util/rpc_util.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/kv_client.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream_client.h" + +DS_DECLARE_uint32(page_size); +using namespace datasystem::client::stream_cache; + +namespace datasystem { +namespace st { +class SCClientTokenAuthTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.workerGflagParams = " -authorization_enable=true "; + opts.numEtcd = 1; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + spClient_ = nullptr; + ExternalClusterTest::TearDown(); + } + + static void ArrToStr(void *data, size_t sz, std::string &str) + { + str.assign(reinterpret_cast(data), sz); + } + + void PubSubElement(std::shared_ptr spClient1) + { + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("test", config, consumer)); + size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + DS_ASSERT_OK(spClient1->CreateProducer("test", producer, defaultProducerConf_)); + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + uint64_t conusmerNum, producerNum; + spClient1->QueryGlobalConsumersNum("test", conusmerNum); + spClient1->QueryGlobalProducersNum("test", producerNum); + + ASSERT_EQ(conusmerNum, uint64_t(1)); + ASSERT_EQ(conusmerNum, uint64_t(1)); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + + producer.reset(); + consumer.reset(); + } + +protected: + void InitTest() + { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.auth", "return(token1,tenant1)")); + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions = { .host = workerAddress.Host(), + .port = workerAddress.Port(), + .connectTimeoutMs = 60 * 1000, // 60s + .clientPublicKey = "", + .clientPrivateKey = "", + .serverPublicKey = "", + .accessKey = "QTWAOYTTINDUT2QVKYUC", + .secretKey = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc" }; + spClient_ = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient_->Init()); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + std::shared_ptr spClient_ = nullptr; + ProducerConf defaultProducerConf_; + + Status CreateElement(size_t elementSize, Element &element, std::vector &writeElement) + { + writeElement = RandomData().RandomBytes(elementSize); + element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + return Status::OK(); + } + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(SCClientTokenAuthTest, TestTokenAuth) +{ + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient_->Subscribe("test", config, consumer)); + + size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + DS_ASSERT_OK(spClient_->CreateProducer("test", producer, defaultProducerConf_)); + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); +} + +TEST_F(SCClientTokenAuthTest, TestClientWithTenantIds) +{ + // Subscribe before send. + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions; + connectOptions.host = workerAddress.Host(); + connectOptions.port = workerAddress.Port(); + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + connectOptions.tenantId = "akskTenantId"; + std::shared_ptr spClient1 = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient1->Init()); + std::thread thread1([&spClient1, this] { + Context::SetTenantId(""); + PubSubElement(spClient1); + }); + + std::thread thread2([&spClient1, this] { + Context::SetTenantId("tenantId1"); + PubSubElement(spClient1); + }); + thread1.join(); + thread2.join(); +} + +TEST_F(SCClientTokenAuthTest, TestClientResetWithTenant) +{ + // Subscribe before send. + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions; + connectOptions.host = workerAddress.Host(); + connectOptions.port = workerAddress.Port(); + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + connectOptions.tenantId = ""; + std::shared_ptr spClient1 = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient1->Init()); + std::shared_ptr consumer; + SubscriptionConfig config("sub10", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("test", config, consumer)); + std::thread thread1([&spClient1, this] { + Context::SetTenantId("tenant2"); + PubSubElement(spClient1); + }); + + std::thread thread2([&spClient1, this] { + Context::SetTenantId("tenantId1"); + PubSubElement(spClient1); + }); + thread1.join(); + thread2.join(); +} + +TEST_F(SCClientTokenAuthTest, TestClientTenant) +{ + // Subscribe before send. + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions; + connectOptions.host = workerAddress.Host(); + connectOptions.port = workerAddress.Port(); + connectOptions.SetAkSkAuth(accessKey_, secretKey_, "qqqqq"); + std::shared_ptr spClient1 = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient1->Init()); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("test", config, consumer)); + + size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + DS_ASSERT_OK(spClient1->CreateProducer("test", producer, defaultProducerConf_)); + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); +} + +TEST_F(SCClientTokenAuthTest, TestCheckoutTenantWhenDeaulfTenantIsEmpty) +{ + std::shared_ptr client1; + std::string tenantId1 = ""; + std::string tenantId2 = "tenantId1"; + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions; + connectOptions.host = workerAddress.Host(); + connectOptions.port = workerAddress.Port(); + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + std::shared_ptr spClient1 = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient1->Init()); + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient_->Subscribe("test", config, consumer)); + + size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + DS_ASSERT_OK(spClient_->CreateProducer("test", producer, defaultProducerConf_)); + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + Context::SetTenantId("tenantId1"); + ASSERT_EQ(producer->Send(element), Status::OK()); + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); +} + +TEST_F(SCClientTokenAuthTest, TestReceiveChangeTenant) +{ + std::shared_ptr client1; + std::string tenantId1 = ""; + std::string tenantId2 = "tenantId1"; + HostPort workerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress)); + ConnectOptions connectOptions; + connectOptions.host = workerAddress.Host(); + connectOptions.port = workerAddress.Port(); + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + Context::SetTenantId(tenantId2); + std::shared_ptr spClient1 = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient1->Init()); + std::shared_ptr spClient2 = std::make_shared(connectOptions); + DS_ASSERT_OK(spClient2->Init()); + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::thread t1([&spClient2, &config, &consumer, tenantId2] { + Context::SetTenantId(tenantId2); + DS_ASSERT_OK(spClient2->Subscribe("test", config, consumer)); + std::vector outElements; + Context::SetTenantId("tenantId2"); + ASSERT_EQ(consumer->Receive(1, 10000, outElements), Status::OK()); // timeout is 10000 ms + ASSERT_EQ(outElements.size(), size_t(1)); + ASSERT_EQ(outElements[0].id, size_t(1)); + }); + sleep(2); // wait 2 s to send + size_t testSize = 4ul * 1024ul * 1024ul; + Element element; + std::vector writeElement; + std::shared_ptr producer; + DS_ASSERT_OK(spClient1->CreateProducer("test", producer, defaultProducerConf_)); + DS_ASSERT_OK(CreateElement(testSize, element, writeElement)); + ASSERT_EQ(producer->Send(element), Status::OK()); + t1.join(); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/sc_metrics_test.cpp b/tests/st/client/stream_cache/sc_metrics_test.cpp new file mode 100644 index 0000000..f72b770 --- /dev/null +++ b/tests/st/client/stream_cache/sc_metrics_test.cpp @@ -0,0 +1,1100 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache metrics + */ + +#include +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/worker/stream_cache/stream_manager.h" + +DS_DECLARE_string(log_dir); + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class SCMetricsTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.enableDistributedMaster = "false"; + opts.numWorkers = WORKER_NUM; + opts.vLogLevel = VLOG_LEVEL; + opts.masterIdx = 0; + std::string workerGflags = "-sc_local_cache_memory_size_mb=20 -log_monitor=true -sc_metrics_log_interval_s=" + + std::to_string(PRINT_INTERVAL) + " -sc_cache_pages=" + std::to_string(CACHE_PAGES); + opts.workerGflagParams = workerGflags; + SCClientCommon::SetClusterSetupOptions(opts); + } + +protected: + Status InitClient(int index, std::shared_ptr &client) + { + InitStreamClient(index, client); + return Status::OK(); + } + + Status CreateProducerAndConsumer(std::shared_ptr &client, + std::vector> producerDesc, + std::vector> &producers, + std::vector> consumerDesc, + std::vector> &consumers, bool autoCleanup) + { + ProducerConf conf; + conf.delayFlushTime = DELAY_FLUSH_TIME; + conf.pageSize = PAGE_SIZE; // 4K + conf.maxStreamSize = MAX_STREAM_SIZE; + conf.autoCleanup = autoCleanup; + for (const auto &kv : producerDesc) { + for (size_t i = 0; i < kv.second; i++) { + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(kv.first, producer, conf)); + producers.emplace_back(producer); + } + } + + for (const auto &kv : consumerDesc) { + std::shared_ptr consumer; + SubscriptionConfig config(kv.second, SubscriptionType::STREAM); + RETURN_IF_NOT_OK(client->Subscribe(kv.first, config, consumer, false)); + consumers.emplace_back(consumer); + } + return Status::OK(); + } + + Status CloseAllProducerAndConsumer(std::vector> &producers, + std::vector> &consumers) + { + for (auto &producer : producers) { + RETURN_IF_NOT_OK(producer->Close()); + } + for (auto &consumer : consumers) { + RETURN_IF_NOT_OK(consumer->Close()); + } + return Status::OK(); + } + + void CheckCount(std::shared_ptr client, const std::string &streamName, int producerCount, + int consumerCount) + { + uint64_t result = 0; + if (producerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, result)); + EXPECT_EQ(result, static_cast(producerCount)); + result = 0; + } + if (consumerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, result)); + EXPECT_EQ(result, static_cast(consumerCount)); + result = 0; + } + } + + Status Produce(std::shared_ptr &producer, std::string producerName, int numEle, uint64_t eleSz, + int timeout = 0) + { + Status stat = Status::OK(); + ElementGenerator elementGenerator(eleSz, eleSz); + auto strs = elementGenerator.GenElements(producerName, numEle, 1); + Status rc; + + for (int i = 0; i < numEle; i++) { + if (timeout) { + rc = producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()), timeout); + } else { + rc = producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size())); + } + if (rc.IsError()) { + stat = rc; + } + } + return stat; + } + + Status ConsumeAll(std::shared_ptr &consumer, int timeout = 5000, bool checkFIFO = true, + uint64_t *res = nullptr, int producerNum = 1, bool ack = true) + { + std::vector outElements; + size_t expectNum = DEFAULT_NUM_ELEMENT * producerNum; + RETURN_IF_NOT_OK(consumer->Receive(expectNum, timeout, outElements)); + if (ack) { + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + } + LOG(INFO) << FormatString("Stream Consumer Receive %d elements.", outElements.size()); + std::unordered_map seqNoMap; + uint64_t eleTotalSz = 0; + for (const auto &element : outElements) { + ElementView view(std::string((const char *)element.ptr, element.size)); + RETURN_IF_NOT_OK(view.VerifyIntegrity()); + if (checkFIFO) { + RETURN_IF_NOT_OK(view.VerifyFifo(seqNoMap, 0)); + } + eleTotalSz += element.size; + } + if (res != nullptr) { + *res = eleTotalSz; + } + return Status::OK(); + } + + void VerifyAllStreamMetrics(const std::vector &workerMetrics, + const std::vector> &allScMetrics, + const std::vector &expectedWM, + const std::vector> expectedScM) + { + ASSERT_EQ(workerMetrics.size(), expectedWM.size()); + for (size_t i = 0; i < workerMetrics.size(); i++) { + ASSERT_EQ(workerMetrics[i], expectedWM[i]); + } + + ASSERT_EQ(allScMetrics.size(), expectedScM.size()); + for (size_t i = 0; i < allScMetrics.size(); i++) { + auto scMetrics = allScMetrics[i]; + auto expected = expectedScM[i]; + ASSERT_EQ(scMetrics.size(), expected.size()); + for (size_t j = 0; j < scMetrics.size(); j++) { + ASSERT_EQ(scMetrics[j], expected[j]); + } + } + } + + void GetStreamMetrics(int index, const std::string &fileName, + std::unordered_map> &allScMetrics) + { + allScMetrics.clear(); + std::string fullName = FormatString("%s/../worker%d/log/%s", FLAGS_log_dir.c_str(), index, fileName); + std::ifstream ifs(fullName); + ASSERT_TRUE(ifs.is_open()); + std::string line; + std::streampos metric_start_pos = ifs.tellg(); + std::streampos pos = ifs.tellg(); + bool found = false; + int64_t oldTime = 0; + std::string timeFormat = "%Y-%m-%dT%H:%M:%S"; + // Find the end of the second last set of stream metrics + while (std::getline(ifs, line)) { + std::string timestamp = Split(line, "|")[0]; + std::tm tm = {}; + std::stringstream ss(timestamp); + ss >> std::get_time(&tm, timeFormat.c_str()); + std::chrono::system_clock::time_point point = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + int64_t time = std::chrono::duration_cast(point.time_since_epoch()).count(); + if (time - oldTime >= PRINT_INTERVAL) { + metric_start_pos = pos; + found = true; + oldTime = time; + } + pos = ifs.tellg(); // stores the last position + } + ASSERT_TRUE(found); + ifs.clear(); + ifs.seekg(metric_start_pos); + while (std::getline(ifs, line)) { + if ((line.find(" exit") == std::string::npos)) { + auto scMetrics = Split(line, "/"); + ASSERT_TRUE(scMetrics.size() > 0); + auto streamName = scMetrics[0].substr(scMetrics[0].find_last_of("|") + 2); + scMetrics.erase(scMetrics.begin()); + allScMetrics.emplace(streamName, scMetrics); + } + } + } + + std::string GetScMetric(const std::vector &scMetrics, StreamMetric metric) + { + // Subtract index by NumLocalProducers since it is the first sc metric, add one to account for stream name + int index = (int)metric - (int)StreamMetric::NumLocalProducers; + return scMetrics[index]; + } + + void VerifyStreamMetrics(const std::unordered_map> &scMetrics, + const std::unordered_map> &expected, + const std::vector &metricsToVerify) + { + ASSERT_EQ(scMetrics.size(), expected.size()); + for (auto &metric : expected) { + // Verify stream name + ASSERT_TRUE(scMetrics.count(metric.first) == 1); + auto scMetric = scMetrics.at(metric.first); + LOG(INFO) << "Verifying stream: " << metric.first; + for (size_t i = 0; i < metricsToVerify.size(); i++) { + LOG(INFO) << "Verifying metric: " << (int)metricsToVerify[i]; + ASSERT_EQ(GetScMetric(scMetric, metricsToVerify[i]), metric.second[i]); + } + } + } + + void CreateOneWorkerMetricsScenario(std::shared_ptr &client1, + std::vector> &producers, + std::vector> &consumers, + std::string streamName1, std::string streamName2) + { + // worker 0 + // stream1: 3 producer, 1 consumer + // stream2: 2 producer, 2 consumer + DS_ASSERT_OK(InitClient(0, client1)); + + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName1, 3 }, { streamName2, 2 } }, producers, + { { streamName1, "sub1" }, { streamName2, "sub1" }, + { streamName2, "sub2" } }, consumers, false)); + } + + void CreateTwoWorkerMetricsScenario(std::shared_ptr &client1, std::shared_ptr &client2, + std::vector> &producers, + std::vector> &consumers, + std::string streamName1, std::string streamName2, + std::string streamName3) + { + // worker 0 + // stream1: 3 producer, 2 consumer + // stream2: 2 consumer + // stream3: 2 producer + // worker 1 + // stream2: 3 producer + // stream3: 1 producer, 1 consumer + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + DS_ASSERT_OK(CreateProducerAndConsumer( + client1, { { streamName1, 3 }, { streamName3, 2 } }, producers, + { { streamName1, "sub1" }, { streamName2, "sub1" }, { streamName1, "sub2" }, { streamName2, "sub2" } }, + consumers, true)); + DS_ASSERT_OK(CreateProducerAndConsumer(client2, { { streamName2, 3 }, { streamName3, 1 } }, producers, + { { streamName3, "sub1" } }, consumers, true)); + } + + void SetUp() + { + ExternalClusterTest::SetUp(); + } + + void TearDown() + { + ExternalClusterTest::TearDown(); + } + +protected: + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + const int WORKER_NUM = 2; + const int VLOG_LEVEL = 2; + const int PRINT_INTERVAL = 2; + const int DEFAULT_WAIT_TIME = 1000; + const int DEFAULT_NUM_ELEMENT = 20; + const int TEST_ELEMENT_SIZE = 2 * KB - 128; + const int TEST_ELEMENT2_SIZE = 4 * KB - 256; + const int MAX_STREAM_SIZE = 2 * MB; + const int CACHE_PAGES = 16; + const int LOG_LEVEL = 2; + const int STREAM_SIZE_MB = 2; + const int SLEEP_TIME = 2; + const int LONG_SLEEP_TIME = 5; + const int RELEASE_PAGE_SLEEP_TIME = 15; + const int BIG_ELEMENT_SIZE = 8 * KB; + const int DELAY_FLUSH_TIME = 3000; +}; + +TEST_F(SCMetricsTest, NumLocalProducersConsumers) +{ + std::shared_ptr client1; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::vector metricsToVerify = { StreamMetric::NumLocalProducers, StreamMetric::NumLocalConsumers }; + std::string streamName1 = "TestMetricsNumLocalProdCon_s1"; + std::string streamName2 = "TestMetricsNumLocalProdCon_s2"; + std::unordered_map> expected = { { streamName1, { "3", "1" } }, + { streamName2, { "2", "2" } } }; + + CreateOneWorkerMetricsScenario(client1, producers, consumers, streamName1, streamName2); + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + // Close some + const int prodCloseNum = 4; + const int conCloseNum = 2; + for (int i = 0; i < prodCloseNum; i++) { + // 3 s1, 1 s2 + DS_ASSERT_OK(producers[i]->Close()); + } + for (int i = 0; i < conCloseNum; i++) { + // 1 s1, 1 s2 + DS_ASSERT_OK(consumers[i]->Close()); + } + + sleep(SLEEP_TIME); + expected = { + { streamName1, { "0", "0" } }, + { streamName2, { "1", "1" } }, + }; + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + // Close rest and delete stream + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); + DS_ASSERT_OK(client1->DeleteStream(streamName1)); + DS_ASSERT_OK(client1->DeleteStream(streamName2)); +} + +TEST_F(SCMetricsTest, NumRemoteProducersConsumers) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::NumLocalProducers, StreamMetric::NumRemoteProducers, + StreamMetric::NumLocalConsumers, StreamMetric::NumRemoteConsumers }; + std::string s1 = "testMetricsRemoteProdCon_s1"; + std::string s2 = "testMetricsRemoteProdCon_s2"; + std::string s3 = "testMetricsRemoteProdCon_s3"; + + std::unordered_map> expected1 = { { s1, { "3", "0", "2", "0" } }, + { s2, { "0", "1", "2", "0" } }, + { s3, { "2", "0", "0", "1" } } }; + + std::unordered_map> expected2 = { { s2, { "3", "0", "0", "2" } }, + { s3, { "1", "1", "1", "0" } } }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + // Close some + const int prodIndex = 8; // s3 + const int conIndex = 4; // s3 + DS_ASSERT_OK(producers[prodIndex]->Close()); + DS_ASSERT_OK(consumers[conIndex]->Close()); + sleep(SLEEP_TIME); + expected1 = { { s1, { "3", "0", "2", "0" } }, + { s2, { "0", "1", "2", "0" } }, + { s3, { "2", "0", "0", "0" } } }; + expected2 = { { s2, { "3", "0", "0", "2" } }, { s3, { "0", "0", "0", "0" } } }; + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, SharedMemoryUsed) +{ + std::shared_ptr client1; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::vector metricsToVerify = { StreamMetric::SharedMemoryUsed }; + std::string streamName1 = "TestMetricsSharedMemUsed_s1"; + std::string streamName2 = "TestMetricsSharedMemUsed_s2"; + std::unordered_map> expected = { + { streamName1, { std::to_string(3 * 40 * KB + 4 * 64) } }, { streamName2, + { std::to_string(2 * 40 * KB + 4 * 64) } } + }; + + CreateOneWorkerMetricsScenario(client1, producers, consumers, streamName1, streamName2); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, true, nullptr, producers.size())); + } + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, LocalMemoryUsed) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::LocalMemoryUsed }; + std::string s1 = "testMetricsLocalMemUsed_s1"; + std::string s2 = "testMetricsLocalMemUsed_s2"; + std::string s3 = "testMetricsLocalMemUsed_s3"; + + // Streams with remote producer will have local memory usage + std::unordered_map> expected1 = { { s1, { "0" } }, + { s2, { std::to_string(40 * KB) } }, + { s3, { "0" } } }; + + std::unordered_map> expected2 = { + { s2, { "0" } }, { s3, { std::to_string(40 * KB) } } + }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + sleep(SLEEP_TIME); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumTotalElementsSent) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::NumTotalElementsSent, + StreamMetric::NumTotalElementsReceived, + StreamMetric::NumTotalElementsAcked }; + std::string s1 = "testMetricsTotalEleSent_s1"; + std::string s2 = "testMetricsTotalEleSent_s2"; + std::string s3 = "testMetricsTotalEleSent_s3"; + + std::unordered_map> expected1 = { { s1, { "60", "0", "0" } }, + { s2, { "0", "0", "0" } }, + { s3, { "40", "0", "0" } } }; + + std::unordered_map> expected2 = { { s2, { "60", "0", "0" } }, + { s3, { "20", "0", "0" } } }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumTotalElementsSentProducerClose) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::NumTotalElementsSent, + StreamMetric::NumTotalElementsReceived, + StreamMetric::NumTotalElementsAcked }; + std::string s1 = "testMetricsEleSentProdClose_s1"; + std::string s2 = "testMetricsEleSentProdClose_s2"; + std::string s3 = "testMetricsEleSentProdClose_s3"; + + std::unordered_map> expected1 = { { s1, { "60", "0", "0" } }, + { s2, { "0", "0", "0" } }, + { s3, { "40", "0", "0" } } }; + + std::unordered_map> expected2 = { { s2, { "60", "0", "0" } }, + { s3, { "20", "0", "0" } } }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + producer->Close(); + } + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); +} + +TEST_F(SCMetricsTest, NumTotalElementsReceived) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::NumTotalElementsSent, + StreamMetric::NumTotalElementsReceived, + StreamMetric::NumTotalElementsAcked }; + std::string s1 = "testMetricsEleRecv_s1"; + std::string s2 = "testMetricsEleRecv_s2"; + std::string s3 = "testMetricsEleRecv_s3"; + + std::unordered_map> expected1 = { { s1, { "60", "60", "0" } }, + { s2, { "0", "60", "0" } }, + { s3, { "40", "0", "0" } } }; + + std::unordered_map> expected2 = { { s2, { "60", "0", "0" } }, + { s3, { "20", "60", "0" } } }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, true, nullptr, producers.size(), false)); + } + + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumTotalEleRecvLateConsumer) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::NumTotalElementsSent, + StreamMetric::NumTotalElementsReceived, + StreamMetric::NumTotalElementsAcked }; + std::string s1 = "testMetricsRecvLateCon_s1"; + std::string s2 = "testMetricsRecvLateCon_s2"; + std::string s3 = "testMetricsRecvLateCon_s3"; + + std::unordered_map> expected1 = { { s1, { "60", "40", "40" } }, + { s2, { "0", "40", "40" } }, + { s3, { "40", "0", "0" } } }; + + std::unordered_map> expected2 = { { s2, { "60", "0", "0" } }, + { s3, { "20", "40", "40" } } }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + + int i = 0; + int numProdToRecv = 2; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, true, nullptr, numProdToRecv, true)); + } + + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + // Create late consumers + DS_ASSERT_OK(CreateProducerAndConsumer(client1, {}, producers, { { s1, "sub3" }, { s2, "sub3" } }, + consumers, true)); + numProdToRecv = 1; + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, false, nullptr, numProdToRecv, false)); + } + + expected1 = { { s1, { "60", "60", "40" } }, + { s2, { "0", "60", "40" } }, + { s3, { "40", "0", "0" } } }; + + expected2 = { { s2, { "60", "0", "0" } }, { s3, { "20", "60", "40" } } }; + + sleep(SLEEP_TIME); + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumTotalElementsAcked) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::NumTotalElementsSent, + StreamMetric::NumTotalElementsReceived, + StreamMetric::NumTotalElementsAcked }; + std::string s1 = "testMetricsEleAcked_s1"; + std::string s2 = "testMetricsEleAcked_s2"; + std::string s3 = "testMetricsEleAcked_s3"; + + std::unordered_map> expected1 = { { s1, { "60", "60", "60" } }, + { s2, { "0", "60", "60" } }, + { s3, { "40", "0", "0" } } }; + + std::unordered_map> expected2 = { { s2, { "60", "0", "0" } }, + { s3, { "20", "60", "60" } } }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, true, nullptr, producers.size())); + } + + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumSendReceiveRequests) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::vector metricsToVerify = { StreamMetric::NumSendRequests, StreamMetric::NumReceiveRequests }; + std::string s1 = "testMetricsSendRecvReq_s1"; + std::string s2 = "testMetricsSendRecvReq_s2"; + std::string s3 = "testMetricsSendRecvReq_s3"; + // Each request executed twice, one sucess one failure + // prodNum * eleNum * times, consNum * eleNum * times + std::unordered_map> expected1 = { + { s1, { "120", "4" } }, // 3 * 20 * 2, 2 * 2 + { s2, { "0", "4" } }, // 0, 2 * 2 + { s3, { "80", "0" } } // 2 * 20 * 2, 0 + }; + + std::unordered_map> expected2 = { + { s2, { "120", "0" } }, // 3 * 20 * 2 + { s3, { "40", "2" } } // 20 * 2, 1 * 2 + }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + DS_ASSERT_NOT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE, -1)); + } + + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, true, nullptr, producers.size())); + DS_ASSERT_OK(ConsumeAll(consumer, 0, true, nullptr, producers.size(), false)); + } + + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, LEVEL1_NumPages) +{ + std::shared_ptr client1; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::vector metricsToVerify = { StreamMetric::NumPagesCreated, StreamMetric::NumPagesReleased, + StreamMetric::NumPagesInUse, StreamMetric::NumPagesCached }; + std::string streamName1 = "TestMetricsPages_s1"; + std::string streamName2 = "TestMetricsPages_s2"; + std::unordered_map> expected = { { streamName1, { "30", "0", "30", "0" } }, + { streamName2, { "20", "0", "20", "0" } } }; + + CreateOneWorkerMetricsScenario(client1, producers, consumers, streamName1, streamName2); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + // One page in use still + expected = { { streamName1, { "30", std::to_string(30 - CACHE_PAGES - 1), "1", std::to_string(CACHE_PAGES) } }, + { streamName2, { "20", std::to_string(20 - CACHE_PAGES - 1), "1", std::to_string(CACHE_PAGES) } } }; + + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, true, nullptr, producers.size())); + } + + sleep(RELEASE_PAGE_SLEEP_TIME); // wait until pages are cleaned up + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, LEVEL1_NumBigPages) +{ + std::shared_ptr client1; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::vector metricsToVerify = { StreamMetric::NumPagesCreated, StreamMetric::NumPagesReleased, + StreamMetric::NumPagesCached, StreamMetric::NumBigPagesCreated, + StreamMetric::NumBigPagesReleased }; + std::string streamName1 = "TestMetricsBigPages_s1"; + std::string streamName2 = "TestMetricsBigPages_s2"; + std::unordered_map> expected = { + // all big pages, except 1 normal page from CreatePageZero() + { streamName1, { "1", "0", "0", "60", "0" } }, + { streamName2, { "1", "0", "0", "40", "0" } } + }; + + CreateOneWorkerMetricsScenario(client1, producers, consumers, streamName1, streamName2); + int i = 0; + // Each producer sends 20 8KB elements -> BigElement since pageSize is 4KB + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, BIG_ELEMENT_SIZE)); + } + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + for (auto &consumer : consumers) { + DS_ASSERT_OK(ConsumeAll(consumer, DEFAULT_WAIT_TIME, true, nullptr, producers.size())); + } + // all big pages, except 1 normal page from CreatePageZero() + expected = { { streamName1, { "1", "0", "0", "60", "60" } }, { streamName2, { "1", "0", "0", "40", "40" } } }; + sleep(RELEASE_PAGE_SLEEP_TIME); // wait until pages are cleaned up + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumLocalProducersBlocked) +{ + std::shared_ptr client1; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::string streamName1 = "testMetricsLocalProdBlocked_s1"; + std::string streamName2 = "testMetricsLocalProdBlocked_s2"; + std::vector metricsToVerify = { StreamMetric::NumLocalProducersBlocked }; + + std::unordered_map> expected = { { streamName1, { "1" } }, + { streamName2, { "0" } } }; + + DS_ASSERT_OK(InitClient(0, client1)); + + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName1, 3 }, { streamName2, 2 } }, producers, + { { streamName1, "sub1" }, { streamName2, "sub1" }, + { streamName2, "sub2" } }, consumers, true)); + const int eleNum = 550; + const int waitTime = 30000; + std::thread sendThread([&]() { Produce(producers[0], "producer", eleNum, TEST_ELEMENT2_SIZE, waitTime); }); + + sleep(LONG_SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + const int expectNum = 100; + const int numRecvCons = 1; + for (int i = 0; i < numRecvCons; i++) { + std::vector outElements; + DS_ASSERT_OK(consumers[i]->Receive(expectNum, DEFAULT_WAIT_TIME, outElements)); + DS_ASSERT_OK(consumers[i]->Ack(outElements.back().id)); + } + + sendThread.join(); + + expected = { { streamName1, { "0" } }, { streamName2, { "0" } } }; + sleep(SLEEP_TIME); + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumRemoteProducersConsumersBlocked) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::string streamName1 = "testMetricsRemoteProConBlocked_s1"; + std::string streamName2 = "testMetricsRemoteProConBlocked_s2"; + std::vector metricsToVerify = { StreamMetric::NumRemoteProducersBlocked, + StreamMetric::NumRemoteConsumersBlocking }; + + std::unordered_map> expected1 = { { streamName1, { "0", "2" } }, + { streamName2, { "", "" } } }; + + std::unordered_map> expected2 = { { streamName1, { "1", "0" } }, + { streamName2, { "0", "0" } } }; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName1, 1 } }, producers, {}, consumers, true)); + DS_ASSERT_OK(CreateProducerAndConsumer(client2, { { streamName1, 1 }, { streamName2, 1 } }, producers, + { { streamName1, "sub1" }, { streamName1, "sub2" }, + { streamName2, "sub1" } }, consumers, true)); + + const int waitTime = 30000; + // First send 200 elements to use up some memory in worker2 + Produce(producers[1], "producer2", 200, TEST_ELEMENT2_SIZE, waitTime); + + // Send 400 elements to worker1, so worker2 will block worker1 producer + int eleNum = 400; + Produce(producers[0], "producer1", eleNum, TEST_ELEMENT2_SIZE, waitTime); + + sleep(LONG_SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + const int expectNum = 100; + const int numRecvCons = 2; + for (int i = 0; i < numRecvCons; i++) { + std::vector outElements; + DS_ASSERT_OK(consumers[i]->Receive(expectNum, DEFAULT_WAIT_TIME, outElements)); + DS_ASSERT_OK(consumers[i]->Ack(outElements.back().id)); + } + + expected1 = { { streamName1, { "0", "0" } }, { streamName2, { "", "" } } }; + + expected2 = { { streamName1, { "0", "0" } }, { streamName2, { "0", "0" } } }; + sleep(SLEEP_TIME); + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, RetainDataState) +{ + std::shared_ptr client1; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::string streamName = "testMetricsRetainDataState"; + std::vector metricsToVerify = { + StreamMetric::RetainDataState, + }; + + // Streams with remote producer will have local memory usage + std::unordered_map> expected = { + { streamName, { std::to_string(RetainDataState::RETAIN) } } + }; + + DS_ASSERT_OK(InitClient(0, client1)); + + ProducerConf conf; + conf.delayFlushTime = DELAY_FLUSH_TIME; + conf.pageSize = PAGE_SIZE; // 4K + conf.maxStreamSize = MAX_STREAM_SIZE; + conf.autoCleanup = true; + conf.retainForNumConsumers = 1; // retain data until one consumer + + std::shared_ptr producer; + // Create producer and send data + std::vector writeElement = RandomData().RandomBytes(TEST_ELEMENT_SIZE); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + sleep(SLEEP_TIME); + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + + // Create a late consumer + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer)); + std::vector outElements; + // Now should get the data + DS_ASSERT_OK(consumer->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(DEFAULT_NUM_ELEMENT)); + + expected = { { streamName, { std::to_string(RetainDataState::NOT_RETAIN) } } }; + + sleep(SLEEP_TIME); + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); +} + +TEST_F(SCMetricsTest, NumProducersConsumersMaster) +{ + std::shared_ptr client1; + std::shared_ptr client2; + std::vector> producers; + std::vector> consumers; + std::unordered_map> sc0Metrics; + std::unordered_map> sc1Metrics; + std::string s1 = "TestMetricsProdConMaster_s1"; + std::string s2 = "TestMetricsProdConMaster_s2"; + std::string s3 = "TestMetricsProdConMaster_s3"; + // StreamMetric::NumProducersMaster: Number of worker that have at least 1 producer for the stream. + // If there are 2 or more local producer for the same stream on the same worker, that count as 1. + std::vector metricsToVerify = { + StreamMetric::NumProducersMaster, + StreamMetric::NumConsumersMaster, + }; + + std::unordered_map> expected1 = { { s1, { "1", "2" } }, + { s2, { "1", "2" } }, + { s3, { "2", "1" } } }; + + std::unordered_map> expected2 = { { s2, { "", "" } }, + { s3, { "", "" } } }; + + CreateTwoWorkerMetricsScenario(client1, client2, producers, consumers, s1, s2, s3); + sleep(SLEEP_TIME); + + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + + // Close some + const int prodIndex = 8; // s3 + const int conIndex = 4; // s3 + DS_ASSERT_OK(producers[prodIndex]->Close()); + DS_ASSERT_OK(consumers[conIndex]->Close()); + sleep(SLEEP_TIME); + expected1 = { { s1, { "1", "2" } }, { s2, { "1", "2" } }, { s3, { "1", "0" } } }; + expected2 = { { s2, { "", "" } }, { s3, { "", "" } } }; + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected1, metricsToVerify); + GetStreamMetrics(1, "sc_metrics.log", sc1Metrics); + VerifyStreamMetrics(sc1Metrics, expected2, metricsToVerify); + DS_ASSERT_OK(CloseAllProducerAndConsumer(producers, consumers)); +} + +TEST_F(SCMetricsTest, NumPagesInit) +{ + std::shared_ptr client1; + std::unordered_map> sc0Metrics; + std::vector metricsToVerify = { StreamMetric::NumPagesCreated, StreamMetric::NumPagesReleased, + StreamMetric::NumPagesCached, StreamMetric::NumBigPagesCreated, + StreamMetric::NumBigPagesReleased }; + std::string streamName = "testMetricsNumPgsInit"; + std::unordered_map> expected = { { streamName, { "1", "0", "0", "0", "0" } }}; + + DS_ASSERT_OK(InitClient(0, client1)); + + ProducerConf conf; + conf.delayFlushTime = DELAY_FLUSH_TIME; + conf.pageSize = PAGE_SIZE; // 4K + conf.maxStreamSize = MAX_STREAM_SIZE; + const int numIterations = 3; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + // Verify values are properly init + for (int i = 0; i < numIterations; i++) { + std::shared_ptr producer; + std::shared_ptr consumer; + DS_ASSERT_OK(client1->Subscribe(streamName, config, consumer, true)); + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + sleep(SLEEP_TIME); + GetStreamMetrics(0, "sc_metrics.log", sc0Metrics); + VerifyStreamMetrics(sc0Metrics, expected, metricsToVerify); + producer->Close(); + consumer->Close(); + client1->DeleteStream(streamName); + } +} +} // namespace st +} // namespace datasystem \ No newline at end of file diff --git a/tests/st/client/stream_cache/shared_page_send_recv_test.cpp b/tests/st/client/stream_cache/shared_page_send_recv_test.cpp new file mode 100644 index 0000000..a3dc664 --- /dev/null +++ b/tests/st/client/stream_cache/shared_page_send_recv_test.cpp @@ -0,0 +1,914 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Remote send test with shared page queue enabled. + */ +#include +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/common/util/random_data.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +constexpr int K_TWO = 2; +constexpr int K_TEN = 10; +constexpr int K_TWENTY = 20; +class SharedPageSendRecvTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; + + void SetUp() override; + + void TearDown() override; + +protected: + Status TryAndDeleteStream(std::shared_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + Status SendHelper(std::shared_ptr producer, size_t numElements, Element element); + Status SendRandomHelper(std::shared_ptr producer, size_t numElements, const std::string &data, + size_t minSize = 1024); + Status ReceiveHelper(std::shared_ptr consumer, size_t numElements, const std::string &expectedData = ""); + + // Mock producer worker. + HostPort w1Addr_; + HostPort w2Addr_; + HostPort w3Addr_; + + std::shared_ptr w1Client_ = nullptr; + std::shared_ptr w2Client_ = nullptr; + std::shared_ptr w3Client_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + const int DEFAULT_WAIT_TIME = 5000; + const int DEFAULT_WORKER_NUM = 3; + const int DEFAULT_LOG_LEVEL = 2; +}; + +void SharedPageSendRecvTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = DEFAULT_WORKER_NUM; + opts.enableDistributedMaster = "true"; + opts.numRpcThreads = 0; + opts.vLogLevel = DEFAULT_LOG_LEVEL; + opts.workerGflagParams = " -shared_memory_size_mb=2048 "; + + SCClientCommon::SetClusterSetupOptions(opts); +} + +void SharedPageSendRecvTest::SetUp() +{ + ExternalClusterTest::SetUp(); + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, w1Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, w2Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(K_TWO, w3Addr_)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, K_TWO)); + // Worker 1. + InitStreamClient(0, w1Client_); + InitStreamClient(1, w2Client_); + InitStreamClient(K_TWO, w3Client_); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; +} + +void SharedPageSendRecvTest::TearDown() +{ + w1Client_ = nullptr; + w2Client_ = nullptr; + w3Client_ = nullptr; + ExternalClusterTest::TearDown(); +} + +Status SharedPageSendRecvTest::SendHelper(std::shared_ptr producer, size_t numElements, Element element) +{ + const int DEFAULT_SLEEP_TIME = 300; + int retryLimit = 30; + for (size_t i = 0; i < numElements; i++) { + datasystem::Status rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + } + RETURN_IF_NOT_OK(rc); + } + return Status::OK(); +} + +Status SharedPageSendRecvTest::SendRandomHelper(std::shared_ptr producer, size_t numElements, + const std::string &data, size_t minSize) +{ + const int DEFAULT_SLEEP_TIME = 300; + const int DEFAULT_RETRY_TIME = 60; + Timer timer; + size_t maxSize = data.size(); + minSize = std::min(maxSize, minSize); + for (size_t i = 0; i < numElements; i++) { + size_t sizeElement = RandomData().GetRandomUint32(minSize, maxSize); + std::string writeElement = data.substr(0, sizeElement); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + datasystem::Status rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + } + if (rc.IsError()) { + LOG(INFO) << "send failed exits."; + } + RETURN_IF_NOT_OK(rc); + LOG(INFO) << "send count:" << i; + } + return Status::OK(); +} + +Status SharedPageSendRecvTest::ReceiveHelper(std::shared_ptr consumer, size_t numElements, + const std::string &expectedData) +{ + Timer timer; + size_t remaining = numElements; + int round = 0; + const int PER_RECEIVE_NUM = 1; + const int DEFAULT_WAIT_TIME = 1000; + const int DEFAULT_RETRY_TIME = 30; + while (remaining > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + std::vector outElements; + RETURN_IF_NOT_OK(consumer->Receive(PER_RECEIVE_NUM, DEFAULT_WAIT_TIME, outElements)); + LOG(INFO) << "remaining num : " << remaining << ", receive num : " << outElements.size() << " ;" << round++; + if (!outElements.empty()) { + remaining -= outElements.size(); + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + if (!expectedData.empty()) { + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + CHECK_FAIL_RETURN_STATUS(expectedData == actualData, K_RUNTIME_ERROR, + "expected data does not match actual data."); + } + } + } + CHECK_FAIL_RETURN_STATUS(remaining == 0, K_RUNTIME_ERROR, "failed to receive all data"); + return Status::OK(); +} + +TEST_F(SharedPageSendRecvTest, TestBasicSingleStream) +{ + // Test that a single stream with shared page enabled, send and receive are working as expected. + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + + const int threadNum = 2; + // shared page is default of size 4MB, so apply ~10MB data to test that multiple pages would work + const size_t numElements = 10000; + ThreadPool pool(threadNum); + std::vector> futs; + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); +} + +TEST_F(SharedPageSendRecvTest, TestBasicMultiStream) +{ + // Test that with 3 streams with consumer on the same node, the shared page logic is working fine. + + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(w2Client_->Subscribe("stream2", config, consumer2)); + std::shared_ptr consumer3; + DS_ASSERT_OK(w2Client_->Subscribe("stream3", config, consumer3)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->CreateProducer("stream2", producer2, conf)); + std::shared_ptr producer3; + DS_ASSERT_OK(w1Client_->CreateProducer("stream3", producer3, conf)); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + std::string writeElement2 = RandomData().GetRandomString(sizeElement); + Element element2(reinterpret_cast(writeElement2.data()), writeElement2.size()); + std::string writeElement3 = RandomData().GetRandomString(sizeElement); + Element element3(reinterpret_cast(writeElement3.data()), writeElement3.size()); + + const int threadNum = 6; + const size_t numElements = 10000; + ThreadPool pool(threadNum); + std::vector> futs; + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back( + pool.Submit([this, producer2, &element2]() { return SendHelper(producer2, numElements, element2); })); + futs.push_back( + pool.Submit([this, producer3, &element3]() { return SendHelper(producer3, numElements, element3); })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + futs.push_back(pool.Submit( + [this, consumer2, &writeElement2]() { return ReceiveHelper(consumer2, numElements, writeElement2); })); + futs.push_back(pool.Submit( + [this, consumer3, &writeElement3]() { return ReceiveHelper(consumer3, numElements, writeElement3); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(producer3->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(consumer3->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream2")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream3")); +} + +TEST_F(SharedPageSendRecvTest, TestCloseOneStreamConsumer) +{ + // Test that if one of the consumer is closed in shared page case, the buffers are discarded correctly. + + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(w2Client_->Subscribe("stream2", config, consumer2)); + std::shared_ptr consumer3; + DS_ASSERT_OK(w2Client_->Subscribe("stream3", config, consumer3)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->CreateProducer("stream2", producer2, conf)); + std::shared_ptr producer3; + DS_ASSERT_OK(w1Client_->CreateProducer("stream3", producer3, conf)); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + std::string writeElement2 = RandomData().GetRandomString(sizeElement); + Element element2(reinterpret_cast(writeElement2.data()), writeElement2.size()); + std::string writeElement3 = RandomData().GetRandomString(sizeElement); + Element element3(reinterpret_cast(writeElement3.data()), writeElement3.size()); + + const int threadNum = 6; + const size_t numElements = 10000; + ThreadPool pool(threadNum); + std::vector> futs; + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back( + pool.Submit([this, producer2, &element2]() { return SendHelper(producer2, numElements, element2); })); + futs.push_back( + pool.Submit([this, producer3, &element3]() { return SendHelper(producer3, numElements, element3); })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + futs.push_back(pool.Submit( + [this, consumer2, &writeElement2]() { return ReceiveHelper(consumer2, numElements, writeElement2); })); + + sleep(1); + DS_ASSERT_OK(consumer3->Close()); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(producer3->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream2")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream3")); +} + +TEST_F(SharedPageSendRecvTest, TestDeleteOneStream) +{ + // Test that if one of the streams is deleted in shared page case, the buffers are discarded correctly. + + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(w2Client_->Subscribe("stream2", config, consumer2)); + std::shared_ptr consumer3; + DS_ASSERT_OK(w2Client_->Subscribe("stream3", config, consumer3)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->CreateProducer("stream2", producer2, conf)); + std::shared_ptr producer3; + DS_ASSERT_OK(w1Client_->CreateProducer("stream3", producer3, conf)); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + std::string writeElement2 = RandomData().GetRandomString(sizeElement); + Element element2(reinterpret_cast(writeElement2.data()), writeElement2.size()); + std::string writeElement3 = RandomData().GetRandomString(sizeElement); + Element element3(reinterpret_cast(writeElement3.data()), writeElement3.size()); + + const int threadNum = 6; + const size_t numElements = 10000; + // It will be slow if the stream to be deleted sends too many elements, + // so only send 1000 elements to get discarded. + const size_t numElementsForDeleteStream = 1000; + ThreadPool pool(threadNum); + std::vector> futs; + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back( + pool.Submit([this, producer2, &element2]() { return SendHelper(producer2, numElements, element2); })); + futs.push_back(pool.Submit([this, producer3, consumer3, &element3]() { + RETURN_IF_NOT_OK(SendHelper(producer3, numElementsForDeleteStream, element3)); + // Delete stream after the send is done, so the data should be discarded. + RETURN_IF_NOT_OK(producer3->Close()); + RETURN_IF_NOT_OK(consumer3->Close()); + RETURN_IF_NOT_OK(TryAndDeleteStream(w1Client_, "stream3")); + return Status::OK(); + })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + futs.push_back(pool.Submit( + [this, consumer2, &writeElement2]() { return ReceiveHelper(consumer2, numElements, writeElement2); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream2")); +} + +TEST_F(SharedPageSendRecvTest, TestMultiStreamStreamNoTiming) +{ + // Test that with 3 streams with parallel create producer, stream numbers are generated correctly. + // Injection is used for the timing. + DS_ASSERT_OK(cluster_->SetInjectAction( + WORKER, 0, "ClientWorkerSCServiceImpl.CreateStreamManagerImpl.StreamNo_Sleep", "sleep(2000)")); + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(w2Client_->Subscribe("stream2", config, consumer2)); + std::shared_ptr consumer3; + DS_ASSERT_OK(w2Client_->Subscribe("stream3", config, consumer3)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + std::shared_ptr producer2; + std::shared_ptr producer3; + const int NUM_PRODUCERS = 3; + ThreadPool createProducerPool(NUM_PRODUCERS); + std::vector> futs; + // CreateProducer in parallel to trigger the potential timing. + futs.emplace_back(createProducerPool.Submit( + [this, &producer1, conf]() { return w1Client_->CreateProducer("stream1", producer1, conf); })); + futs.emplace_back(createProducerPool.Submit( + [this, &producer2, conf]() { return w1Client_->CreateProducer("stream2", producer2, conf); })); + futs.emplace_back(createProducerPool.Submit( + [this, &producer3, conf]() { return w1Client_->CreateProducer("stream3", producer3, conf); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + futs.clear(); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + std::string writeElement2 = RandomData().GetRandomString(sizeElement); + Element element2(reinterpret_cast(writeElement2.data()), writeElement2.size()); + std::string writeElement3 = RandomData().GetRandomString(sizeElement); + Element element3(reinterpret_cast(writeElement3.data()), writeElement3.size()); + + const int threadNum = 6; + const size_t numElements = 10000; + ThreadPool sendRecvPool(threadNum); + futs.push_back( + sendRecvPool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back( + sendRecvPool.Submit([this, producer2, &element2]() { return SendHelper(producer2, numElements, element2); })); + futs.push_back( + sendRecvPool.Submit([this, producer3, &element3]() { return SendHelper(producer3, numElements, element3); })); + futs.push_back(sendRecvPool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + futs.push_back(sendRecvPool.Submit( + [this, consumer2, &writeElement2]() { return ReceiveHelper(consumer2, numElements, writeElement2); })); + futs.push_back(sendRecvPool.Submit( + [this, consumer3, &writeElement3]() { return ReceiveHelper(consumer3, numElements, writeElement3); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(producer3->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(consumer3->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream2")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream3")); +} + +TEST_F(SharedPageSendRecvTest, TestShutdownWithoutDelStream) +{ + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); +} + +TEST_F(SharedPageSendRecvTest, TestReuseSharedPageQueue) +{ + // Test that shared page queue can be reused, when stream is deleted and re-created. + auto func = [this]() { + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + RETURN_IF_NOT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + RETURN_IF_NOT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + + const int threadNum = 2; + // shared page is default of size 4MB, so apply ~10MB data to test that multiple pages would work + const size_t numElements = 10000; + ThreadPool pool(threadNum); + std::vector> futs; + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + + for (auto &fut : futs) { + RETURN_IF_NOT_OK(fut.get()); + } + RETURN_IF_NOT_OK(producer1->Close()); + RETURN_IF_NOT_OK(consumer1->Close()); + RETURN_IF_NOT_OK(TryAndDeleteStream(w1Client_, "stream1")); + return Status::OK(); + }; + DS_ASSERT_OK(func()); + // Recreate stream so it uses the same shared page queue. + DS_ASSERT_OK(func()); +} + +TEST_F(SharedPageSendRecvTest, TestOneStreamBlocking1) +{ + // Test that if one of the streams got OOM and blocked, + // the data will be moved to separate shm blocks, + // and the shared pages can be acked and freed. + // In this testcase, trigger blocking first and then send data of other streams. + + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(w2Client_->Subscribe("stream2", config, consumer2)); + std::shared_ptr consumer3; + DS_ASSERT_OK(w2Client_->Subscribe("stream3", config, consumer3)); + + ProducerConf conf; + // Restrict the stream size, so that blocking happens earlier for stream3. + const uint64_t maxStreamSize = 8 * MB; + conf.maxStreamSize = maxStreamSize; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->CreateProducer("stream2", producer2, conf)); + std::shared_ptr producer3; + DS_ASSERT_OK(w1Client_->CreateProducer("stream3", producer3, conf)); + + const size_t sizeElement = 250 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + std::string writeElement2 = RandomData().GetRandomString(sizeElement); + Element element2(reinterpret_cast(writeElement2.data()), writeElement2.size()); + std::string writeElement3 = RandomData().GetRandomString(sizeElement); + Element element3(reinterpret_cast(writeElement3.data()), writeElement3.size()); + + const int threadNum = 6; + const size_t numElements = 100; + ThreadPool pool(threadNum); + std::vector> futs; + // Send stream3 data first to trigger blocking before hand. + auto s3pFut = pool.Submit([this, producer3, &element3]() { return SendHelper(producer3, numElements, element3); }); + // Wait for push and move to happen and then continue to other streams. + sleep(1); + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back( + pool.Submit([this, producer2, &element2]() { return SendHelper(producer2, numElements, element2); })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + futs.push_back(pool.Submit( + [this, consumer2, &writeElement2]() { return ReceiveHelper(consumer2, numElements, writeElement2); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + // Test that the blocked data can still all be received once unblocked. + DS_ASSERT_OK(ReceiveHelper(consumer3, numElements, writeElement3)); + DS_ASSERT_OK(s3pFut.get()); + + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(producer3->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(consumer3->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream2")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream3")); +} + +TEST_F(SharedPageSendRecvTest, TestOneStreamBlocking2) +{ + // Test that if one of the streams got OOM and blocked, + // the data will be moved to separate shm blocks, + // and the shared pages can be acked and freed. + // In this testcase, send elements from streams together. + + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(w2Client_->Subscribe("stream2", config, consumer2)); + std::shared_ptr consumer3; + DS_ASSERT_OK(w2Client_->Subscribe("stream3", config, consumer3)); + + ProducerConf conf; + // Restrict the stream size, so that blocking happens earlier for stream3. + const uint64_t maxStreamSize = 8 * MB; + conf.maxStreamSize = maxStreamSize; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->CreateProducer("stream2", producer2, conf)); + std::shared_ptr producer3; + DS_ASSERT_OK(w1Client_->CreateProducer("stream3", producer3, conf)); + + const size_t sizeElement = 250 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + std::string writeElement2 = RandomData().GetRandomString(sizeElement); + Element element2(reinterpret_cast(writeElement2.data()), writeElement2.size()); + std::string writeElement3 = RandomData().GetRandomString(sizeElement); + Element element3(reinterpret_cast(writeElement3.data()), writeElement3.size()); + + const int threadNum = 6; + const size_t numElements = 100; + ThreadPool pool(threadNum); + std::vector> futs; + // Send stream3 data first to trigger blocking before hand. + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back( + pool.Submit([this, producer2, &element2]() { return SendHelper(producer2, numElements, element2); })); + auto s3pFut = pool.Submit([this, producer3, &element3]() { return SendHelper(producer3, numElements, element3); }); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + futs.push_back(pool.Submit( + [this, consumer2, &writeElement2]() { return ReceiveHelper(consumer2, numElements, writeElement2); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + // Test that the blocked data can still all be received once unblocked. + DS_ASSERT_OK(ReceiveHelper(consumer3, numElements, writeElement3)); + DS_ASSERT_OK(s3pFut.get()); + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(producer3->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(consumer3->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream2")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream3")); +} + +TEST_F(SharedPageSendRecvTest, TestProduderAndConsuemrAtOneNode) +{ + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.streamMode = StreamMode::MPSC; + std::shared_ptr producer1; + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + DS_ASSERT_OK(w2Client_->CreateProducer("stream1", producer2, conf)); + + // Start consumers first for now. + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w1Client_->Subscribe("stream1", config, consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(w1Client_->Subscribe("stream2", config, consumer2)); + + std::shared_ptr producer3; + std::shared_ptr producer4; + DS_ASSERT_OK(w1Client_->CreateProducer("stream2", producer3, conf)); + DS_ASSERT_OK(w2Client_->CreateProducer("stream2", producer4, conf)); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + std::string writeElement2 = RandomData().GetRandomString(sizeElement); + Element element2(reinterpret_cast(writeElement2.data()), writeElement2.size()); + + const int threadNum = 6; + const size_t numElements = 10000; + const size_t recvNumElements = numElements * 2; + ThreadPool pool(threadNum); + std::vector> futs; + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back( + pool.Submit([this, producer2, &element1]() { return SendHelper(producer2, numElements, element1); })); + futs.push_back( + pool.Submit([this, producer3, &element2]() { return SendHelper(producer3, numElements, element2); })); + futs.push_back( + pool.Submit([this, producer4, &element2]() { return SendHelper(producer4, numElements, element2); })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, recvNumElements, writeElement1); })); + futs.push_back(pool.Submit( + [this, consumer2, &writeElement2]() { return ReceiveHelper(consumer2, recvNumElements, writeElement2); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(producer3->Close()); + DS_ASSERT_OK(producer4->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream2")); +} + +TEST_F(SharedPageSendRecvTest, TestCloseLastProducerWhenDataSending) +{ + ProducerConf prodConf; + prodConf.maxStreamSize = TEST_STREAM_SIZE; + prodConf.streamMode = StreamMode::MPSC; + SubscriptionConfig subsConfig("sub1", SubscriptionType::STREAM); + std::vector clients = { w1Client_.get(), w2Client_.get(), w3Client_.get() }; + + const size_t sizeElement = 50 * KB; + const size_t numElements = 300; + std::string writeElement = RandomData().GetRandomString(sizeElement); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + auto start = [this, &clients, &prodConf, &subsConfig, &writeElement, &element](const std::string &streamName, + size_t consumerIndex) { + std::vector> producers; + std::shared_ptr consumer; + for (size_t index = 0; index < clients.size(); index++) { + auto client = clients[index]; + if (consumerIndex == index) { + RETURN_IF_NOT_OK(client->Subscribe(streamName, subsConfig, consumer)); + } else { + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, prodConf)); + producers.emplace_back(std::move(producer)); + } + } + ThreadPool pool(clients.size()); + std::vector> futs; + for (auto &producer : producers) { + futs.push_back(pool.Submit([this, producer, &element]() { + RETURN_IF_NOT_OK(SendHelper(producer, numElements, element)); + RETURN_IF_NOT_OK(producer->Close()); + return Status::OK(); + })); + } + size_t recvNumElements = numElements * producers.size(); + futs.push_back(pool.Submit([this, consumer, &writeElement, recvNumElements]() { + return ReceiveHelper(consumer, recvNumElements, writeElement); + })); + Status lastRc; + for (auto &fut : futs) { + auto rc = fut.get(); + if (rc.IsError()) { + lastRc = rc; + } + } + return lastRc; + }; + + const int testStreamCount = 3; + for (int i = 0; i < testStreamCount; i++) { + std::string streamName = "stream-" + std::to_string(i); + DS_ASSERT_OK(start(streamName, i % clients.size())); + } +} + +TEST_F(SharedPageSendRecvTest, TestRemoteRecvIncUsageFailed) +{ + const int maxStreamSize = 4 * MB; + const int pageSize = 1 * MB; + ProducerConf prodConf; + prodConf.maxStreamSize = maxStreamSize; + prodConf.pageSize = pageSize; + prodConf.retainForNumConsumers = 1; + prodConf.streamMode = StreamMode::MPSC; + SubscriptionConfig subsConfig("sub1", SubscriptionType::STREAM); + std::vector clients = { w1Client_.get(), w2Client_.get(), w3Client_.get() }; + + const size_t numElements = 300; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamMetaShm.TryIncUsage", "5%return(K_OUT_OF_MEMORY)")); + auto start = [this, &clients, &prodConf, &subsConfig](const std::string &streamName, size_t consumerIndex) { + std::vector> producers; + std::shared_ptr consumer; + for (size_t index = 0; index < clients.size(); index++) { + auto client = clients[index]; + if (consumerIndex == index) { + RETURN_IF_NOT_OK(client->Subscribe(streamName, subsConfig, consumer)); + } + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(streamName, producer, prodConf)); + producers.emplace_back(std::move(producer)); + } + int threadCount = producers.size() + 1; + ThreadPool pool(threadCount); + std::vector> futs; + int index = 0; + for (auto &producer : producers) { + futs.push_back(pool.Submit([this, producer, index]() { + LOG(INFO) << "producer index:" << index; + size_t dataSize = 600 * 1024; + auto data = RandomData().GetRandomString(dataSize); + RETURN_IF_NOT_OK(SendRandomHelper(producer, numElements, data)); + RETURN_IF_NOT_OK(producer->Close()); + return Status::OK(); + })); + index++; + } + size_t recvNumElements = numElements * producers.size(); + futs.push_back( + pool.Submit([this, consumer, recvNumElements]() { return ReceiveHelper(consumer, recvNumElements); })); + Status lastRc; + for (auto &fut : futs) { + auto rc = fut.get(); + if (rc.IsError()) { + lastRc = rc; + } + } + return lastRc; + }; + + std::string streamName = "stream-0"; + DS_ASSERT_OK(start(streamName, 0)); +} + +class SharedPageSendRecvOOMTest : public SharedPageSendRecvTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; +}; + +void SharedPageSendRecvOOMTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = DEFAULT_WORKER_NUM; + opts.enableDistributedMaster = "true"; + opts.numRpcThreads = 0; + opts.vLogLevel = DEFAULT_LOG_LEVEL; + // Increase the zmq_chunk_sz to beyond shared page size, so then it can batch elements from multiple pages. + opts.workerGflagParams = " -zmq_chunk_sz=10485760"; + SCClientCommon::SetClusterSetupOptions(opts); +} + +TEST_F(SharedPageSendRecvOOMTest, TestSingleStreamOOM) +{ + // Test that a single stream with shared page enabled, requests are retried if necessary upon OOM. + // OOM is injected to guarantee consistency. + // Start consumers first for now. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "worker.UsageMonitor.CheckOverUsedForStream.MockError", + "5*return(K_OUT_OF_MEMORY)")); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe("stream1", config, consumer1)); + + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.pageSize = 1 * MB; + conf.streamMode = StreamMode::SPSC; + std::shared_ptr producer1; + DS_ASSERT_OK(w1Client_->CreateProducer("stream1", producer1, conf)); + + const size_t sizeElement = 100 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + + const int threadNum = 2; + const size_t numElements = 500; + ThreadPool pool(threadNum); + std::vector> futs; + futs.push_back( + pool.Submit([this, producer1, &element1]() { return SendHelper(producer1, numElements, element1); })); + futs.push_back(pool.Submit( + [this, consumer1, &writeElement1]() { return ReceiveHelper(consumer1, numElements, writeElement1); })); + + for (auto &fut : futs) { + DS_ASSERT_OK(fut.get()); + } + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(TryAndDeleteStream(w1Client_, "stream1")); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/single_consuemr_topo_test.cpp b/tests/st/client/stream_cache/single_consuemr_topo_test.cpp new file mode 100644 index 0000000..a3e2b3f --- /dev/null +++ b/tests/st/client/stream_cache/single_consuemr_topo_test.cpp @@ -0,0 +1,360 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test single consumer topo test. + */ +#include +#include +#include + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/utils/status.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/client/stream_cache/client_worker_api.h" +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +constexpr int K_TWO = 2; +class SingleConsumerTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = DEFAULT_WORKER_NUM; + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = " -log_monitor=true "; + opts.numRpcThreads = 0; + opts.vLogLevel = DEFAULT_LOG_LEVEL; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, w1Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, w2Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(K_TWO, w3Addr_)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, K_TWO)); + // Worker 1. + InitStreamClient(0, w1Client_); + InitStreamClient(1, w2Client_); + InitStreamClient(2, w3Client_); // index is 2 + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + + void TearDown() override + { + w1Client_ = nullptr; + w2Client_ = nullptr; + w3Client_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + // Mock producer worker. + HostPort w1Addr_; + HostPort w2Addr_; + HostPort w3Addr_; + + std::shared_ptr w1Client_ = nullptr; + std::shared_ptr w2Client_ = nullptr; + std::shared_ptr w3Client_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + const int DEFAULT_WAIT_TIME = 5000; + const int DEFAULT_WORKER_NUM = 3; + const int DEFAULT_LOG_LEVEL = 2; +}; + +class MPSCTest : public SingleConsumerTest {}; + +// create producer -> create consumer -> create consumer +TEST_F(MPSCTest, TestCreateMultiConsumer1) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPSC; + producerConf.maxStreamSize = TEST_STREAM_SIZE; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + { + // same worker. + std::string stream1("singleStream1"); + std::shared_ptr producer; + std::shared_ptr consumer; + std::shared_ptr consumer2; + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, producer, producerConf)); + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config, consumer)); + DS_ASSERT_NOT_OK(w1Client_->Subscribe(stream1, config2, consumer2)); + } + + { + // diff worker. + std::string stream2("singleStream2"); + std::shared_ptr producerW1; + std::shared_ptr consumerW2; + std::shared_ptr consumerW3; + DS_ASSERT_OK(w1Client_->CreateProducer(stream2, producerW1, producerConf)); + DS_ASSERT_OK(w2Client_->Subscribe(stream2, config, consumerW2)); + DS_ASSERT_NOT_OK(w3Client_->Subscribe(stream2, config2, consumerW3)); + } +} + +// create consumer -> create producer -> create consumer +TEST_F(MPSCTest, TestCreateMultiConsumer2) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPSC; + producerConf.maxStreamSize = TEST_STREAM_SIZE; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + + { + std::string stream1("singleStream1"); + std::shared_ptr consumer; + std::shared_ptr producer; + std::shared_ptr consumer2; + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config, consumer)); + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, producer, producerConf)); + DS_ASSERT_NOT_OK(w1Client_->Subscribe(stream1, config2, consumer2)); + } + + { + std::string stream2("singleStream2"); + std::shared_ptr consumerW1; + std::shared_ptr producerW2; + std::shared_ptr consumerW3; + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config, consumerW1)); + DS_ASSERT_OK(w2Client_->CreateProducer(stream2, producerW2, producerConf)); + DS_ASSERT_NOT_OK(w3Client_->Subscribe(stream2, config2, consumerW3)); + } +} + +// create consumer -> create consumer -> create producer +TEST_F(MPSCTest, TestCreateMultiConsumer3) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPSC; + producerConf.maxStreamSize = TEST_STREAM_SIZE; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + { + std::string stream1("singleStream1"); + std::shared_ptr consumer; + std::shared_ptr consumer2; + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config, consumer)); + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config2, consumer2)); + DS_ASSERT_NOT_OK(w1Client_->CreateProducer(stream1, producer, producerConf)); + } + + { + std::string stream2("singleStream2"); + std::shared_ptr consumerW1; + std::shared_ptr consumerW2; + std::shared_ptr producerW3; + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config, consumerW1)); + DS_ASSERT_OK(w2Client_->Subscribe(stream2, config2, consumerW2)); + DS_ASSERT_NOT_OK(w3Client_->CreateProducer(stream2, producerW3, producerConf)); + } +} + +TEST_F(MPSCTest, TestCreateMultiProducer) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPSC; + producerConf.maxStreamSize = TEST_STREAM_SIZE; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + { + std::string stream1("singleStream1"); + std::shared_ptr consumer; + std::shared_ptr producer1; + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config, consumer)); + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, producer1, producerConf)); + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, producer2, producerConf)); + } + + { + std::string stream2("singleStream2"); + std::shared_ptr consumerW1; + std::shared_ptr producer1W1; + std::shared_ptr producer2W2; + std::shared_ptr producer3W3; + DS_ASSERT_OK(w1Client_->Subscribe(stream2, config, consumerW1)); + DS_ASSERT_OK(w1Client_->CreateProducer(stream2, producer1W1, producerConf)); + DS_ASSERT_OK(w2Client_->CreateProducer(stream2, producer2W2, producerConf)); + DS_ASSERT_OK(w3Client_->CreateProducer(stream2, producer3W3, producerConf)); + } +} + +class SPSCTest : public SingleConsumerTest {}; + +// test create multi producer in SPSC mode. +TEST_F(SPSCTest, TestCreateMultiProducer) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::SPSC; + producerConf.maxStreamSize = TEST_STREAM_SIZE; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + + std::string stream1("singleStream1"); + std::shared_ptr consumer; + std::shared_ptr producer1; + std::shared_ptr producer2; + DS_ASSERT_OK(w1Client_->Subscribe(stream1, config, consumer)); + DS_ASSERT_OK(w1Client_->CreateProducer(stream1, producer1, producerConf)); + DS_ASSERT_NOT_OK(w1Client_->CreateProducer(stream1, producer2, producerConf)); + DS_ASSERT_NOT_OK(w2Client_->CreateProducer(stream1, producer2, producerConf)); +} + +class ShmReserveTest : public SingleConsumerTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + SingleConsumerTest::SetClusterSetupOptions(opts); + opts.workerGflagParams = " -shared_memory_size_mb=32 "; + } + +protected: + const int maxStreamSize_ = 16 * MB; +}; + +TEST_F(ShmReserveTest, TestMPMCProducerCreateOOM) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPMC; + producerConf.maxStreamSize = maxStreamSize_; + const int streamCount = 33; + std::vector> producers; + Status rc; + int i = 0; + for (; i < streamCount; i++) { + auto streamName = "stream-" + std::to_string(i); + std::shared_ptr producer; + rc = w1Client_->CreateProducer(streamName, producer, producerConf); + if (rc.IsError()) { + break; + } + producers.emplace_back(std::move(producer)); + } + LOG(INFO) << "stream count:" << i; + ASSERT_EQ(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); +} + +TEST_F(ShmReserveTest, TestMPSCDiffNodeProducerNotOOM) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPSC; + producerConf.maxStreamSize = maxStreamSize_; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + const int streamCount = 33; + const int producerCountPerStream = 3; + std::vector> producers; + std::vector> consumers; + for (int i = 0; i < streamCount; i++) { + auto streamName = "stream-" + std::to_string(i); + std::shared_ptr consumer; + const int subsWorkerCount = 2; + if (i % subsWorkerCount == 0) { + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + } else { + DS_ASSERT_OK(w3Client_->Subscribe(streamName, config, consumer)); + } + consumers.emplace_back(std::move(consumer)); + + for (int j = 0; j < producerCountPerStream; j++) { + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, producerConf)); + producers.emplace_back(std::move(producer)); + } + } +} + +TEST_F(ShmReserveTest, TestMPSCSomeNodeSubscribeOOM) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPSC; + producerConf.maxStreamSize = maxStreamSize_; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + const int streamCount = 33; + const int producerCountPerStream = 3; + std::vector> producers; + std::vector> consumers; + Status rc; + int i = 0; + for (; i < streamCount; i++) { + auto streamName = "stream-" + std::to_string(i); + for (int j = 0; j < producerCountPerStream; j++) { + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, producerConf)); + producers.emplace_back(std::move(producer)); + } + std::shared_ptr consumer; + rc = (w1Client_->Subscribe(streamName, config, consumer)); + if (rc.IsError()) { + break; + } + consumers.emplace_back(std::move(consumer)); + } + LOG(INFO) << "stream count:" << i; + ASSERT_EQ(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); +} + +TEST_F(ShmReserveTest, TestMPSCSomeNodeProducerOOM) +{ + ProducerConf producerConf; + producerConf.streamMode = StreamMode::MPSC; + producerConf.maxStreamSize = maxStreamSize_; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + const int streamCount = 33; + const int producerCountPerStream = 3; + std::vector> producers; + std::vector> consumers; + Status rc; + int i = 0; + for (; i < streamCount; i++) { + auto streamName = "stream-" + std::to_string(i); + std::shared_ptr consumer; + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + if (rc.IsError()) { + break; + } + consumers.emplace_back(std::move(consumer)); + for (int j = 0; j < producerCountPerStream; j++) { + std::shared_ptr producer; + rc = (w1Client_->CreateProducer(streamName, producer, producerConf)); + if (rc.IsOk()) { + break; + } + producers.emplace_back(std::move(producer)); + } + } + LOG(INFO) << "stream count:" << i; + ASSERT_EQ(rc.GetCode(), K_SC_STREAM_RESOURCE_ERROR); +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_client_replica_test.cpp b/tests/st/client/stream_cache/stream_client_replica_test.cpp new file mode 100644 index 0000000..999eac2 --- /dev/null +++ b/tests/st/client/stream_cache/stream_client_replica_test.cpp @@ -0,0 +1,1177 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Stream client multi replica tests. + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "common.h" +#include "common_distributed_ext.h" +#include "client/stream_cache/sc_client_common.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/common/kvstore/etcd/etcd_store.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/random_data.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/common/log/log.h" +#include "datasystem/stream_client.h" +#include "datasystem/utils/status.h" +#include "datasystem/worker/hash_ring/hash_ring_allocator.h" + +namespace datasystem { +namespace st { +struct StreamEntry { + std::string name; + std::shared_ptr producer; + std::shared_ptr consuemr; +}; + +class StreamReplicaTest : public SCClientCommon, public CommonDistributedExt { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + const int workerCount = 5; + opts.numEtcd = 1; + opts.numWorkers = workerCount; + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = + " -v=1 -shared_memory_size_mb=2048 -node_timeout_s=3 -enable_meta_replica=true -log_monitor=true"; + opts.waitWorkerReady = false; + SCClientCommon::SetClusterSetupOptions(opts); + opts.disableRocksDB = false; + } + + void SetUp() override + { + CommonTest::SetUp(); + DS_ASSERT_OK(Init()); + ASSERT_TRUE(cluster_ != nullptr); + DS_ASSERT_OK(cluster_->StartEtcdCluster()); + externalCluster_ = dynamic_cast(cluster_.get()); + } + + void TearDown() override + { + clients_.clear(); + ExternalClusterTest::TearDown(); + } + + void InitClients(int count) + { + for (int i = 0; i < count; i++) { + std::shared_ptr client; + InitStreamClient(i, client); + clients_.emplace_back(std::move(client)); + } + } + + void InitClients(const std::vector &indexes) + { + for (auto index : indexes) { + std::shared_ptr client; + InitStreamClient(index, client); + clients_.emplace_back(std::move(client)); + } + } + + void BasicTest(const int clientCount = 3) + { + int streamCount = 50; + int maxStreamSize = 10 * 1024 * 1024; + std::vector streamNames; + std::vector> producers; + std::vector> consumers; + for (int i = 0; i < streamCount; i++) { + std::string streamName = GetStringUuid(); + auto pubClient = clients_[i % clientCount]; + auto subClient = clients_[(i + 1) % clientCount]; + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + DS_ASSERT_OK(pubClient->CreateProducer(streamName, producer)); + SubscriptionConfig config; + std::shared_ptr consumer; + DS_ASSERT_OK(subClient->Subscribe(streamName, config, consumer)); + streamNames.emplace_back(std::move(streamName)); + producers.emplace_back(std::move(producer)); + consumers.emplace_back(std::move(consumer)); + } + + RandomData random; + for (auto &s : streamNames) { + auto client = clients_[random.GetRandomUint32() % clientCount]; + uint64_t producerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(s, producerCount)); + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(s, consumerCount)); + ASSERT_EQ(consumerCount, 1ul); + } + + for (auto &p : producers) { + DS_ASSERT_OK(p->Close()); + } + for (auto &c : consumers) { + DS_ASSERT_OK(c->Close()); + } + for (auto &s : streamNames) { + auto client = clients_[random.GetRandomUint32() % clientCount]; + DS_ASSERT_OK(client->DeleteStream(s)); + } + } + + BaseCluster *GetCluster() override + { + return cluster_.get(); + } + + void CreateStreams(std::vector &streams, int streamCount) + { + int maxStreamSize = 10 * 1024 * 1024; + auto clientCount = clients_.size(); + ASSERT_GT(clientCount, 0) << "clients_ not init."; + for (int i = 0; i < streamCount; i++) { + std::string streamName = GetStringUuid(); + auto pubClient = clients_[i % clientCount]; + auto subClient = clients_[(i + 1) % clientCount]; + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + DS_ASSERT_OK(pubClient->CreateProducer(streamName, producer)); + SubscriptionConfig config; + std::shared_ptr consumer; + DS_ASSERT_OK(subClient->Subscribe(streamName, config, consumer)); + streams.emplace_back(StreamEntry{ streamName, producer, consumer }); + } + } + + void VerifyStream(std::vector &streams) + { + auto clientCount = clients_.size(); + ASSERT_GT(clientCount, 0) << "clients_ not init."; + RandomData random; + for (auto &entry : streams) { + auto s = entry.name; + auto client = clients_[random.GetRandomUint32() % clientCount]; + uint64_t producerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(s, producerCount)); + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(s, consumerCount)); + ASSERT_EQ(consumerCount, 1ul); + } + } + + Status SetNotReadyInject(const std::initializer_list indexes) + { + auto injects = { "master.CreateProducer", "master.CloseProducer", "master.Subscribe", + "master.CloseConsumer", "master.DeleteStream", "master.QueryGlobalProducersNum", + "master.QueryGlobalConsumersNum" }; + + for (auto index : indexes) { + for (auto inject : injects) { + RETURN_IF_NOT_OK(cluster_->SetInjectAction(WORKER, index, inject, "10%return(K_REPLICA_NOT_READY)")); + } + RETURN_IF_NOT_OK( + cluster_->SetInjectAction(WORKER, index, "master.MigrateSCMetadata", "3*return(K_REPLICA_NOT_READY)")); + } + return Status::OK(); + } + +protected: + ExternalCluster *externalCluster_ = nullptr; + std::vector> clients_; +}; + +class StreamReplicaRouterTest : public StreamReplicaTest {}; + +TEST_F(StreamReplicaRouterTest, DISABLED_LEVEL1_BasicPubSubTest) +{ + // replica layout: + // worker0: { primary: [worker0], backup: [worker1] } + // worker1: { primary: [worker1, worker2], backup: []} + // worker2: { primary: [], backup: [worker0, worker2]} + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + InjectSyncCap({ 0, 1, 2 }, UINT16_MAX, 1); + DS_ASSERT_OK(SetNotReadyInject({ 0, 1, 2 })); + const int clientCount = 3; + InitClients(clientCount); + + int streamCount = 50; + std::vector streams; + CreateStreams(streams, streamCount); + VerifyStream(streams); + + RandomData random; + + for (auto &entry : streams) { + auto s = entry.name; + auto client = clients_[random.GetRandomUint32() % clientCount]; + DS_ASSERT_OK(entry.producer->Close()); + DS_ASSERT_OK(entry.consuemr->Close()); + DS_ASSERT_OK(client->DeleteStream(s)); + } +} + +TEST_F(StreamReplicaRouterTest, LEVEL2_SubScribeFirstSendReceive) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + const int clientCount = 3; + InitClients(clientCount); + + int streamCount = 5; + int maxStreamSize = 10 * 1024 * 1024; + std::vector streamNames; + std::vector> producers; + std::vector> consumers; + for (int i = 0; i < streamCount; i++) { + std::string streamName = GetStringUuid(); + auto pubClient = clients_[i % clientCount]; + auto subClient = clients_[(i + 1) % clientCount]; + SubscriptionConfig config; + std::shared_ptr consumer; + DS_ASSERT_OK(subClient->Subscribe(streamName, config, consumer)); + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + DS_ASSERT_OK(pubClient->CreateProducer(streamName, producer)); + std::string data = randomData_.GetPartRandomString(100, 0); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector elements; + DS_ASSERT_OK(consumer->Receive(1, 10000, elements)); // wait time is 10000 ms + ASSERT_EQ(elements.size(), (size_t)1); + streamNames.emplace_back(std::move(streamName)); + producers.emplace_back(std::move(producer)); + consumers.emplace_back(std::move(consumer)); + } + + RandomData random; + for (auto &s : streamNames) { + auto client = clients_[random.GetRandomUint32() % clientCount]; + uint64_t producerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(s, producerCount)); + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(s, consumerCount)); + ASSERT_EQ(consumerCount, 1ul); + } +} + +TEST_F(StreamReplicaRouterTest, LEVEL2_CloseProducer) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + const int clientCount = 3; + InitClients(clientCount); + + int maxStreamSize = 10 * 1024 * 1024; + + std::string streamName = GetStringUuid(); + auto pubClient = clients_[0]; + auto subClient = clients_[1]; + SubscriptionConfig config; + std::shared_ptr consumer; + DS_ASSERT_OK(subClient->Subscribe(streamName, config, consumer)); + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + DS_ASSERT_OK(pubClient->CreateProducer(streamName, producer)); + std::string data = randomData_.GetPartRandomString(100, 0); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector elements; + DS_ASSERT_OK(consumer->Receive(1, 10000, elements)); // wait time is 10000 ms + ASSERT_EQ(elements.size(), (size_t)1); + + RandomData random; + auto client = clients_[random.GetRandomUint32() % clientCount]; + uint64_t producerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 1ul); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 0ul); + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 0ul); + + DS_ASSERT_OK(client->DeleteStream(streamName)); +} + +TEST_F(StreamReplicaRouterTest, LEVEL2_CloseProducerAndConsumer) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + const int clientCount = 3; + InitClients(clientCount); + + int maxStreamSize = 10 * 1024 * 1024; + + std::string streamName = GetStringUuid(); + auto pubClient = clients_[0]; + auto subClient = clients_[1]; + SubscriptionConfig config; + std::shared_ptr consumer; + DS_ASSERT_OK(subClient->Subscribe(streamName, config, consumer)); + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + DS_ASSERT_OK(pubClient->CreateProducer(streamName, producer)); + std::string data = randomData_.GetPartRandomString(100, 0); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + std::vector elements; + DS_ASSERT_OK(consumer->Receive(1, 10000, elements)); // wait time is 10000 ms + ASSERT_EQ(elements.size(), (size_t)1); + + RandomData random; + + auto client = clients_[random.GetRandomUint32() % clientCount]; + uint64_t producerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 1ul); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 0ul); + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 0ul); + + DS_ASSERT_OK(client->DeleteStream(streamName)); +} + +TEST_F(StreamReplicaRouterTest, DISABLED_LEVEL2_TestResetResumeStream) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + const int clientCount = 3; + InitClients(clientCount); + int maxStreamSize = 10 * 1024 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + std::shared_ptr producer; + std::string streamName = "testResetAndResumeStream"; + DS_ASSERT_OK(clients_[0]->CreateProducer(streamName, producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients_[0]->Subscribe(streamName, config, consumer)); + std::vector streamNames; + streamNames.push_back(streamName); + + std::string data = "Hello World 1"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + EXPECT_EQ(outElements.size(), (size_t)1); + + DS_ASSERT_NOT_OK(producer->Send(element)); + + DS_ASSERT_NOT_OK(consumer->Receive(1, 0, outElements)); + DS_ASSERT_NOT_OK(consumer->Ack(outElements.back().id)); + + DS_ASSERT_OK(producer->Send(element)); + std::vector elements; + DS_ASSERT_OK(consumer->Receive(1, 10000, elements)); // wait time is 10000 ms + ASSERT_EQ(elements.size(), (size_t)1); +} + +TEST_F(StreamReplicaRouterTest, LEVEL2_TestMPSC) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + // wait for cluster replica to finish init + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + + const int clientCount = 3; + InitClients(clientCount); + + int maxStreamSize = 10 * 1024 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + std::shared_ptr producer, producer1; + std::string streamName = "testMPSC"; + DS_ASSERT_OK(clients_[0]->CreateProducer(streamName, producer, conf)); + DS_ASSERT_OK(clients_[0]->CreateProducer(streamName, producer1, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients_[0]->Subscribe(streamName, config, consumer)); + std::vector streamNames; + streamNames.push_back(streamName); + + std::string data = "Hello World 1"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer1->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(2, 0, outElements)); // idx is 2 + + EXPECT_EQ(outElements.size(), (size_t)2); // idx is 2 + auto client = clients_[1]; + uint64_t producerCount; + // Number of worker that have at least 1 producer for the stream, if there are 2 or more producer for the same + // stream on the same worker, that count as 1. + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 1ul); +} + +TEST_F(StreamReplicaRouterTest, LEVEL2_TestMPSCClose) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + // wait for cluster replica to finish init + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + + const int clientCount = 3; + InitClients(clientCount); + int maxStreamSize = 10 * 1024 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + std::shared_ptr producer, producer1; + std::string streamName = "testMPSCClose"; + DS_ASSERT_OK(clients_[0]->CreateProducer(streamName, producer, conf)); + DS_ASSERT_OK(clients_[0]->CreateProducer(streamName, producer1, conf)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients_[0]->Subscribe(streamName, config, consumer)); + std::vector streamNames; + streamNames.push_back(streamName); + + std::string data = "Hello World 1"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer1->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(2, 0, outElements)); // idx is 2 + + EXPECT_EQ(outElements.size(), (size_t)2); // idx is 2 + auto client = clients_[1]; + uint64_t producerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + // Number of worker that have at least 1 producer for the stream, if there are 2 or more producer for the same + // stream on the same worker, that count as 1. + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 1ul); + + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 1ul); + + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 0ul); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 0ul); +} + +TEST_F(StreamReplicaRouterTest, LEVEL2_TestSPMC) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + // wait for cluster replica to finish init + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + + const int clientCount = 3; + InitClients(clientCount); + int maxStreamSize = 10 * 1024 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + std::shared_ptr producer; + std::string streamName = "testSPMC"; + DS_ASSERT_OK(clients_[0]->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer, consumer1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients_[0]->Subscribe(streamName, config, consumer)); + SubscriptionConfig config1("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(clients_[0]->Subscribe(streamName, config1, consumer1)); + std::vector streamNames; + streamNames.push_back(streamName); + + std::string data = "Hello World 1"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements, outElements1; + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(consumer->Receive(1, 0, outElements1)); + EXPECT_EQ(outElements.size(), (size_t)1); + EXPECT_EQ(outElements1.size(), (size_t)1); + auto client = clients_[1]; + uint64_t producerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 1ul); + uint64_t consumerCount; + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 2ul); + + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 0ul); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 0ul); +} + +TEST_F(StreamReplicaRouterTest, LEVEL2_TestMPMC) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + // wait for cluster replica to finish init + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + + const int clientCount = 2; + InitClients(clientCount); + int maxStreamSize = 10 * 1024 * 1024; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + std::shared_ptr producer; + std::string streamName = "TestMPMC"; + DS_ASSERT_OK(clients_[0]->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer, consumer1; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(clients_[0]->Subscribe(streamName, config, consumer)); + SubscriptionConfig config1("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(clients_[0]->Subscribe(streamName, config1, consumer1)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + EXPECT_EQ(outElements.size(), (size_t)1); + + auto client = clients_[1]; + uint64_t producerCount, consumerCount; + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + ASSERT_EQ(producerCount, 1ul); + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, 2ul); + + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, producerCount)); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, consumerCount)); + ASSERT_EQ(consumerCount, producerCount); +} + +class StreamReplicaScaleTest : public StreamReplicaTest {}; + +TEST_F(StreamReplicaScaleTest, DISABLED_TestScaleUp) +{ + // Worker exists two primary replica and migrate data to the new node. + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2 }, { { 2, " -inject_actions=worker.ClusterInitFinish:return()" } })); + InjectSyncCap({ 0, 1, 2 }, UINT16_MAX, 1); + + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + const int clientCount = 3; + InitClients(clientCount); + + int streamCount = 50; + std::vector streams; + CreateStreams(streams, streamCount); + VerifyStream(streams); + + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 3 }, " -inject_actions=master.MigrateSCMetadata:3*return(K_REPLICA_NOT_READY)")); + + const int nodeCount = 4; + WaitAllNodesJoinIntoHashRing(nodeCount); + VerifyStream(streams); +} + +TEST_F(StreamReplicaScaleTest, DISABLED_TestVoluntaryScaleDown) +{ + // The worker0 switch replica. + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2, 3 }, { { 0, " -inject_actions=worker.ClusterInitFinish:return()" } })); + InjectSyncCap({ 0, 1, 2 }, UINT16_MAX, 1); + + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + + // Get the next worker of worker0 in hashring. + std::vector indexes = { 0, 1, 2, 3 }; + std::vector clientIndexes = indexes; + InitWorkersInfoMap(indexes); + auto nextWorkerOfWorker0 = workersInfo_[0].nextIndex; + + // Get the scale down node index + indexes.erase( + std::remove_if(indexes.begin(), indexes.end(), [&](const int &index) { return index == nextWorkerOfWorker0; }), + indexes.end()); + + const int scaleDownIndex = indexes.back(); + + // Init client. + clientIndexes.erase(std::remove_if(clientIndexes.begin(), clientIndexes.end(), + [&](const int &index) { return index == scaleDownIndex; }), + clientIndexes.end()); + LOG(INFO) << "init client for index:" << VectorToString(clientIndexes); + InitClients(clientIndexes); + + int streamCount = 50; + std::vector streams; + CreateStreams(streams, streamCount); + VerifyStream(streams); + + VoluntaryScaleDownInject(scaleDownIndex); + + const int nodeCount = 3; + WaitAllNodesJoinIntoHashRing(nodeCount); + VerifyStream(streams); +} + +TEST_F(StreamReplicaScaleTest, DISABLED_TestScaleDownSourceSwitchReplica) +{ + // test migrate data sorce node already switch replica. + const int workerIndex = 3; + const int maxStartTimeoutSec = 30; + const int nodeTimeout = 5; + const int nodeDeadTimeoutS = 8; + // The worker3 switch replica then scale down. + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2, 3 }, { { workerIndex, " -inject_actions=worker.ClusterInitFinish:return()" } }, maxStartTimeoutSec, + FormatString("-node_timeout_s=%d -node_dead_timeout_s=%d -v=2", nodeTimeout, nodeDeadTimeoutS))); + InjectSyncCap({ 0, 1, 2, 3 }, UINT16_MAX, 1); + + // remove sleep after add retry logic. + int timeout = 3; + sleep(timeout); + const int clientCount = 3; + InitClients(clientCount); + + int streamCount = 50; + std::vector streams; + CreateStreams(streams, streamCount); + VerifyStream(streams); + + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, workerIndex)); + + const int nodeCount = 3; + WaitAllNodesJoinIntoHashRing(nodeCount); + VerifyStream(streams); +} + +TEST_F(StreamReplicaScaleTest, DISABLED_TestScaleDownTargetSwitchReplica) +{ + // test migrate data target node already switch replica. + const int nodeTimeout = 5; + const int nodeDeadTimeoutS = 8; + // The worker3 switch replica then scale down. + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady( + { 0, 1, 2, 3 }, + FormatString(" -node_timeout_s=%d -node_dead_timeout_s=%d -v=2", nodeTimeout, nodeDeadTimeoutS))); + InjectSyncCap({ 0, 1, 2, 3 }, UINT16_MAX, 1); + + // Get the next worker of worker0 in hashring. + std::vector indexes = { 0, 1, 2, 3 }; + std::vector clientIndexes = indexes; + InitWorkersInfoMap(indexes); + + int indexOfWorkera, indexOfWorkerb; + if (!GetTwoWorkerNotBackupEachOther(indexOfWorkera, indexOfWorkerb)) { + LOG(INFO) << "Cannot find two workers that don't backup each other"; + return; + } + LOG(INFO) << "a:" << indexOfWorkera << ", b:" << indexOfWorkerb; + + // // switch replica for workera and workerb + auto uuidOfWorkera = workersInfo_[indexOfWorkera].uuid; + auto nextWorkerUuidOfWorkera = workersInfo_[indexOfWorkera].nextUuid; + LOG(INFO) << "workera:" << uuidOfWorkera << ", " << nextWorkerUuidOfWorkera; + + auto uuidOfWorkerb = workersInfo_[indexOfWorkerb].uuid; + auto nextWorkerUuidOfWorkerb = workersInfo_[indexOfWorkerb].nextUuid; + + LOG(INFO) << "workerb:" << uuidOfWorkerb << ", " << nextWorkerUuidOfWorkerb; + + DS_ASSERT_OK(SetWaitingElection(uuidOfWorkera, nextWorkerUuidOfWorkera)); + DS_ASSERT_OK(SetWaitingElection(uuidOfWorkerb, nextWorkerUuidOfWorkerb)); + + WaitReplicaNotInCurrentNode(indexOfWorkera); + WaitReplicaNotInCurrentNode(indexOfWorkerb); + + // Init client. + int timeout = 3; + sleep(timeout); + clientIndexes.erase(std::remove_if(clientIndexes.begin(), clientIndexes.end(), + [&](const int &index) { return index == indexOfWorkera; }), + clientIndexes.end()); + InitClients(clientIndexes); + + int streamCount = 50; + std::vector streams; + CreateStreams(streams, streamCount); + VerifyStream(streams); + + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, indexOfWorkera)); + + const int nodeCount = 3; + WaitAllNodesJoinIntoHashRing(nodeCount); + VerifyStream(streams); +} + +class StreamUpdateToReplicaTest : public StreamReplicaRouterTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + const int workerCount = 5; + opts.numEtcd = 1; + opts.numWorkers = workerCount; + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = " -v=1 -shared_memory_size_mb=2048 -node_timeout_s=3 -log_monitor=true"; + opts.waitWorkerReady = false; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + CommonTest::SetUp(); + DS_ASSERT_OK(Init()); + ASSERT_TRUE(cluster_ != nullptr); + DS_ASSERT_OK(cluster_->StartEtcdCluster()); + externalCluster_ = dynamic_cast(cluster_.get()); + } +}; + +TEST_F(StreamUpdateToReplicaTest, LEVEL1_TestUpdateToReplicaEnable) +{ + DS_ASSERT_OK(externalCluster_->StartWorkerAndWaitReady({ 0, 1 })); + const int clientCount = 2; + InitClients(clientCount); + BasicTest(2); // index is 2 + ThreadPool pool(2); + auto fut1 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); }); + auto fut2 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); }); + fut1.get(); + fut2.get(); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 1, "-enable_meta_replica=true")); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "-enable_meta_replica=true")); + fut1 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); }); + fut2 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); }); + fut1.get(); + fut2.get(); + + BasicTest(2); // index is 2 +} + +const std::string HOST_IP_PREFIX = "127.0.0.1"; +constexpr size_t DEFAULT_WORKER_NUM = 2; +class StreamClientWriteRocksdbTest : public StreamReplicaTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = DEFAULT_WORKER_NUM; + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = " -v=1 -shared_memory_size_mb=2048 "; + SCClientCommon::SetClusterSetupOptions(opts); + opts.disableRocksDB = false; + for (size_t i = 0; i < DEFAULT_WORKER_NUM; i++) { + opts.workerConfigs.emplace_back(HOST_IP_PREFIX + std::to_string(i), GetFreePort()); + workerHost_.emplace_back(HOST_IP_PREFIX + std::to_string(i)); + workerAddress_.emplace_back(opts.workerConfigs.back().ToString()); + } + } + + void InitTestEtcdInstance() + { + std::string etcdAddress; + for (size_t i = 0; i < cluster_->GetEtcdNum(); ++i) { + std::pair addrs; + cluster_->GetEtcdAddrs(i, addrs); + if (!etcdAddress.empty()) { + etcdAddress += ","; + } + etcdAddress += addrs.first.ToString(); + } + FLAGS_etcd_address = etcdAddress; + LOG(INFO) << "The etcd address is:" << FLAGS_etcd_address << std::endl; + db_ = std::make_unique(etcdAddress); + if ((db_ != nullptr) && (db_->Init().IsOk())) { + db_->DropTable(ETCD_RING_PREFIX); + // We don't check rc here. If table to drop does not exist, it's fine. + LOG(INFO) << "create table"; + (void)db_->CreateTable(ETCD_RING_PREFIX, ETCD_RING_PREFIX); + (void)db_->CreateTable(std::string(ETCD_GLOBAL_CACHE_TABLE_PREFIX) + ETCD_HASH_SUFFIX, + std::string(ETCD_GLOBAL_CACHE_TABLE_PREFIX) + ETCD_HASH_SUFFIX); + (void)db_->CreateTable(std::string(ETCD_GLOBAL_CACHE_TABLE_PREFIX) + ETCD_WORKER_SUFFIX, + std::string(ETCD_GLOBAL_CACHE_TABLE_PREFIX) + ETCD_WORKER_SUFFIX); + (void)db_->CreateTable(std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_HASH_SUFFIX, + std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_HASH_SUFFIX); + (void)db_->CreateTable(std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_WORKER_SUFFIX, + std::string(ETCD_LOCATION_TABLE_PREFIX) + ETCD_WORKER_SUFFIX); + (void)db_->CreateTable(std::string(ETCD_ASYNC_WORKER_OP_TABLE_PREFIX) + ETCD_HASH_SUFFIX, + std::string(ETCD_ASYNC_WORKER_OP_TABLE_PREFIX) + ETCD_HASH_SUFFIX); + } + } + + void SetUp() override + { + CommonTest::SetUp(); + DS_ASSERT_OK(Init()); + ASSERT_TRUE(cluster_ != nullptr); + DS_ASSERT_OK(cluster_->StartEtcdCluster()); + InitTestEtcdInstance(); + externalCluster_ = dynamic_cast(cluster_.get()); + } + + void TearDown() override + { + db_.reset(); + ExternalClusterTest::TearDown(); + } + + void GetHashOnWorker(size_t workerNum = 2) + { + std::string value; + db_->Get(ETCD_RING_PREFIX, "", value); + HashRingPb ring; + ring.ParseFromString(value); + LOG(INFO) << "ring: " << ring.DebugString(); + for (size_t i = 0; i < workerNum; ++i) { + auto tokens = ring.workers().at(workerAddress_[i]).hash_tokens(); + workerHashValue_.emplace_back(*tokens.begin() - 1); + LOG(INFO) << FormatString("workerAddress_ %s, workerHashValue_ %d", workerAddress_[i], *tokens.begin() - 1); + } + ASSERT_EQ(workerHashValue_.size(), workerNum); + } + + void GetWorkerUuids() + { + std::string value; + DS_ASSERT_OK(db_->Get(ETCD_RING_PREFIX, "", value)); + HashRingPb ring; + ring.ParseFromString(value); + for (auto worker : ring.workers()) { + HostPort workerAddr; + DS_ASSERT_OK(workerAddr.ParseString(worker.first)); + uuidMap_.emplace(std::move(workerAddr), worker.second.worker_uuid()); + } + } + + void SetWorkerHashInjection(std::vector injectNode = std::vector{}) + { + if (injectNode.size() == 0) { + for (size_t i = 0; i < DEFAULT_WORKER_NUM; ++i) { + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, i, "MurmurHash3", "return()")); + } + return; + } + + for (auto i : injectNode) { + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, i, "MurmurHash3", "return()")); + } + } + + void UnsetWorkerHashInjection(std::vector injectNode = std::vector{}) + { + if (injectNode.size() == 0) { + for (size_t i = 0; i < DEFAULT_WORKER_NUM; ++i) { + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, i, "MurmurHash3")); + } + return; + } + + for (auto i : injectNode) { + DS_ASSERT_OK(cluster_->ClearInjectAction(WORKER, i, "MurmurHash3")); + } + } + + void StartWorkerAndWaitReady(std::initializer_list indexes, + const std::unordered_map &workerFlags = {}, int maxWaitTimeSec = 20) + { + for (auto i : indexes) { + std::string flags; + auto iter = workerFlags.find(i); + if (iter != workerFlags.end()) { + flags = " " + iter->second; + } + ASSERT_TRUE(externalCluster_->StartWorker(i, HostPort(), flags).IsOk()) << i; + } + for (auto i : indexes) { + ASSERT_TRUE(cluster_->WaitNodeReady(WORKER, i, maxWaitTimeSec).IsOk()) << i; + } + for (auto i : indexes) { + // When the scale-in scenario is tested, the scale-in failure may not be determined correctly. + // Therefore, the scale-in failure is directly exited. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "Hashring.Scaletask.Fail", "abort()")); + } + InitWorkersInfoMap(indexes); + } + + void StartWorkerAndWaitReady(std::initializer_list indexes, const std::string &flags, int maxWaitTimeSec = 20) + { + std::unordered_map workerFlags; + for (auto i : indexes) { + workerFlags.emplace(i, flags); + } + StartWorkerAndWaitReady(indexes, workerFlags, maxWaitTimeSec); + } + +protected: + Status InitClient(int index, std::shared_ptr &client) + { + InitStreamClient(index, client); + return Status::OK(); + } + + Status CreateConsumer(std::shared_ptr client, const std::string &streamName, + const std::string &subName, std::shared_ptr &consumer) + { + SubscriptionConfig config(subName, SubscriptionType::STREAM); + return client->Subscribe(streamName, config, consumer); + } + + Status CreateProducer(std::shared_ptr client, const std::string &streamName, + std::shared_ptr &producer) + { + const int64_t autoFlushTime = 10 * 1000; // 10s; + ProducerConf conf = { .delayFlushTime = autoFlushTime, + .pageSize = 20 * 1024, + .maxStreamSize = TEST_STREAM_SIZE }; + return client->CreateProducer(streamName, producer, conf); + } + + std::shared_ptr client1_; + std::shared_ptr client2_; + std::unique_ptr db_; + std::vector workerAddress_; + std::vector workerHost_; + std::vector workerHashValue_; + std::unordered_map uuidMap_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(StreamClientWriteRocksdbTest, TestNodeRestartWithNoneMode) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=none"); + GetHashOnWorker(DEFAULT_WORKER_NUM); + SetWorkerHashInjection(); + DS_ASSERT_OK(InitClient(0, client1_)); + std::shared_ptr producer; + std::shared_ptr consumer1; + int index = 0; + int num = 1; + std::string streamName = "a_key_hash_to_" + std::to_string(workerHashValue_[index] - num); + // std::string streamName = "test"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "subname1", consumer1)); + + std::string str = "hello world!"; + Element element(reinterpret_cast((uint8_t *)str.data()), str.length()); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + uint32_t expectRecvNum = 10; + ASSERT_EQ(consumer1->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer1->Ack(outElements.back().id)); + + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + + ASSERT_EQ(producer->Send(element).GetCode(), StatusCode::K_SC_ALREADY_CLOSED); + ASSERT_EQ(consumer1->Receive(expectRecvNum, 0, outElements).GetCode(), StatusCode::K_SC_ALREADY_CLOSED); +} + +TEST_F(StreamClientWriteRocksdbTest, TestNodeRestartWithNoneMode2) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=none"); + GetHashOnWorker(DEFAULT_WORKER_NUM); + DS_ASSERT_OK(InitClient(0, client1_)); + SetWorkerHashInjection(); + std::shared_ptr producer; + std::shared_ptr consumer1; + int index = 1; + LOG(INFO) << "workerHashValue_[index] " << workerHashValue_[index]; + std::string streamName = "a_key_hash_to_" + std::to_string(workerHashValue_[index] - index); + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "subname1", consumer1)); + + std::string str = "hello world!"; + Element element(reinterpret_cast((uint8_t *)str.data()), str.length()); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + uint32_t expectRecvNum = 10; + ASSERT_EQ(consumer1->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer1->Ack(outElements.back().id)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + + std::shared_ptr producer2; + std::shared_ptr consumer2; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer2)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "subname2", consumer2)); + + std::string str2 = "hello world 2"; + Element element2(reinterpret_cast((uint8_t *)str2.data()), str2.length()); + DS_ASSERT_OK(producer->Send(element2)); + outElements.clear(); + ASSERT_EQ(consumer1->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 1ul); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(str2, actualData); + outElements.clear(); + ASSERT_EQ(consumer2->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 1ul); + std::string actualData2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(str2, actualData2); + + std::string str3 = "hello world 3"; + Element element3(reinterpret_cast((uint8_t *)str3.data()), str3.length()); + DS_ASSERT_OK(producer2->Send(element3)); + outElements.clear(); + ASSERT_EQ(consumer2->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 1ul); + std::string actualData3(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(str3, actualData3); + + outElements.clear(); + ASSERT_EQ(consumer1->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 1ul); + std::string actualData4(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(str3, actualData4); +} + +TEST_F(StreamClientWriteRocksdbTest, TestNodeRestartWithNoneMode3) +{ + StartWorkerAndWaitReady({ 0, 1 }, "-rocksdb_write_mode=none"); + int waitHashSecond = 2; + sleep(waitHashSecond); + GetHashOnWorker(DEFAULT_WORKER_NUM); + SetWorkerHashInjection(); + DS_ASSERT_OK(InitClient(0, client1_)); + std::shared_ptr producer; + std::shared_ptr consumer1; + int index = 1; + LOG(INFO) << "workerHashValue_[index] " << workerHashValue_[index]; + std::string streamName = "a_key_hash_to_" + std::to_string(workerHashValue_[index] - index); + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "subname1", consumer1)); + + std::string str = "hello world!"; + Element element(reinterpret_cast((uint8_t *)str.data()), str.length()); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + uint32_t expectRecvNum = 1; + ASSERT_EQ(consumer1->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer1->Ack(outElements.back().id)); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + + DS_ASSERT_OK(InitClient(1, client2_)); + std::shared_ptr producer2; + std::shared_ptr consumer2; + DS_ASSERT_OK(CreateProducer(client2_, streamName, producer2)); + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "subname3", consumer2)); + + std::string str2 = "hello world 2"; + Element element2(reinterpret_cast((uint8_t *)str2.data()), str2.length()); + DS_ASSERT_OK(producer->Send(element2)); + outElements.clear(); + + ASSERT_EQ(consumer2->Receive(expectRecvNum, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 0ul); + + std::string str3 = "hello world 3"; + Element element3(reinterpret_cast((uint8_t *)str3.data()), str3.length()); + DS_ASSERT_OK(producer2->Send(element3)); + outElements.clear(); + int timeoutMs = 1000; + ASSERT_EQ(consumer2->Receive(expectRecvNum + 1, timeoutMs, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 2ul); + std::string actualData0(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(str3, actualData0); + std::string actualData(reinterpret_cast(outElements[1].ptr), outElements[1].size); + EXPECT_EQ(str2, actualData); + + outElements.clear(); + ASSERT_EQ(consumer1->Receive(expectRecvNum + 1, timeoutMs, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 2ul); + std::string actualData2(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(str2, actualData2); + std::string actualData3(reinterpret_cast(outElements[1].ptr), outElements[1].size); + EXPECT_EQ(str3, actualData3); +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_client_scale_test.cpp b/tests/st/client/stream_cache/stream_client_scale_test.cpp new file mode 100644 index 0000000..f293848 --- /dev/null +++ b/tests/st/client/stream_cache/stream_client_scale_test.cpp @@ -0,0 +1,1874 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Stream cache client scale tests. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "common.h" +#include "common_distributed_ext.h" +#include "common/stream_cache/stream_common.h" +#include "client/stream_cache/sc_client_common.h" +#include "datasystem/common/util/format.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" +#include "datasystem/utils/status.h" + +DS_DECLARE_string(etcd_address); +DS_DECLARE_string(log_dir); + +namespace datasystem { +namespace st { +constexpr int K_TWO = 2; +constexpr int K_TEN = 10; +constexpr int K_THIRTY = 30; +constexpr int SCALE_UP_WAIT_TIME = 3; +constexpr int SCALE_DOWN_WAIT_TIME = 3; +constexpr int NODE_DEAD_TIMEOUT = 8; +const std::string HOST_IP = "127.0.0.1"; +class StreamClientScaleTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = workerNum_; + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = FormatString(" -v=2 -node_timeout_s=%d -node_dead_timeout_s=%d -log_monitor=true", + nodeTimeoutS_, nodeDeadTimeoutS_); + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitStreamClient(0, w1Client_); + InitStreamClient(1, w2Client_); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + + InitTestEtcdInstance(); + } + + void TearDown() override + { + w1Client_.reset(); + w2Client_.reset(); + ExternalClusterTest::TearDown(); + } + + void CheckCount(std::shared_ptr client, const std::string &streamName, int producerCount, + int consumerCount) + { + uint64_t result = 0; + if (producerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, result)); + EXPECT_EQ(result, static_cast(producerCount)); + result = 0; + } + if (consumerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, result)); + EXPECT_EQ(result, static_cast(consumerCount)); + result = 0; + } + } + + Status AddNode() + { + const int newWorkerIdx = workerNum_++; + HostPort workerAddr(HOST_IP, GetFreePort()); + HostPort masterAddr; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(0, masterAddr)); + RETURN_IF_NOT_OK(cluster_->AddNode(masterAddr, workerAddr.ToString(), GetFreePort())); + RETURN_IF_NOT_OK(cluster_->WaitNodeReady(WORKER, newWorkerIdx)); + return Status::OK(); + } + + void VoluntaryScaleDownInject(int workerIdx) + { + std::string checkFilePath = FLAGS_log_dir.c_str(); + std::string client = "client"; + checkFilePath = checkFilePath.substr(0, checkFilePath.length() - client.length()) + "/worker" + + std::to_string(workerIdx) + "/log/worker-status"; + std::ofstream ofs(checkFilePath); + if (!ofs.is_open()) { + LOG(ERROR) << "Can not open worker status file in " << checkFilePath + << ", voluntary scale in will not start, errno: " << errno; + } else { + ofs << "voluntary scale in\n"; + } + ofs.close(); + kill(cluster_->GetWorkerPid(workerIdx), SIGTERM); + } + + void InitTestEtcdInstance() + { + if (db_ != nullptr) { + return; + } + std::string etcdAddress; + for (size_t i = 0; i < cluster_->GetEtcdNum(); ++i) { + std::pair addrs; + cluster_->GetEtcdAddrs(i, addrs); + if (!etcdAddress.empty()) { + etcdAddress += ","; + } + etcdAddress += addrs.first.ToString(); + } + FLAGS_etcd_address = etcdAddress; + db_ = std::make_unique(etcdAddress); + DS_ASSERT_OK(db_->Init()); + (void)db_->CreateTable(ETCD_RING_PREFIX, ETCD_RING_PREFIX); + (void)db_->CreateTable(ETCD_CLUSTER_TABLE, "/" + std::string(ETCD_CLUSTER_TABLE)); + } + + bool CheckScaleDownFinished(const std::string &workerAddr) + { + std::string value; + auto status = db_->Get(ETCD_CLUSTER_TABLE, workerAddr, value); + if (status.GetCode() == K_NOT_FOUND) { + return true; + } + return false; + } + + void WaitForVoluntaryDownFinished(int workerIndex, int timeoutS = 50) + { + HostPort workerAddr; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex, workerAddr)); + Timer timer; + while (timer.ElapsedSecond() < timeoutS) { + if (CheckScaleDownFinished(workerAddr.ToString()) && !cluster_->CheckWorkerProcess(workerIndex)) { + LOG(INFO) << "scale down finish time: " << timer.ElapsedSecond() + << " worker: " << workerAddr.ToString(); + return; + } + auto interval = 100; + std::this_thread::sleep_for(std::chrono::milliseconds(interval)); + } + ASSERT_TRUE(false) << "Voluntary scaling down is not completed: " << workerAddr.ToString(); + } + + /** + * @brief Creates streamNum producers and consumers on w1Client, placing them into streams + * @param[in] streams The map of stream names to pairs of prod/cons + * @param[in] streamNum The number of streams to create + * @param[in] sameWorker Whether to create consumers on w2Client instead + */ + void CreateNProducerAndConsumer(std::map, std::shared_ptr>> &streams, + int streamNum, std::string streamName, + bool sameWorker = true) + { + for (int i = 0; i < streamNum; ++i) { + std::string strmName = streamName + std::to_string(i); + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(strmName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub" + std::to_string(i), SubscriptionType::STREAM); + if (sameWorker) { + DS_ASSERT_OK(w1Client_->Subscribe(strmName, config, consumer)); + } else { + DS_ASSERT_OK(w2Client_->Subscribe(strmName, config, consumer)); + } + streams.emplace(strmName, std::make_pair(producer, consumer)); + CheckCount(w1Client_, strmName, 1, 1); + } + } + + void WaitAllNodesJoinIntoHashRing(int num, uint64_t timeoutSec = 60, std::string azName = "") + { + int S2Ms = 1000; + WaitHashRingChange( + [&](const HashRingPb &hashRing) { + if (hashRing.workers_size() != num || hashRing.add_node_info_size() != 0 + || hashRing.del_node_info_size() != 0) { + return false; + } + for (auto &worker : hashRing.workers()) { + if (worker.second.state() != WorkerPb::ACTIVE) { + return false; + } + } + return true; + }, + timeoutSec * S2Ms, azName); + sleep(WORKER_RECEIVE_DELAY); + } + + template + void WaitHashRingChange(F &&f, uint64_t timeoutMs = 30'000, std::string azName = "") + { + if (!db_) { + InitTestEtcdInstance(); + } + auto timeOut = std::chrono::steady_clock::now() + std::chrono::milliseconds(timeoutMs); + bool flag = false; + HashRingPb ring; + while (std::chrono::steady_clock::now() < timeOut) { + std::string hashRingStr; + auto trueRingTable = azName.empty() ? ETCD_RING_PREFIX : '/' + azName + ETCD_RING_PREFIX; + DS_ASSERT_OK(db_->Get(trueRingTable, "", hashRingStr)); + ASSERT_TRUE(ring.ParseFromString(hashRingStr)); + if (f(ring)) { + flag = true; + break; + } + const int interval = 100; // 100ms; + std::this_thread::sleep_for(std::chrono::milliseconds(interval)); + } + LOG(INFO) << "Check " << (flag ? "success" : "failed") + << ", Ring info:" << worker::HashRingToJsonString(ring); + ASSERT_TRUE(flag); + } + + void TestSendRecv(std::shared_ptr &producer, std::shared_ptr &consumer) + { + // Produce element + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + // Read the element to make sure other requests can go through + std::vector outElements; + const int DEFAULT_WAIT_TIME = 5000; + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), size_t(1)); + DS_ASSERT_OK(consumer->Ack(1)); + } + + void SendHelper(std::shared_ptr producer) + { + const int DEFAULT_SLEEP_TIME = 300; + std::vector writeElement = RandomData().RandomBytes(TEST_SIZE); + Element element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + for (size_t i = 0; i < SEND_COUNT; i++) { + Status rc = producer->Send(element); + int retryCount = 30; + while (rc.GetCode() == K_OUT_OF_MEMORY && retryCount-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + DS_ASSERT_OK(rc); + } + } + + void ReceiveHelper(std::shared_ptr consumer) + { + const int DEFAULT_RETRY_TIME = 100; + Timer timer; + std::vector outElements; + int sendCount = SEND_COUNT; + const int DEFAULT_WAIT_TIME = 1000; + while (sendCount > 0 && timer.ElapsedSecond() < DEFAULT_RETRY_TIME) { + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + if (!outElements.empty()) { + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + sendCount -= outElements.size(); + } + } + } + + /** + * @brief Adds 10 streams with producer/consumer in each one, test it, + * and then ensures that all streams are functional by creating remote producers + * and consumers + * @param[in] streams The map of stream names to pairs of prod/cons + * @param[in] remoteClient The remote client to use for producer + */ + void PostScaleTest( + std::map, std::shared_ptr>> &streams, + std::shared_ptr &remoteClient) + { + const int K_TEN = 10; + int streamNum = streams.size(); + // Add 10 streams and producers/consumers after scale up + for (int i = 0; i < K_TEN; ++i) { + std::string streamName = "stream" + std::to_string(i + streamNum); + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub" + std::to_string(i + streamNum), SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + streams.emplace(streamName, std::make_pair(producer, consumer)); + CheckCount(w1Client_, streamName, 1, 1); + } + + // Make sure later requests get redirected and handled correctly + for (auto &stream : streams) { + const auto &streamName = stream.first; + LOG(INFO) << "handle stream: " << streamName; + + // Add new remote producer and consumer after scale up/down + std::shared_ptr producer2; + DS_ASSERT_OK(remoteClient->CreateProducer(streamName, producer2, defaultProducerConf_)); + + std::shared_ptr consumer2; + SubscriptionConfig config("remote_sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(remoteClient->Subscribe(streamName, config, consumer2)); + if (w1Client_ == remoteClient) { + // if both producers are created by same client/worker + // then the producer count in master will be 1 + CheckCount(remoteClient, streamName, 1, K_TWO); + } else { + CheckCount(remoteClient, streamName, K_TWO, K_TWO); + } + auto &producer1 = stream.second.first; + auto &consumer1 = stream.second.second; + + // Test Produce element + TestSendRecv(producer1, consumer1); + TestSendRecv(producer1, consumer2); + TestSendRecv(producer2, consumer1); + TestSendRecv(producer2, consumer2); + + // Close producer, consumer, and then delete the stream + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(producer2->Close()); + DS_ASSERT_OK(consumer2->Close()); + CheckCount(remoteClient, streamName, 0, 0); + // We cant guarantee that Close consumer notification + // would be finished before delete stream call + // If notification is still slow DeleteStream will return + // Stream is still in use + Status rc = remoteClient->DeleteStream(streamName); + if (rc.GetCode() != K_SC_STREAM_NOTIFICATION_PENDING) { + DS_ASSERT_OK(rc); + } + } + } + +protected: + + Status TryAndDeleteStream(std::shared_ptr spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + int workerNum_ = 2; + HostPort w1Addr_; + HostPort w2Addr_; + + std::shared_ptr w1Client_ = nullptr; + std::shared_ptr w2Client_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + std::unique_ptr db_; + int nodeTimeoutS_ = 3; + int nodeDeadTimeoutS_ = 5; + const size_t SEND_COUNT = 100000; + const size_t TEST_SIZE = 1 * KB; +}; + +TEST_F(StreamClientScaleTest, TestConsumerCanRecvEleAfterScaleDown) +{ + // If the stream metadata is hashed to worker2, we can hit the target scenario. Therefore the use case needs to be + // executed 10 times + int streamNum = 10; + int consumerRecvTimeoutMs = 10'000; + + AddNode(); // Now we have 3 nodes. + std::vector, std::shared_ptr, std::shared_ptr>> cache; + cache.resize(streamNum); + // Contruct element. + const size_t sizeElement = 1 * KB; + std::string writeElement = RandomData().GetRandomString(sizeElement); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + // Init client in worker3. + HostPort w3Addr; + DS_ASSERT_OK(cluster_->GetWorkerAddr(2, w3Addr)); // 2 is the index of node. + std::shared_ptr w3Client; + InitStreamClient(2, w3Client); // 2 is the index of node. + // Create producer/consumer -> send/recv one element -> close producer + for (int i = 0; i < streamNum; ++i) { + auto &consumer = std::get<0>(cache[i]); + auto &producer1 = std::get<1>(cache[i]); + auto streamName = "stream" + std::to_string(i); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(w3Client->CreateProducer(streamName, producer1, defaultProducerConf_)); + DS_ASSERT_OK(producer1->Send(element)); + DS_ASSERT_OK(producer1->Close()); + std::vector eles; + DS_ASSERT_OK(consumer->Receive(1, consumerRecvTimeoutMs, eles)); + ASSERT_EQ(eles.size(), 1); + } + // Scale down worker2 + w2Client_.reset(); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); + sleep(nodeDeadTimeoutS_ + 1); + // Consuemr can also recv ele success. + for (int i = 0; i < streamNum; ++i) { + auto &consumer = std::get<0>(cache[i]); + auto &producer2 = std::get<2>(cache[i]); // The index of producer2 is 2. + auto streamName = "stream" + std::to_string(i); + DS_ASSERT_OK(w3Client->CreateProducer(streamName, producer2, defaultProducerConf_)); + DS_ASSERT_OK(producer2->Send(element)); + DS_ASSERT_OK(producer2->Close()); + std::vector eles; + DS_ASSERT_OK(consumer->Receive(1, consumerRecvTimeoutMs, eles)); + ASSERT_EQ(eles.size(), 1); + } +} + +TEST_F(StreamClientScaleTest, TestAutoDeleteStreamAfterScaleDown) +{ + // If the stream metadata is hashed to worker2, we can hit the target scenario. Therefore the use case needs to be + // executed 10 times + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, 1, "SCNotifyWorkerManager.DeleteStreams", "return(K_RPC_UNAVAILABLE)")); + int streamNum = 10; + + AddNode(); // Now we have 3 nodes. + std::vector, std::shared_ptr, std::shared_ptr>> cache; + cache.resize(streamNum); + + std::shared_ptr w3Client; + InitStreamClient(2, w3Client); // index is 2 + // Create producer/consumer -> send/recv one element -> close producer + defaultProducerConf_.autoCleanup = true; + for (int i = 0; i < streamNum; ++i) { + auto &consumer = std::get<0>(cache[i]); + auto &producer1 = std::get<1>(cache[i]); + auto streamName = "stream" + std::to_string(i); + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + DS_ASSERT_OK(w3Client->CreateProducer(streamName, producer1, defaultProducerConf_)); + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer->Close()); + } + // Scale down worker2 + w2Client_.reset(); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); + sleep(nodeDeadTimeoutS_ + 1); + // Modify the stream configuration to confirm that the stream has been deleted + defaultProducerConf_.maxStreamSize += 1; + // Consuemr can also recv ele success. + for (int i = 0; i < streamNum; ++i) { + auto &producer2 = std::get<2>(cache[i]); // The index of producer2 is 2. + auto streamName = "stream" + std::to_string(i); + DS_ASSERT_OK(w3Client->CreateProducer(streamName, producer2, defaultProducerConf_)); + } +} + +TEST_F(StreamClientScaleTest, TestSimpleScaleUp) +{ + LOG(INFO) << "TestSimpleScaleUp start!"; + // Test the scale up and metadata migrate logic + // In this case 7 streams will get metadata migrated at scale up, + // including stream2, stream3, stream6, stream8, stream9, stream10, stream11 + + // Initialize producers and consumers + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "testSimpleScaleUp"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + // Add new worker node to trigger scale up and metadata migration + DS_ASSERT_OK(AddNode()); + + // Make sure later requests get redirected and handled correctly + for (auto &stream : streams) { + const auto &streamName = stream.first; + // Add new remote consumer after scale up + std::shared_ptr consumer; + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + CheckCount(w2Client_, streamName, 1, K_TWO); + + auto &producer = stream.second.first; + // Produce element + TestSendRecv(producer, consumer); + DS_ASSERT_OK(stream.second.first->Close()); + DS_ASSERT_OK(stream.second.second->Close()); + DS_ASSERT_OK(consumer->Close()); + CheckCount(w2Client_, streamName, 0, 0); + DS_ASSERT_OK(w2Client_->DeleteStream(streamName)); + } + LOG(INFO) << "TestSimpleScaleUp finish!"; +} + +TEST_F(StreamClientScaleTest, LEVEL1_TestScaleUpCrashWorker1) +{ + // Test the scale up and metadata migrate logic by closing one of the workers + // Initialize producers and consumers + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "testScaleUpCrashWorker1"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + // Add new worker node to trigger scale up and metadata migration + const int newWorkerIdx = workerNum_; + DS_ASSERT_OK(AddNode()); + // Wait for scale up and migration done + sleep(SCALE_UP_WAIT_TIME); + std::shared_ptr w3Client; + InitStreamClient(newWorkerIdx, w3Client); + + // Kill one of the workers (instead of voluntary shutdown) to make sure that scale up is handled well + std::set STREAMS_ON_WORKER2 = { "stream2", "stream3", "stream6", "stream8", + "stream9", "stream10", "stream11" }; + // close consumers and producers before kill so RPC_UNAVAILABLE is avoided at shutdown + for (auto &stream : streams) { + const auto &streamName = stream.first; + if (STREAMS_ON_WORKER2.find(streamName) == STREAMS_ON_WORKER2.end()) { + DS_ASSERT_OK(streams[streamName].first->Close()); + DS_ASSERT_OK(streams[streamName].second->Close()); + } + } + sleep(K_TWO); + // Delete streams in seperate loop after to avoid running into pending notification failure + for (auto &stream : streams) { + const auto &streamName = stream.first; + if (STREAMS_ON_WORKER2.find(streamName) == STREAMS_ON_WORKER2.end()) { + DS_ASSERT_OK(w1Client_->DeleteStream(streamName)); + } + } + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(1)); + + // Make sure requests for migrated streams are still handled correctly after worker1 crashes + for (auto &stream : streams) { + const auto &streamName = stream.first; + if (STREAMS_ON_WORKER2.find(streamName) == STREAMS_ON_WORKER2.end()) { + continue; + } + std::shared_ptr producer; + DS_ASSERT_OK(w3Client->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client->Subscribe(streamName, config, consumer)); + CheckCount(w3Client, streamName, K_TWO, K_TWO); + DS_ASSERT_OK(stream.second.first->Close()); + DS_ASSERT_OK(stream.second.second->Close()); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + CheckCount(w3Client, streamName, 0, 0); + } + sleep(K_TWO); + // Delete streams in seperate loop after to avoid running into pending notification failure + for (auto &stream : streams) { + const auto &streamName = stream.first; + if (STREAMS_ON_WORKER2.find(streamName) != STREAMS_ON_WORKER2.end()) { + DS_ASSERT_OK(w3Client->DeleteStream(streamName)); + } + } +} + +TEST_F(StreamClientScaleTest, LEVEL2_TestScaleUpCrashWorker2) +{ + LOG(INFO) << "TestScaleUpCrashWorker2 start!"; + // Test the scale up and metadata migrate logic after the new node crash and restarts + // Essentially the purpose is to make sure it can recover itself from rocksdb + // Initialize producers and consumers + const int streamNum = 5; + std::map, std::shared_ptr>> streams; + std::string streamName = "testScaleUpCraskWorker2"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + // Add new worker node to trigger scale up and metadata migration + const int newWorkerIdx = workerNum_; + DS_ASSERT_OK(AddNode()); + // Wait for scale up and migration done + sleep(SCALE_UP_WAIT_TIME); + // Use kill so it is not voluntary scale down + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(newWorkerIdx)); + // Wait for the process kill so it does not timeout after 60s on GcovFlush + sleep(1); + DS_ASSERT_OK(cluster_->StartNode(WORKER, newWorkerIdx, {})); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, newWorkerIdx)); + std::shared_ptr w3Client; + InitStreamClient(newWorkerIdx, w3Client); + // Make sure requests are still handled correctly after worker2 crash and restart + for (auto &stream : streams) { + const auto &streamName = stream.first; + std::shared_ptr producer; + DS_ASSERT_OK(w3Client->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client->Subscribe(streamName, config, consumer)); + CheckCount(w3Client, streamName, K_TWO, K_TWO); + DS_ASSERT_OK(stream.second.first->Close()); + DS_ASSERT_OK(stream.second.second->Close()); + CheckCount(w3Client, streamName, 1, 1); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + CheckCount(w3Client, streamName, 0, 0); + } + sleep(K_TWO); + // Delete streams in seperate loop after to avoid running into pending notification failure + for (auto &stream : streams) { + const auto &streamName = stream.first; + DS_ASSERT_OK(w3Client->DeleteStream(streamName)); + } + // Shutdown the new node to avoid problem caused by kill with signal 9 + (void)cluster_->ShutdownNode(WORKER, newWorkerIdx); + LOG(INFO) << "TestScaleUpCrashWorker2 finish!"; +} + +TEST_F(StreamClientScaleTest, LEVEL2_TestVoluntaryScaleDown) +{ + LOG(INFO) << "TestVoluntaryScaleDown start!"; + // Test the voluntary scale down and the related metadata migrate logic + // In this case after worker2 shuts down, worker1 should be able to take over all the streams + // Test will take around a minute due to the 16 streams being created + + // Initialize 16 producers and consumers + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "testVoluntaryScaleDown"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + // Shutdown worker2 to trigger voluntary scale down and metadata migration + w2Client_.reset(); + VoluntaryScaleDownInject(1); + // Wait for voluntary scale down to finish + sleep(SCALE_DOWN_WAIT_TIME); + + PostScaleTest(streams, w1Client_); + LOG(INFO) << "TestVoluntaryScaleDown finish!"; +} + +TEST_F(StreamClientScaleTest, LEVEL1_TestScaleDownAutoDeleteStream1) +{ + std::string streamName = "testScaleDownAutoDelStream"; + defaultProducerConf_.autoCleanup = true; + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + + // Shutdown worker 1 + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); + sleep(K_TWO); + + // Close producer to invoke auto delete + DS_ASSERT_OK(producer->Close()); + sleep(K_TWO); + // Try to create a producer with different configs to test if stream was deleted, and can be recreated + std::shared_ptr producer1; + defaultProducerConf_.autoCleanup = false; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer1, defaultProducerConf_)); +} + +TEST_F(StreamClientScaleTest, LEVEL2_TestScaleDownWhileRetainingData) +{ + int streamNum = 10; + std::string streamNameBase = "testScaleDownAutoDelStream"; + defaultProducerConf_.autoCleanup = false; + defaultProducerConf_.retainForNumConsumers = 1; + + for (int i = 0; i < streamNum; i++) { + auto streamName = streamNameBase + std::to_string(i); + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + DS_ASSERT_OK(producer->Close()); + } + + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, 0, "EtcdKeepAlive.SendKeepAliveMessage", "return(K_RPC_UNAVAILABLE)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "worker.RunKeepAliveTask", "return(K_RPC_UNAVAILABLE)")); + + WaitAllNodesJoinIntoHashRing(1); + + auto externalCluster = dynamic_cast(cluster_.get()); + DS_ASSERT_OK(externalCluster->StartWorkerAndWaitReady({0})); + + WaitAllNodesJoinIntoHashRing(2); // 2 workers online + + for (int i = 0; i < streamNum; i++) { + auto streamName = streamNameBase + std::to_string(i); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + } +} + +TEST_F(StreamClientScaleTest, DISABLED_LEVEL1_TestScaleDownDeleteStream) +{ + LOG(INFO) << "TestScaleDownDeleteStream start!"; + // Test that the related node info is kept for DeleteStream cleanup purposes + // First shutdown worker2 + w2Client_.reset(); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 1)); + sleep(nodeDeadTimeoutS_ + 1); + + // Initialize producers and consumers + // Less streams to have the testcase take shorter time + const int streamNum = 3; + std::map, std::shared_ptr>> streams; + std::string streamName = "testScaleDownDelStream"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + for (auto &stream : streams) { + DS_ASSERT_OK(stream.second.first->Close()); + DS_ASSERT_OK(stream.second.second->Close()); + CheckCount(w1Client_, stream.first, 0, 0); + } + + // And then restart worker2 and scale down worker1, so all metadata gets migrated to worker2 + DS_ASSERT_OK(cluster_->StartNode(WORKER, 1, {})); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + sleep(SCALE_UP_WAIT_TIME); + + InitStreamClient(1, w2Client_); + VoluntaryScaleDownInject(0); + sleep(SCALE_DOWN_WAIT_TIME); + + // Nodes that got shutdown will be detected and ignored so delete stream should be OK + for (auto &stream : streams) { + const auto &streamName = stream.first; + CheckCount(w2Client_, streamName, 0, 0); + DS_ASSERT_OK(w2Client_->DeleteStream(streamName)); + } + LOG(INFO) << "TestScaleDownDeleteStream finish!"; +} + +TEST_F(StreamClientScaleTest, LEVEL1_TestScaleUpAndDown) +{ + LOG(INFO) << "LEVEL1_TestScaleUpAndDown start!"; + // Test a mixture of scale up and down operations + // Initialize producers and consumers + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "testScaleUpAndDown"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + // A mixture of operations with enough time in between: + // Add new worker node to trigger scale up and metadata migration + // Scale down and Restart worker1 + DS_ASSERT_OK(AddNode()); + WaitAllNodesJoinIntoHashRing(3); // 3 workers online + VoluntaryScaleDownInject(0); + VoluntaryScaleDownInject(1); + w1Client_.reset(); + w2Client_.reset(); + for (auto &stream : streams) { + stream.second = { nullptr, nullptr }; + } + WaitAllNodesJoinIntoHashRing(1); // 1 workers online + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, {})); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 1, {})); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + WaitAllNodesJoinIntoHashRing(3); // 3 workers online + + InitStreamClient(0, w1Client_); + InitStreamClient(1, w2Client_); + // Make sure all requests can still be handled correctly + for (auto &stream : streams) { + const auto &streamName = stream.first; + LOG(INFO) << "Processing stream: " << streamName; + // Producers and consumers related to worker1 got cleaned up at scale down + CheckCount(w1Client_, streamName, 0, 0); + // Add new producer and consumer + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + CheckCount(w1Client_, streamName, 1, 1); + DS_ASSERT_OK(producer->Close()); + DS_ASSERT_OK(consumer->Close()); + CheckCount(w1Client_, streamName, 0, 0); + DS_ASSERT_OK(w1Client_->DeleteStream(streamName)); + } + LOG(INFO) << "LEVEL1_TestScaleUpAndDown finish!"; +} + + +TEST_F(StreamClientScaleTest, DISABLED_LEVEL1_TestScaleDownProducerCount) +{ + LOG(INFO) << "LEVEL1_TestScaleDownProducerCount start!"; + // A simplified version of StreamClientScaleTest.TestScaleUpAndDown to monitor the producer count after migration. + // Initialize producers and consumers + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "TestScaleDownProducerCount"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + // A mixture of operations with enough time in between: + // Add new worker node to trigger scale up and metadata migration + // Scale down and Restart worker1 + DS_ASSERT_OK(AddNode()); + WaitAllNodesJoinIntoHashRing(3); // 3 workers online + VoluntaryScaleDownInject(0); + w1Client_.reset(); + for (auto &stream : streams) { + stream.second = { nullptr, nullptr }; + } + WaitAllNodesJoinIntoHashRing(2); // 2 workers online + + for (auto &stream : streams) { + const auto &streamName = stream.first; + LOG(INFO) << "Processing stream: " << streamName; + // Producers and consumers related to worker1 got cleaned up at scale down + CheckCount(w2Client_, streamName, 0, 0); + } +} + +TEST_F(StreamClientScaleTest, DISABLED_LEVEL1_TestLargeScaleUp) +{ + LOG(INFO) << "LEVEL1_TestLargeScaleUp start!"; + // Test the scale up and metadata migrate logic, with opening new streams and + // testing sending data across and on different nodes + // Test will take around a minute due to the 100 streams being created + + // Initialize 90 producers and consumers + const int streamNum = 90; + std::map, std::shared_ptr>> streams; + std::string streamName = "testLargeScaleUp"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + // Add new worker node to trigger scale up and metadata migration + DS_ASSERT_OK(AddNode()); + + PostScaleTest(streams, w2Client_); + LOG(INFO) << "LEVEL1_TestLargeScaleUp finish!"; +} + +TEST_F(StreamClientScaleTest, LEVEL2_TestUnlimitedAutodelete1) +{ + // Test that node lost etcd event will clear metadata, so auto-delete is not retried when related node is lost. + std::string streamName = "TestUnlimitedAutodelete1"; + ProducerConf oldConf; + oldConf.autoCleanup = true; + oldConf.pageSize = 4 * MB; // page size is 4 MB + ProducerConf newConf; + newConf.pageSize = 3 * MB; // page size is 3 MB + + std::vector streamVec; + std::vector> producerVec; + const int streamNum = 10; + for (int i = 0; i < streamNum; i++) { + auto tmpStreamName = streamName + std::to_string(i); + + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(tmpStreamName, producer, oldConf)); + producerVec.emplace_back(producer); + streamVec.emplace_back(tmpStreamName); + } + + // Close all producers, but do not handle auto-delete yet. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "master.ProcessDeleteStreams", "1*sleep(10000)")); + for (auto &producer : producerVec) { + producer->Close(); + } + + // Shutdown worker 0 + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + const int WAIT_FOR_DELETION = 15; + sleep(WAIT_FOR_DELETION); + + for (auto streamName : streamVec) { + std::shared_ptr producer; + DS_ASSERT_OK(w2Client_->CreateProducer(streamName, producer, newConf)); + } +} + +TEST_F(StreamClientScaleTest, LEVEL2_TestUnlimitedAutodelete2) +{ + // Test that even if metadata is not cleared, auto-delete will not be retried indefinitely when node is found lost. + std::string streamName = "TestUnlimitedAutodelete2"; + ProducerConf oldConf; + oldConf.autoCleanup = true; + oldConf.pageSize = 4 * MB; // page size is 4 MB + ProducerConf newConf; + newConf.pageSize = 3 * MB; // page size is 3 MB + + std::vector streamVec; + std::vector> producerVec; + const int streamNum = 10; + for (int i = 0; i < streamNum; i++) { + auto tmpStreamName = streamName + std::to_string(i); + + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(tmpStreamName, producer, oldConf)); + producerVec.emplace_back(producer); + streamVec.emplace_back(tmpStreamName); + } + + // Close all producers, but do not handle auto-delete yet. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "master.ProcessDeleteStreams", "1*sleep(10000)")); + for (auto &producer : producerVec) { + producer->Close(); + } + + // Shutdown worker 0, also inject to simulate the scenario that metadata is not cleared. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "SCMetadataManager.SkipClearEmptyMeta", "call()")); + DS_ASSERT_OK(cluster_->ShutdownNode(WORKER, 0)); + const int WAIT_FOR_DELETION = 15; + sleep(WAIT_FOR_DELETION); + + for (auto streamName : streamVec) { + std::shared_ptr producer; + DS_ASSERT_OK(w2Client_->CreateProducer(streamName, producer, newConf)); + } +} + +TEST_F(StreamClientScaleTest, LEVEL1_ScaleWhenSyncConsumerNode) +{ + int streamNum = 10; + std::vector streams; + std::vector> consumers; + for (int i = 0; i < streamNum; i++) { + std::string streamName = RandomData().GetRandomString(10); // stream name len is 10 + streams.emplace_back(streamName); + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + std::shared_ptr consumer; + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + consumers.emplace_back(consumer); + } + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, 1, "MasterWorkerSCServiceImpl.UpdateTopoNotification.begin", "sleep(15000)")); + + const int threadNum = 20; + ThreadPool threadPool(threadNum); + std::vector> futs; + auto createProducer = [this](std::string stream) { + std::shared_ptr producer; + w1Client_->CreateProducer(stream, producer, defaultProducerConf_); + }; + auto closeConsumer = [](std::shared_ptr consumer) { + sleep(1); + DS_ASSERT_OK(consumer->Close()); + }; + for (const auto &stream : streams) { + futs.emplace_back(threadPool.Submit(createProducer, stream)); + } + for (const auto &consumer : consumers) { + futs.emplace_back(threadPool.Submit(closeConsumer, consumer)); + } + sleep(1); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "SCMetadataManager.GetMetasMatch.timeout", "call(5)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "SCMetadataManager.GetMetasMatch.timeout", "call(5)")); + AddNode(); + WaitAllNodesJoinIntoHashRing(3, 10); // wait 10s for worker 3 join + for (const auto &fut : futs) { + fut.wait(); + } +} + +TEST_F(StreamClientScaleTest, LEVEL1_ContinuousRedirection) +{ + std::vector> consumers; + DS_ASSERT_OK(AddNode()); + int worker3Idx = 2; + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, worker3Idx, "SCMetadataManager.Subscribe.wait", "2*sleep(3000)")); + const int threadNum = 2; + ThreadPool threadPool(threadNum); + auto fut1 = threadPool.Submit([this, &consumers] () { + int streamNum = 16; + for (int i = 0; i < streamNum; i++) { + std::string streamName = RandomData().GetRandomString(10); // stream name len is 10 + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + std::shared_ptr consumer; + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + consumers.emplace_back(consumer); + } + }); + + auto fut2 = threadPool.Submit([this, worker3Idx]() { + VoluntaryScaleDownInject(worker3Idx); + WaitAllNodesJoinIntoHashRing(2, 20); // wait 20s for w3 scale down + }); + + fut1.wait(); + fut2.wait(); + + for (const auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Close()); + } +} + +class StreamClientPassiveScaleTest : public StreamClientScaleTest, public CommonDistributedExt { +public: + BaseCluster *GetCluster() override + { + return cluster_.get(); + } + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + // Start 3 workers + opts.numWorkers = ++workerNum_; + opts.enableDistributedMaster = "true"; + opts.vLogLevel = 1; + // Set up node_dead_timeout_s and auto_del_dead_node flags, so that a new meta owner master can be reselected + // for passive scale down purposes + opts.workerGflagParams = FormatString( + " -v=2 -node_timeout_s=3 -node_dead_timeout_s=%d -auto_del_dead_node=true -shared_memory_size_mb=10240", + NODE_DEAD_TIMEOUT); + SCClientCommon::SetClusterSetupOptions(opts); + } +}; + +TEST_F(StreamClientPassiveScaleTest, DISABLED_TestRestartDuringEtcdCrash) +{ + auto externalCluster = dynamic_cast(cluster_.get()); + DS_ASSERT_OK(externalCluster->ShutdownEtcds()); + + std::vector writeElement = RandomData().RandomBytes(TEST_SIZE); + Element element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + int receiveMaxWaitTimeMs = 10'000; + std::vector outElements; + + auto streamNamePerRestart = "streamPerRestart"; + std::shared_ptr consumerPerRestart; + SubscriptionConfig configPerRestart("subPerRestart", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamNamePerRestart, configPerRestart, consumerPerRestart)); + std::shared_ptr producerPerRestart; + DS_ASSERT_OK(w1Client_->CreateProducer(streamNamePerRestart, producerPerRestart, defaultProducerConf_)); + + DS_ASSERT_OK(producerPerRestart->Send(element)); + DS_ASSERT_OK(consumerPerRestart->Receive(1, receiveMaxWaitTimeMs, outElements)); + + DS_ASSERT_OK(externalCluster->RestartWorkerAndWaitReadyOneByOne({ 1 })); + std::shared_ptr w22Client; + InitStreamClient(1, w22Client); + + auto streamNamePostRestart = "streamPostRestart"; + std::shared_ptr consumerPostRestart; + SubscriptionConfig configPostRestart("subPostRestart", SubscriptionType::STREAM); + DS_ASSERT_OK(w22Client->Subscribe(streamNamePostRestart, configPostRestart, consumerPostRestart)); + std::shared_ptr producerPostRestart; + DS_ASSERT_OK(w1Client_->CreateProducer(streamNamePostRestart, producerPostRestart, defaultProducerConf_)); + + DS_ASSERT_OK(producerPostRestart->Send(element)); + DS_ASSERT_OK(consumerPostRestart->Receive(1, receiveMaxWaitTimeMs, outElements)); + + DS_ASSERT_OK(externalCluster->SetInjectAction(WORKER, 1, "WorkerOCServiceImpl.Reconciliation.SkipWait", "call()")); + DS_ASSERT_OK(externalCluster->StartEtcdCluster()); + + int waitReconciliationSec = 5; + sleep(waitReconciliationSec); + + DS_ASSERT_OK(producerPerRestart->Send(element)); + DS_ASSERT_NOT_OK(consumerPerRestart->Receive(1, receiveMaxWaitTimeMs, outElements)); + + DS_ASSERT_OK(producerPostRestart->Send(element)); + DS_ASSERT_OK(consumerPostRestart->Receive(1, receiveMaxWaitTimeMs, outElements)); +} + +TEST_F(StreamClientPassiveScaleTest, DISABLED_LEVEL1_TestPassiveScaleDown) +{ + LOG(INFO) << "LEVEL1_TestPassiveScaleDown start!"; + // Test passive scale down, that is other worker can take over and recover the metadata upon node crash + + // Use the worker3 for crash purpose + const int worker3Index = 2; + // Initialize 40 streams with producers and consumers, total 50 + const int streamNum = 40; + std::map, std::shared_ptr>> streams; + std::string streamName = "testPassiveScaleDown"; + // Create Producer on worker1 and Consumer on worker2 + CreateNProducerAndConsumer(streams, streamNum, streamName, false); + + // Kill worker3, and sleep to trigger passive scale down logic + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(worker3Index)); + sleep(NODE_DEAD_TIMEOUT + 1); + + // First verify that the existing producer->consumer can still go through + // This does not involve master logic + for (auto &stream : streams) { + CheckCount(w1Client_, stream.first, 1, 1); + TestSendRecv(stream.second.first, stream.second.second); + } + + // Add a new worker4 for new consumer + const int worker4Index = workerNum_; + DS_ASSERT_OK(AddNode()); + std::shared_ptr w4Client; + InitStreamClient(worker4Index, w4Client); + + PostScaleTest(streams, w2Client_); + + LOG(INFO) << "LEVEL1_TestPassiveScaleDown finish!"; +} + +TEST_F(StreamClientPassiveScaleTest, LEVEL1_TestRestartPassiveScaleDown) +{ + LOG(INFO) << "TestRestartPassiveScaleDown start!"; + // Test when the passive scale down node gets restarted + // It starts with passive scale down, and then a mixture of reconciliation and scale up metadata migration + // In this case K_DUPLICATED is ignored at metadata migration, so it will not retry indefinitely + + // Use the worker3 for crash purpose + const int worker3Index = 2; + // Initialize producers and consumers + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "testRestartPassiveScaleDown"; + // Create Producer on worker1 and Consumer on worker2 + CreateNProducerAndConsumer(streams, streamNum, streamName, false); + + // Kill worker3, and sleep to trigger passive scale down logic + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(worker3Index)); + sleep(NODE_DEAD_TIMEOUT + 1); + // Then restart worker3 to trigger scale up metadata migration logic + DS_ASSERT_OK(cluster_->StartNode(WORKER, worker3Index, {})); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, worker3Index)); + std::shared_ptr w3Client; + InitStreamClient(worker3Index, w3Client); + + // Make sure requests can still be handled correctly for the new Consumer + for (auto &stream : streams) { + const auto &streamName = stream.first; + auto &producer = stream.second.first; + LOG(INFO) << "handle stream: " << streamName; + // Add new consumer on worker3 + std::shared_ptr consumer; + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(w3Client->Subscribe(streamName, config, consumer)); + CheckCount(w3Client, streamName, 1, K_TWO); + + // Make sure producer -> consumer can go through for the new consumer on worker4 + // This involves master related topo change logic + TestSendRecv(producer, consumer); + + DS_ASSERT_OK(stream.second.first->Close()); + DS_ASSERT_OK(stream.second.second->Close()); + DS_ASSERT_OK(consumer->Close()); + CheckCount(w3Client, streamName, 0, 0); + DS_ASSERT_OK(w3Client->DeleteStream(streamName)); + } + // Shutdown the worker3 to avoid problem caused by kill with signal 9 + (void)cluster_->ShutdownNode(WORKER, worker3Index); + LOG(INFO) << "TestRestartPassiveScaleDown finish!"; +} + +class StreamClientVoluntaryScaleDownTest : public StreamClientScaleTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + // Start 3 workers + opts.numWorkers = ++workerNum_; + opts.enableDistributedMaster = "true"; + // Set up node_dead_timeout_s and auto_del_dead_node flags, so that a new meta owner master can be reselected + // for passive scale down purposes + opts.workerGflagParams = FormatString( + " -v=2 -node_timeout_s=3 -node_dead_timeout_s=%d -auto_del_dead_node=true -shared_memory_size_mb=10240" + " -log_monitor=true -log_monitor_interval_ms=1000 " + "-sc_metrics_log_interval_s=1", + NODE_DEAD_TIMEOUT); + opts.skipWorkerPreShutdown = false; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + StreamClientScaleTest::SetUp(); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE << 1; + InitTestEtcdInstance(); + } + + void GetHashOnWorker(int workerIndex, int64_t &hash) + { + hash = -1; + ASSERT_NE(db_, nullptr) << "The etcd store instance is not initialized"; + HostPort workerAddr; + DS_ASSERT_OK(cluster_->GetWorkerAddr(workerIndex, workerAddr)); + std::string value; + DS_ASSERT_OK(db_->Get(ETCD_RING_PREFIX, "", value)); + HashRingPb ring; + ring.ParseFromString(value); + auto tokens = ring.workers().at(workerAddr.ToString()).hash_tokens(); + ASSERT_GT(tokens.size(), 1) << "A node should have multiple tokens"; + hash = tokens[0] != 0 ? tokens[0] - 1 : tokens[1] - 1; + } + + void InitClientsHelper() + { + InitStreamClient(0, w1Client_); + InitStreamClient(1, w2Client_); + } +}; + +// If the node to be scaled in has data to be sent by the Producer, the node exits only after the data is sent. +TEST_F(StreamClientVoluntaryScaleDownTest, LEVEL1_TestVoluntaryScaleDownWithUnfinishedTask1) +{ + DS_ASSERT_OK(inject::Set("ListenWorker.CheckHeartbeat.heartbeat_interval_ms", "call(500)")); + // The producer is on worker3 and the consumer is on worker1. + const int worker3Index = 2; + std::shared_ptr w3Client; + InitStreamClient(worker3Index, w3Client); + + std::string streamName = "testVolScaleDownUnfinishedTask1"; + std::shared_ptr producer; + DS_ASSERT_OK(w3Client->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + CheckCount(w1Client_, streamName, 1, 1); + + // This injection will result in the producer being slower to push data to the remote worker. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, worker3Index, "ExclusivePageQueue.ScanAndEval", "sleep(3000)")); + + std::string eleContent = "hello"; + Element element(reinterpret_cast(&eleContent.front()), eleContent.size(), ULONG_MAX); + int totalEleNum = 3; + ThreadPool threadPool(1); + auto producerFuture = threadPool.Submit([&] { + for (int i = 0; i < totalEleNum; i++) { + RETURN_IF_NOT_OK(producer->Send(element)); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + return Status::OK(); + }); + // Shutdown worker3 to trigger voluntary scale down. + VoluntaryScaleDownInject(worker3Index); + DS_ASSERT_OK(producerFuture.get()); + producer.reset(); + w3Client.reset(); + // Wait for voluntary scale down to finish + WaitForVoluntaryDownFinished(worker3Index); + // Consumer can receive all of the data. + std::vector outElements; + uint32_t timeOutMs = 2'000; + DS_ASSERT_OK(consumer->Receive(totalEleNum, timeOutMs, outElements)); + ASSERT_EQ(outElements.size(), totalEleNum); + for (auto ele : outElements) { + std::string actualData(reinterpret_cast(ele.ptr), ele.size); + ASSERT_EQ(actualData, eleContent); + } +} + +TEST_F(StreamClientVoluntaryScaleDownTest, TestVoluntaryScaleDownWithTasksShouldBeDiscarded) +{ + std::string streamName = "testVoluntaryScaleDown"; + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + + std::shared_ptr producer; + ProducerConf pConf; + pConf.autoCleanup = true; + DS_ASSERT_OK(w2Client_->CreateProducer(streamName, producer, pConf)); + + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "BufferPool.BatchAsyncFlush", "return(K_OK)")); + + std::string eleContent = "hello"; + Element element(reinterpret_cast(&eleContent.front()), eleContent.size()); + DS_ASSERT_OK(producer->Send(element)); + + sleep(1); // wait scan ele success + + consumer->Close(); + producer->Close(); + + DS_ASSERT_OK(w2Client_->ShutDown()); + + VoluntaryScaleDownInject(1); + + // Wait for voluntary scale down to finish + WaitForVoluntaryDownFinished(1); + + w1Client_.reset(); + w2Client_.reset(); +} + +TEST_F(StreamClientVoluntaryScaleDownTest, LEVEL1_TestScaleDownAutoDeleteStream2) +{ + InitClientsHelper(); + // Test that with voluntary shutdown, pending auto-delete can be handled after migration. + std::string streamName = "testScaleDownAutoDelStream2"; + // Inject so that auto delete cannot finish on worker1. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "master.ProcessDeleteStreams", "sleep(10000)")); + defaultProducerConf_.autoCleanup = true; + const int streamNum = 10; + std::map, std::shared_ptr>> streams; + CreateNProducerAndConsumer(streams, streamNum, streamName); + // Close producers and consumers so auto delete can be triggered. + for (auto &stream : streams) { + DS_ASSERT_OK(stream.second.first->Close()); + DS_ASSERT_OK(stream.second.second->Close()); + CheckCount(w1Client_, stream.first, 0, 0); + } + streams.clear(); + // Voluntarily scale down worker 1, the pending auto delete should be initiated from the new meta owner master. + VoluntaryScaleDownInject(1); + sleep(SCALE_DOWN_WAIT_TIME); + // Try to create a producer with different configs to test if stream was deleted, and can be recreated. + defaultProducerConf_.autoCleanup = false; + CreateNProducerAndConsumer(streams, streamNum, streamName); +} + +TEST_F(StreamClientVoluntaryScaleDownTest, LEVEL2_TestScaleDownNotifications1) +{ + InitClientsHelper(); + // Test that with voluntary shutdown, add pub and add sub notifications can be readded. + // And also the stop data retention notification is also involved. + std::string streamName = "testScaleDownNotifications1"; + defaultProducerConf_.retainForNumConsumers = 1; + // Inject so that notifications all become async on worker3, and async notifications are not handled. + const int worker3Index = 2; + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, worker3Index, "SCNotifyWorkerManager.ForceAsyncNotification", "call()")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, worker3Index, "master.ProcessAsyncNotify", "sleep(10000)")); + const int streamNum = 10; + std::map, std::shared_ptr>> streams; + CreateNProducerAndConsumer(streams, streamNum, streamName, false); + // Voluntarily scale down worker3, metadata will get migrated and notification/reconciliation logic should be triggered. + VoluntaryScaleDownInject(worker3Index); + sleep(SCALE_DOWN_WAIT_TIME); + for (auto &stream : streams) { + CheckCount(w1Client_, stream.first, 1, 1); + TestSendRecv(stream.second.first, stream.second.second); + } +} + +TEST_F(StreamClientVoluntaryScaleDownTest, LEVEL2_TestScaleDownNotifications2) +{ + InitClientsHelper(); + // Test that with voluntary shutdown, del pub and del sub notifications can be readded. + std::string streamName = "testScaleDownNotifications2"; + const int streamNum = 6; + std::map, std::shared_ptr>> streams; + CreateNProducerAndConsumer(streams, streamNum, streamName, false); + // Inject so that notifications all become async on worker3, and async notifications are not handled. + const int worker3Index = 2; + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, worker3Index, "SCNotifyWorkerManager.ForceAsyncNotification", "call()")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, worker3Index, "master.ProcessAsyncNotify", "sleep(10000)")); + for (auto &stream : streams) { + // Close the remote consumer, but the notifications on worker3 are not handled. + stream.second.second->Close(); + } + // Voluntarily scale down worker3, + // metadata will get migrated and notification/reconciliation logic should be triggered. + VoluntaryScaleDownInject(worker3Index); + sleep(SCALE_DOWN_WAIT_TIME); + // Force worker2 to fail to accept any elements with RPC_UNAVAILABLE. + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "PushElementsCursors.begin", "return(K_RPC_UNAVAILABLE)")); + for (auto &stream : streams) { + auto &streamName = stream.first; + auto &producer = stream.second.first; + auto &consumer = stream.second.second; + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName, config, consumer)); + // Send data of amound more than max stream size, to make sure the procedure is fine. + std::thread producerThrd([this, &producer]() { SendHelper(producer); }); + std::thread consumerThrd([this, &consumer]() { ReceiveHelper(consumer); }); + producerThrd.join(); + consumerThrd.join(); + } +} + +TEST_F(StreamClientVoluntaryScaleDownTest, TestScaleDownNotifications3) +{ + InitClientsHelper(); + // Test that with voluntary shutdown, life time consumer count is correctly maintained, + // so stop retain notification is correctly generated by new consumer. + std::string streamName = "testScaleDownNotifications1"; + defaultProducerConf_.retainForNumConsumers = 2; + const int worker3Index = 2; + const int streamNum = 10; + std::map, std::shared_ptr>> streams; + CreateNProducerAndConsumer(streams, streamNum, streamName, false); + for (auto &stream : streams) { + DS_ASSERT_OK(stream.second.second->Close()); + } + // Voluntarily scale down worker3, + // metadata will get migrated and notification/reconciliation logic should be triggered. + VoluntaryScaleDownInject(worker3Index); + sleep(SCALE_DOWN_WAIT_TIME); + for (auto &stream : streams) { + auto &streamName = stream.first; + SubscriptionConfig config("sub_" + streamName, SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, stream.second.second)); + CheckCount(w2Client_, stream.first, 1, 1); + TestSendRecv(stream.second.first, stream.second.second); + } +} + +TEST_F(StreamClientVoluntaryScaleDownTest, LEVEL1_ScaleDownWhenMetaResidue) +{ + DS_ASSERT_OK( + cluster_->SetInjectAction(WORKER, 0, "SCMetadataManager.CreateStreamMetadata", + "1*call(stream_residue_1)->1*call(stream_residue_2)->1*call(stream_residue_3)")); + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "testScaleDownWhenMetaResidue"; + CreateNProducerAndConsumer(streams, streamNum, streamName); + + w1Client_->ShutDown(); + VoluntaryScaleDownInject(0); + int timeoutS = 20; + WaitForVoluntaryDownFinished(0, timeoutS); +} + +TEST_F(StreamClientPassiveScaleTest, LEVEL2_TestSyncConsumerNode) +{ + // Test that during the reconciliation from passive scale down, + // the SyncConsumerNode would skip the duplicates + // Use the worker3 for crash purpose + const int worker3Index = 2; + // Initialize producers and consumers + const int streamNum = 10; + struct PlaceHolder { + PlaceHolder(std::shared_ptr producer, std::shared_ptr consumer, + std::shared_ptr producerThrd, std::shared_ptr consumerThrd) + : producer_(producer), consumer_(consumer), producerThrd_(producerThrd), consumerThrd_(consumerThrd) + { + } + std::shared_ptr producer_; + std::shared_ptr consumer_; + std::shared_ptr producerThrd_; + std::shared_ptr consumerThrd_; + }; + std::map streams; + + // Create Producer on worker1 and Consumer on worker2 + for (int i = 0; i < streamNum; ++i) { + std::string streamName = "testSyncConNode" + std::to_string(i); + std::shared_ptr producer; + ProducerConf conf; + const int TEST_PAGE_SIZE = 16 * KB; + const int TEST_STREAM_MAX_SIZE = 5 * MB; + conf.pageSize = TEST_PAGE_SIZE; + conf.maxStreamSize = TEST_STREAM_MAX_SIZE; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub" + std::to_string(i), SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + auto producerThrd = std::make_shared([this, producer]() { SendHelper(producer); }); + auto consumerThrd = std::make_shared([this, consumer]() { ReceiveHelper(consumer); }); + streams.emplace(streamName, PlaceHolder(producer, consumer, producerThrd, consumerThrd)); + CheckCount(w1Client_, streamName, 1, 1); + } + + // Kill worker3, and sleep to trigger passive scale down logic + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(worker3Index)); + sleep(NODE_DEAD_TIMEOUT + 1); + + // Make sure requests can still be handled correctly between producer and consumer + for (auto &stream : streams) { + stream.second.producerThrd_->join(); + stream.second.consumerThrd_->join(); + } +} + +TEST_F(StreamClientPassiveScaleTest, LEVEL1_TestClearAsyncNotifyTask) +{ + ObtainHashTokens(); + std::string streamName; + int masterIndex = 2; + int timeoutSec = 10; + Timer timer; + while (timer.ElapsedSecond() < timeoutSec) { + std::string tmpStreamName = "stream-" + GetStringUuid(); + WorkerEntry masterEntry; + GetMetaLocationById(tmpStreamName, { 0, 1, 2 }, masterEntry); + if (masterEntry.index == masterIndex) { + streamName = tmpStreamName; + break; + } + } + ASSERT_TRUE(!streamName.empty()); + std::shared_ptr producer; + ProducerConf conf; + const int TEST_PAGE_SIZE = 16 * KB; + const int TEST_STREAM_MAX_SIZE = 5 * MB; + conf.pageSize = TEST_PAGE_SIZE; + conf.maxStreamSize = TEST_STREAM_MAX_SIZE; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName, producer, defaultProducerConf_)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName, config, consumer)); + + // Kill worker1, + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(0)); + DS_ASSERT_OK(consumer->Close()); + sleep(NODE_DEAD_TIMEOUT + 1); + // sleep to trigger passive scale down logic + DS_ASSERT_OK(w2Client_->DeleteStream(streamName)); +} + +class DataVerificationStreamClientScaleTest : public StreamClientScaleTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamClientScaleTest::SetClusterSetupOptions(opts); + opts.workerGflagParams += " -enable_stream_data_verification=true "; + } + + void CreateStreams(uint numOfStream, uint producerPerStream, uint consumerPerStream, + std::vector>, + std::vector>>> &streams, std::string streamName) + { + streams.resize(numOfStream); + CreateNProducerAndMConsumerForEachStream(producerPerStream, consumerPerStream, streams, + streamName); + } + + void CreateNProducerAndMConsumerForEachStream(uint producerPerStream, uint consumerPerStream, + std::vector>, + std::vector>>> &streams, + std::string streamName) + { + for (uint i = 0; i < streams.size(); ++i) { + std::string streamName_ = streamName + std::to_string(i); + auto &producers = streams.at(i).first; + for (uint j = 0; j < producerPerStream; ++j) { + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName_, producer, defaultProducerConf_)); + producers.emplace_back(producer); + } + auto &consumers = streams.at(i).second; + uint consumersSize = consumers.size(); + for (uint j = 0; j < consumerPerStream; ++j) { + std::shared_ptr consumer; + SubscriptionConfig config("sub" + std::to_string(consumersSize + j), SubscriptionType::STREAM); + DS_ASSERT_OK(w1Client_->Subscribe(streamName_, config, consumer)); + consumers.emplace_back(consumer); + } + } + } + + void TestSendRecv(uint numOfElementPerProducer, uint numOfElementPerConsumer, + std::vector>, + std::vector>>> &streams) + { + for (auto &stream : streams) { + auto &producers = stream.first; + for (uint i = 0; i < producers.size(); ++i) { + auto &producer = producers.at(i); + if (producer) { + std::string data = "producer" + std::to_string(i + 1); + Element element(reinterpret_cast(&data.front()), data.size()); + for (uint j = 0; j < numOfElementPerProducer; ++j) { + DS_ASSERT_OK(producer->Send(element)); + } + } + } + if (numOfElementPerConsumer == 0) { + continue; + } + for (auto &consumer : stream.second) { + if (consumer) { + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(numOfElementPerConsumer, RPC_TIMEOUT, outElements)); + ASSERT_EQ(outElements.size(), numOfElementPerConsumer); + outElements.clear(); + } + } + } + } + + void CloseProducer(std::vector producerIndex, + std::vector>, + std::vector>>> &streams) + { + for (auto &stream : streams) { + for (auto &idx : producerIndex) { + if (stream.first.at(idx)) { + DS_ASSERT_OK(stream.first.at(idx)->Close()); + stream.first.at(idx).reset(); + } + } + } + } + + void CloseConsumers(std::vector consumerIndex, + std::vector>, + std::vector>>> &streams) + { + for (auto &stream : streams) { + for (auto &idx : consumerIndex) { + if (stream.second.at(idx)) { + DS_ASSERT_OK(stream.second.at(idx)->Close()); + stream.second.at(idx).reset(); + } + } + } + } + + void DeleteStreams(std::vector>, + std::vector>>> &streams, + std::string streamName) + { + for (uint i = 0; i < streams.size(); ++i) { + auto &stream = streams.at(i); + for (auto &producer : stream.first) { + if (producer) { + DS_ASSERT_OK(producer->Close()); + } + } + for (auto &consumer : stream.second) { + if (consumer) { + DS_ASSERT_OK(consumer->Close()); + } + } + std::string streamName_ = streamName + std::to_string(i); + DS_ASSERT_OK(w1Client_->DeleteStream(streamName_)); + } + } +}; + +TEST_F(DataVerificationStreamClientScaleTest, TestVoluntaryScaleDown) +{ + LOG(INFO) << "TestVoluntaryScaleDown start!"; + + const uint numOfStream = 16; + uint producerPerStream = 3; + const uint consumerPerStream = 1; + std::vector>, std::vector>>> streams; + std::string streamName = "VoluntaryScaleDown"; + // Create 3 Producer and 1 Consumer for each stream. + CreateStreams(numOfStream, producerPerStream, consumerPerStream, streams, streamName); + + // Normal Send and Recv + const uint numOfElementPerProducer = 10; + uint numOfElementPerConsumer = numOfElementPerProducer * producerPerStream; + datasystem::inject::Set("VerifyProducerNo", "return()"); + TestSendRecv(numOfElementPerProducer, numOfElementPerConsumer, streams); + + // Close 1st and 3rd producer. + std::vector producerIndex = {0, 2}; + CloseProducer(producerIndex, streams); + producerPerStream -= producerIndex.size(); + + // Shutdown worker2 to trigger voluntary scale down and metadata migration + w2Client_.reset(); + VoluntaryScaleDownInject(1); + // Wait for voluntary scale down to finish + sleep(SCALE_DOWN_WAIT_TIME); + + // Create 3 more producer per stream + uint newProducerPerStream = 3; + CreateNProducerAndMConsumerForEachStream(newProducerPerStream, 0, streams, streamName); + producerPerStream += newProducerPerStream; + + // Normal Send and Recv again + numOfElementPerConsumer = numOfElementPerProducer * producerPerStream; + TestSendRecv(numOfElementPerProducer, numOfElementPerConsumer, streams); + datasystem::inject::Clear("VerifyProducerNo"); + + // Cleanup + DeleteStreams(streams, streamName); + + LOG(INFO) << "TestVoluntaryScaleDown finish!"; +} + +TEST_F(DataVerificationStreamClientScaleTest, TestVoluntaryScaleUp) +{ + LOG(INFO) << "TestVoluntaryScaleUp start!"; + + const uint numOfStream = 16; + uint producerPerStream = 3; + const uint consumerPerStream = 1; + std::vector>, std::vector>>> streams; + std::string streamName = "VoluntaryScaleUp"; + // Create 3 Producer and 1 Consumer for each stream. + CreateStreams(numOfStream, producerPerStream, consumerPerStream, streams, streamName); + + // Normal Send and Recv + const uint numOfElementPerProducer = 10; + uint numOfElementPerConsumer = numOfElementPerProducer * producerPerStream; + datasystem::inject::Set("VerifyProducerNo", "return()"); + TestSendRecv(numOfElementPerProducer, numOfElementPerConsumer, streams); + + // Close 1st and 3rd producer. + std::vector producerIndex = {0, 2}; + CloseProducer(producerIndex, streams); + producerPerStream -= producerIndex.size(); + + // Add new worker node to trigger scale up and metadata migration + DS_ASSERT_OK(AddNode()); + + // Create 3 more producer per stream + uint newProducerPerStream = 3; + CreateNProducerAndMConsumerForEachStream(newProducerPerStream, 0, streams, streamName); + producerPerStream += newProducerPerStream; + + // Normal Send and Recv again + numOfElementPerConsumer = numOfElementPerProducer * producerPerStream; + TestSendRecv(numOfElementPerProducer, numOfElementPerConsumer, streams); + datasystem::inject::Clear("VerifyProducerNo"); + + // Cleanup + DeleteStreams(streams, streamName); + + LOG(INFO) << "TestVoluntaryScaleUp finish!"; +} + +class DataVerificationStreamClientPassiveScaleTest : public DataVerificationStreamClientScaleTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + // Start 3 workers + opts.numWorkers = ++workerNum_; + opts.enableDistributedMaster = "true"; + opts.vLogLevel = 1; + // Set up node_dead_timeout_s and auto_del_dead_node flags, so that a new meta owner master can be reselected + // for passive scale down purposes + opts.workerGflagParams = FormatString( + " -v=2 -node_timeout_s=3 -node_dead_timeout_s=%d -auto_del_dead_node=true -shared_memory_size_mb=10240 " + "-enable_stream_data_verification=true", + NODE_DEAD_TIMEOUT); + SCClientCommon::SetClusterSetupOptions(opts); + } +}; + +TEST_F(DataVerificationStreamClientPassiveScaleTest, TestPassiveScaleDown) +{ + LOG(INFO) << "TestPassiveScaleDown start!"; + + const uint worker3Index = 2; + const uint numOfStream = 10; + uint producerPerStream = 10; + const uint consumerPerStream = 1; + uint numOfNotReceiveElementPerConsumer = 0; + std::vector>, std::vector>>> streams; + std::string streamName = "PassiveScaleDown"; + // Create 10 Producer and 1 Consumer for each stream. + CreateStreams(numOfStream, producerPerStream, consumerPerStream, streams, streamName); + + // Normal Send and Recv. + const uint numOfElementPerProducer = 10; + uint numOfElementPerConsumer = 0; + TestSendRecv(numOfElementPerProducer, numOfElementPerConsumer, streams); + numOfNotReceiveElementPerConsumer += numOfElementPerProducer * producerPerStream; + + // Close 5 producers. + std::vector producerIndex = {0, 6, 7, 8, 9}; + CloseProducer(producerIndex, streams); + producerPerStream -= producerIndex.size(); + + // Kill worker 3. + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(worker3Index)); + sleep(NODE_DEAD_TIMEOUT + 1); + + // Create 3 producers per stream. + uint newProducerPerStream = 3; + CreateNProducerAndMConsumerForEachStream(newProducerPerStream, 0, streams, streamName); + producerPerStream += newProducerPerStream; + + // Normal Send and Recv. + // 10 * 10 + 10 * (5 + 3) = 180 + datasystem::inject::Set("VerifyProducerNo", "return()"); + numOfNotReceiveElementPerConsumer += numOfElementPerProducer * producerPerStream; + TestSendRecv(numOfElementPerProducer, numOfNotReceiveElementPerConsumer, streams); + datasystem::inject::Clear("VerifyProducerNo"); + + // Cleanup. + DeleteStreams(streams, streamName); + + LOG(INFO) << "TestPassiveScaleDown finish!"; +} + +TEST_F(DataVerificationStreamClientPassiveScaleTest, TestRestartPassiveScaleDown) +{ + LOG(INFO) << "TestRestartPassiveScaleDown start!"; + + const uint worker3Index = 2; + const uint numOfStream = 10; + uint producerPerStream = 10; + const uint consumerPerStream = 1; + uint numOfNotReceiveElementPerConsumer = 0; + std::vector>, std::vector>>> streams; + std::string streamName = "RestartPassiveScaleDown"; + // Create 10 Producer and 1 Consumer for each stream. + CreateStreams(numOfStream, producerPerStream, consumerPerStream, streams, streamName); + + // Normal Send and Recv. + const uint numOfElementPerProducer = 10; + uint numOfElementPerConsumer = 0; + TestSendRecv(numOfElementPerProducer, numOfElementPerConsumer, streams); + numOfNotReceiveElementPerConsumer += numOfElementPerProducer * producerPerStream; + + // Close 5 producers. + std::vector producerIndex = {0, 6, 7, 8, 9}; + CloseProducer(producerIndex, streams); + producerPerStream -= producerIndex.size(); + + // Kill worker3, and sleep to trigger passive scale down logic. + DS_ASSERT_OK(static_cast(cluster_.get())->KillWorker(worker3Index)); + sleep(NODE_DEAD_TIMEOUT + 1); + // Then restart worker3 to trigger scale up metadata migration logic. + DS_ASSERT_OK(cluster_->StartNode(WORKER, worker3Index, {})); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, worker3Index)); + std::shared_ptr w3Client; + InitStreamClient(worker3Index, w3Client); + + // Create 3 producers per stream + uint newProducerPerStream = 3; + CreateNProducerAndMConsumerForEachStream(newProducerPerStream, 0, streams, streamName); + producerPerStream += newProducerPerStream; + + // Normal Send and Recv. + // 10 * 10 + 100 * (5 + 3) = 180 + datasystem::inject::Set("VerifyProducerNo", "return()"); + numOfNotReceiveElementPerConsumer += numOfElementPerProducer * producerPerStream; + TestSendRecv(numOfElementPerProducer, numOfNotReceiveElementPerConsumer, streams); + datasystem::inject::Clear("VerifyProducerNo"); + + // Cleanup. + DeleteStreams(streams, streamName); + + LOG(INFO) << "TestRestartPassiveScaleDown finish!"; +} + +class StreamClientScaleDfxTest : public StreamClientScaleTest { +}; + +TEST_F(StreamClientScaleDfxTest, LEVEL2_ScaleUpWhenMetaResidue) +{ + const int streamNum = 16; + std::map, std::shared_ptr>> streams; + std::string streamName = "testScaleUpWhenMetaResidue"; + CreateNProducerAndConsumer(streams, streamNum, streamName, false); + datasystem::inject::Set("StreamClient.ShutDown.skip", "return()"); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "BatchMigrateMetadata.finish", "1*sleep(2000)")); + + VoluntaryScaleDownInject(1); + sleep(1); // wait hash ring change + kill(cluster_->GetWorkerPid(1), SIGKILL); + w2Client_.reset(); + sleep(6); // wait 6s for worker passive reduction + + DS_ASSERT_OK(AddNode()); + WaitAllNodesJoinIntoHashRing(2, 10); // wait 10s for 2 workers online +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_data_encryption_test.cpp b/tests/st/client/stream_cache/stream_data_encryption_test.cpp new file mode 100644 index 0000000..14238d6 --- /dev/null +++ b/tests/st/client/stream_cache/stream_data_encryption_test.cpp @@ -0,0 +1,349 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream data encryption support + */ + +#include + +#include + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/common/encrypt/secret_manager.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +DS_DECLARE_string(encrypt_kit); +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +class StreamDataEncryptionTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = NUM_WORKERS; + opts.systemAccessKey = ""; + opts.systemSecretKey = ""; + // Set the encrypted key for stream data encryption + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + defaultConf_.pageSize = DEFAULT_PAGE_SIZE; + defaultConf_.maxStreamSize = DEFAULT_MAX_STREAM_SIZE; + // Enable encryptStream by default for test purposes. + defaultConf_.encryptStream = true; + } + + void TearDown() override + { + ExternalClusterTest::TearDown(); + } + + /** + * @brief Creates a stream client at the given worker num + * @param[in] workerNum The worker num to create the stream against + * @param[out] spClient Shared pointer to the stream client + * @return status + */ + Status CreateClient(int workerNum, std::shared_ptr &spClient) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(workerNum, workerAddress)); + ConnectOptions options; + options.host = workerAddress.Host(); + options.port = workerAddress.Port(); + spClient = std::make_shared(options); + RETURN_IF_NOT_OK(spClient->Init()); + return Status::OK(); + } + + Status CreateElement(size_t elementSize, Element &element, std::vector &writeElement) + { + writeElement = RandomData().RandomBytes(elementSize); + element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + return Status::OK(); + } + + std::shared_ptr TestSendRecv(std::shared_ptr producer, std::shared_ptr consumer) + { + return std::make_shared([this, producer, consumer]() { + const int numElements = 500; + std::vector> elements(numElements); + std::thread producerThrd([this, producer, &elements]() { + const int DEFAULT_SLEEP_TIME = 300; + auto randomData = RandomData(); + for (int i = 0; i < numElements; i++) { + size_t testSize = randomData.GetRandomIndex(10) == 0 ? DEFAULT_BIG_SIZE : DEFAULT_SMALL_SIZE; + Element element; + int retryLimit = 30; + DS_ASSERT_OK(CreateElement(testSize, element, elements[i])); + datasystem::Status rc = producer->Send(element); + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + rc = producer->Send(element); + } + DS_ASSERT_OK(rc); + } + }); + + // Receiver should get both small elements and big elements correctly + std::vector outElements; + int received = 0; + while (received < numElements) { + DS_ASSERT_OK(consumer->Receive(1, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), 1); + ASSERT_EQ(outElements[0].size, elements[received].size()); + ASSERT_EQ(memcmp(outElements[0].ptr, elements[received].data(), elements[received].size()), 0); + DS_ASSERT_OK(consumer->Ack(outElements.back().id)); + received++; + } + producerThrd.join(); + }); + } + +protected: + const int NUM_WORKERS = 2; + const int DEFAULT_WAIT_TIME = 10000; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + std::vector validKeys = { "sjdoifjoidjfgfunsjdoifjoidjfgfun", "iusdhfgiojshddagiusdhfgiojshddag" }; + ProducerConf defaultConf_; + const int DEFAULT_SMALL_SIZE = 10 * KB; + const int DEFAULT_BIG_SIZE = 60 * KB; + const int DEFAULT_PAGE_SIZE = 40 * KB; + const int DEFAULT_MAX_STREAM_SIZE = 50 * MB; +}; + +TEST_F(StreamDataEncryptionTest, TestStreamSendRecv1) +{ + // Test the basic Stream Data Encryption support. + // That is, test that remote push functions correctly when + // ProducerConf and FLAGS_sc_encrypt_secret_key are configured correctly + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create a Producer + std::shared_ptr producer; + DS_ASSERT_OK(spClient0->CreateProducer("StreamSendRecv1", producer, defaultConf_)); + + // Create a Consumer on a different node + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("StreamSendRecv1", config, consumer)); + + auto thrd = TestSendRecv(producer, consumer); + thrd->join(); +} + +TEST_F(StreamDataEncryptionTest, TestStreamSendRecv2) +{ + // Test that even if worker is set up with FLAGS_sc_encrypt_secret_key, + // streams can still disable encryption, as this is per stream setting. + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create 2 Producers, only one of them enables encryption + ProducerConf conf = defaultConf_; + conf.encryptStream = false; + std::shared_ptr producer1; + DS_ASSERT_OK(spClient0->CreateProducer("StreamSendRecv2_1", producer1, conf)); + + std::shared_ptr producer2; + DS_ASSERT_OK(spClient0->CreateProducer("StreamSendRecv2_2", producer2, defaultConf_)); + + // Create Consumers on a different node + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("StreamSendRecv2_1", config1, consumer1)); + + std::shared_ptr consumer2; + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("StreamSendRecv2_2", config2, consumer2)); + + auto stream1Thrd = TestSendRecv(producer1, consumer1); + auto stream2Thrd = TestSendRecv(producer2, consumer2); + stream1Thrd->join(); + stream2Thrd->join(); +} + +TEST_F(StreamDataEncryptionTest, TestStreamSendRecv3) +{ + // Test that consumer created before producer also gets the correct stream fields updated. + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create Consumer first + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient0->Subscribe("StreamSendRecv3", config, consumer)); + + // Create a Producer on a different node + std::shared_ptr producer; + DS_ASSERT_OK(spClient1->CreateProducer("StreamSendRecv3", producer, defaultConf_)); + + auto thrd = TestSendRecv(producer, consumer); + thrd->join(); +} + +TEST_F(StreamDataEncryptionTest, TestSharedPageSendRecv) +{ + // Test that when shared page is enabled, stream encryption is performed correctly. + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create Consumer first + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient0->Subscribe("SharedPageSendRecv", config, consumer)); + + // Create a Producer on a different node + std::shared_ptr producer; + ProducerConf conf; + conf.pageSize = DEFAULT_PAGE_SIZE; + conf.maxStreamSize = DEFAULT_MAX_STREAM_SIZE; + // Enable encryptStream and shared page for test purposes. + conf.encryptStream = true; + conf.streamMode = StreamMode::SPSC; + DS_ASSERT_OK(spClient1->CreateProducer("SharedPageSendRecv", producer, conf)); + + auto thrd = TestSendRecv(producer, consumer); + thrd->join(); +} + +TEST_F(StreamDataEncryptionTest, TestCreateProducerFailure) +{ + // Test that CreateProducer fails if the encryptStream setting mismatch. + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create stream producer with encrypt + std::shared_ptr producer1; + DS_ASSERT_OK(spClient0->CreateProducer("testDiffProdConfig", producer1, defaultConf_)); + + ProducerConf conf = defaultConf_; + conf.encryptStream = false; + std::shared_ptr producer2; + DS_ASSERT_NOT_OK(spClient1->CreateProducer("testDiffProdConfig", producer2, conf)); +} + +class StreamDataEncryptionEmptyKeyTest : public StreamDataEncryptionTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = NUM_WORKERS; + opts.systemAccessKey = ""; + opts.systemSecretKey = ""; + SCClientCommon::SetClusterSetupOptions(opts); + } +}; + +TEST_F(StreamDataEncryptionEmptyKeyTest, TestStreamSendRecv) +{ + // Test that if workers are configured with empty FLAGS_sc_encrypt_secret_key, + // stream data can still be sent correctly. + // In this case, encryption/decryption is not performed. + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create a Producer + std::shared_ptr producer; + DS_ASSERT_OK(spClient0->CreateProducer("testStreamSendRecv", producer, defaultConf_)); + + // Create a Consumer on a different node + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("testStreamSendRecv", config, consumer)); + + auto thrd = TestSendRecv(producer, consumer); + thrd->join(); +} + +class StreamDataEncryptionPlainTextTest : public StreamDataEncryptionTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = NUM_WORKERS; + opts.systemAccessKey = ""; + opts.systemSecretKey = ""; + opts.workerGflagParams = "-encrypt_kit=plaintext"; + for (size_t i = 0; i < opts.numWorkers; ++i) { + auto port = GetFreePort(); + std::string encryptedKey; + DS_ASSERT_OK(SecretManager::Instance()->Encrypt(validKeys[i], encryptedKey)); + opts.workerSpecifyGflagParams.emplace( + i, FormatString("-sc_encrypt_secret_key=%s -sc_worker_worker_direct_port=%d", encryptedKey, port)); + } + opts.isStreamCacheCase = true; + } +}; + +TEST_F(StreamDataEncryptionPlainTextTest, TestStreamSendRecv) +{ + // Test that if workers are configured with the default FLAGS_encrypt_kit = "plaintext", + // stream data can still be sent correctly. + // In this case, encryption/decryption is not performed. + // This is guaranteed by send and recv success while keys are set to mismatch. + std::shared_ptr spClient0; + DS_ASSERT_OK(CreateClient(0, spClient0)); + + std::shared_ptr spClient1; + DS_ASSERT_OK(CreateClient(1, spClient1)); + + // Create a Producer + std::shared_ptr producer; + DS_ASSERT_OK(spClient0->CreateProducer("testDataEncryptPlainText", producer, defaultConf_)); + + // Create a Consumer on a different node + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(spClient1->Subscribe("testDataEncryptPlainText", config, consumer)); + + auto thrd = TestSendRecv(producer, consumer); + thrd->join(); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_dfx_send_recv_test.cpp b/tests/st/client/stream_cache/stream_dfx_send_recv_test.cpp new file mode 100644 index 0000000..e988125 --- /dev/null +++ b/tests/st/client/stream_cache/stream_dfx_send_recv_test.cpp @@ -0,0 +1,376 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ + +#include + +#include + +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +const uint32_t EXPECT_RECV_NUM = 10; +const uint32_t TEST_PAGE_SIZE = 20 * 1024; +constexpr int K_TWO = 2; +constexpr int K_FOUR = 4; +constexpr int K_TEN = 10; +class StreamDfxSendRecvTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + const int vLogLevel = 3; + const uint32_t workerCout = 2; + opts.numEtcd = 1; + opts.numWorkers = workerCout; + opts.vLogLevel = vLogLevel; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + DS_ASSERT_OK(InitClient(0, client1_)); + DS_ASSERT_OK(InitClient(1, client2_)); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + Status InitClient(int index, std::shared_ptr &client) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(index, workerAddress)); + LOG(INFO) << "worker index " << index << ": " << workerAddress.ToString(); + ConnectOptions options; + options.accessKey = accessKey_; + options.secretKey = secretKey_; + options.host = workerAddress.Host(); + options.port = workerAddress.Port(); + client = std::make_shared(options); + return client->Init(); + } + + Status CreateConsumer(std::shared_ptr client, const std::string &streamName, + const std::string &subName, std::shared_ptr &consumer) + { + SubscriptionConfig config(subName, SubscriptionType::STREAM); + return client->Subscribe(streamName, config, consumer); + } + + Status CreateProducer(std::shared_ptr client, const std::string &streamName, + std::shared_ptr &producer) + { + const int64_t autoFlushTime = 10 * 1000; // 10s; + ProducerConf conf = { .delayFlushTime = autoFlushTime, + .pageSize = TEST_PAGE_SIZE, + .maxStreamSize = TEST_STREAM_SIZE }; + return client->CreateProducer(streamName, producer, conf); + } + + Status TryAndDeleteStream(std::shared_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + std::shared_ptr client1_; + std::shared_ptr client2_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(StreamDfxSendRecvTest, TestRecv) +{ + std::shared_ptr producer; + std::shared_ptr consumer1; + std::shared_ptr consumer2; + std::string streamName = "test_stream_recv"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "subname1", consumer1)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "subname2", consumer2)); + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + DS_ASSERT_OK(producer->Send(element)); + DS_ASSERT_OK(producer->Send(element)); + + std::vector outElements; + ASSERT_EQ(consumer1->Receive(EXPECT_RECV_NUM, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 2ul); + DS_ASSERT_OK(consumer1->Ack(outElements.back().id)); + + ASSERT_EQ(consumer2->Receive(EXPECT_RECV_NUM, 0, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 2ul); + DS_ASSERT_OK(consumer2->Ack(outElements.back().id)); + DS_ASSERT_OK(consumer2->Ack(outElements.back().id)); +} + +TEST_F(StreamDfxSendRecvTest, TestRecvWorkerCrash) +{ + std::shared_ptr producer; + std::shared_ptr consumer1; + std::shared_ptr consumer2; + + // Create a producer on node1 + std::string streamName = "RecvWorkerCrash"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + + // Create a consumer on node2 + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "subname1", consumer1)); + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "subname2", consumer2)); + + // Delay the recv on node2 + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "PushElementsCursors.begin", + "sleep(3000)")); + + // Send a lot of data in producer + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + for (int i = 0; i < 10; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + std::vector outElements; + ASSERT_EQ(consumer1->Receive(10, 5000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), 10); + DS_ASSERT_OK(consumer1->Ack(10)); + + // Restart node2 + // Node with consumers - crash happens + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); +} + +TEST_F(StreamDfxSendRecvTest, TestPendingRecvWorkerCrash) +{ + std::shared_ptr producer; + std::shared_ptr consumer1; + + // Create a producer on node1 + std::string streamName = "test_stream_pendingRecvWorkerCrash"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + + // Create a consumer on node2 + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "subname1", consumer1)); + + // Send a lot of data in producer + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + for (int i = 0; i < 100; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + int threadNum = 1; + ThreadPool pool(threadNum); + auto fut1 = pool.Submit([this, consumer1, element]() { + // Make sure receive is still blocked while we shutdown the node + std::vector outElements; + ASSERT_EQ(consumer1->Receive(1000, 50000, outElements).GetCode(), StatusCode::K_SC_ALREADY_CLOSED); + ASSERT_EQ(outElements.size(), 0); + }); + + // Restart node2 + // Node with consumers - crash happens + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + + // Wait for recieve + fut1.get(); +} + +TEST_F(StreamDfxSendRecvTest, LEVEL1_TestWorkerCrashLateEtcdNotification) +{ + LOG(INFO) << "TestWorkerCrashLateEtcdNotification start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + const int DEFAULT_NUM_ELEMENT = 100; + const int DEFAULT_WAIT_TIME = 5000; + + std::string streamName = "Stream_" + RandomData().GetRandomString(10); + std::vector writeElement = RandomData().RandomBytes(K_FOUR * KB); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + + ProducerConf conf; + conf.maxStreamSize = K_TEN * MB; + conf.pageSize = 1 * MB; + + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "StreamMetadata.ClearPubSubMetaData.sleep", "sleep(5000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 1, "StreamMetadata.ClearPubSubMetaData.sleep", "sleep(5000)")); + + // Create a new producer - this data should be retain + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + // Node with consumers - crash happens + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + + std::shared_ptr producer1; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer1, conf)); + + std::shared_ptr consumer1; + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config1, consumer1)); + + // Try sending and getting data + for (int i = 0; i < DEFAULT_NUM_ELEMENT; i++) { + DS_ASSERT_OK(producer1->Send(element)); + } + + // Try to receive data + std::vector outElements; + DS_ASSERT_OK(consumer1->Receive(DEFAULT_NUM_ELEMENT, DEFAULT_WAIT_TIME, outElements)); + ASSERT_EQ(outElements.size(), DEFAULT_NUM_ELEMENT); + DS_ASSERT_OK(consumer1->Ack(DEFAULT_NUM_ELEMENT)); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer->Close()); + outElements.clear(); + DS_ASSERT_OK(TryAndDeleteStream(client1, streamName)); +} + +TEST_F(StreamDfxSendRecvTest, LEVEL1_TestResendDataAfterProducerWorkerCrash) +{ + std::shared_ptr producer; + std::shared_ptr consumer1; + + // Create a producer on node1 + std::string streamName = "test_stream_resendAfterProdWorkerCrash"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + + // Create a consumer on node2 + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "subname1", consumer1)); + + // Send a lot of data in producer + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + for (int i = 0; i < 100; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Let recv worker get some data + std::vector outElements; + const int numElements = 100; + ASSERT_EQ(consumer1->Receive(numElements, 5000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), numElements); + DS_ASSERT_OK(consumer1->Ack(numElements)); + + // Restart node1 - producer worker + // Node with producer - crash happens + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + + // Create a producer on node1 again for same stream + std::shared_ptr producer1; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer1)); + + // Send and recv first batch + for (int i = 0; i < 100; i++) { + DS_ASSERT_OK(producer1->Send(element)); + } + + // We should be able to get this data + outElements.clear(); + ASSERT_EQ(consumer1->Receive(numElements, 5000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), numElements); + DS_ASSERT_OK(consumer1->Ack(numElements)); +} + +TEST_F(StreamDfxSendRecvTest, LEVEL1_TestResendDataAfterConsumerWorkerCrash) +{ + std::shared_ptr producer; + std::shared_ptr consumer1; + + // Create a producer on node1 + std::string streamName = "test_stream_resendAfterConWorkerCrash"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + + // Create a consumer on node2 + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "subname1", consumer1)); + + // Send a lot of data in producer + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + for (int i = 0; i < 100; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // Let recv worker get some data + const int numElements = 100; + std::vector outElements; + ASSERT_EQ(consumer1->Receive(numElements, 5000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), numElements); + DS_ASSERT_OK(consumer1->Ack(numElements)); + + // Restart node1 - producer worker + // Node with producer - crash happens + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + + // Create a producer on node1 again for same stream + std::shared_ptr consumer2; + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "subname1", consumer2)); + + for (int i = 0; i < 100; i++) { + DS_ASSERT_OK(producer->Send(element)); + } + + // We should be able to get this data + outElements.clear(); + ASSERT_EQ(consumer2->Receive(numElements, 5000, outElements), Status::OK()); + ASSERT_EQ(outElements.size(), numElements); + DS_ASSERT_OK(consumer2->Ack(numElements)); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_dfx_test.cpp b/tests/st/client/stream_cache/stream_dfx_test.cpp new file mode 100644 index 0000000..fe7d650 --- /dev/null +++ b/tests/st/client/stream_cache/stream_dfx_test.cpp @@ -0,0 +1,2670 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include + +#include +#include +#include +#include +#include + +#include "client/stream_cache/pub_sub_utils.h" +#include "cluster/base_cluster.h" +#include "cluster/external_cluster.h" +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/common/metrics/res_metric_collector.h" +#include "datasystem/common/util/file_util.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/stream_config.h" + +DS_DECLARE_string(log_dir); + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +constexpr int K_TWO = 2; +constexpr uint32_t CONNECT_TIMEOUT_MS = 10000; +constexpr int WORKER_COUNT = 3; +class StreamDfxTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.enableDistributedMaster = "false"; + opts.numWorkers = WORKER_COUNT; + opts.vLogLevel = logLevel; + opts.workerGflagParams += FormatString( + " -node_timeout_s=%d -node_dead_timeout_s=%d -client_reconnect_wait_s=2 -log_monitor=true " + "-log_monitor_interval_ms=500", + nodeTimeout, nodeDead); + SCClientCommon::SetClusterSetupOptions(opts); + } + +protected: + Status InitClient(int index, std::shared_ptr &client, uint32_t timeoutMs = CONNECT_TIMEOUT_MS) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(index, workerAddress)); + LOG(INFO) << "worker index " << index << ": " << workerAddress.ToString(); + ConnectOptions connectOptions; + connectOptions.host = workerAddress.Host(); + connectOptions.port = workerAddress.Port(); + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + connectOptions.connectTimeoutMs = timeoutMs; + client = std::make_shared(connectOptions); + return client->Init(); + } + + Status CreateProducerAndConsumer(std::shared_ptr &client, + std::vector> producerDesc, + std::vector> &producers, + std::vector> consumerDesc, + std::vector> &consumers) + { + const int64_t delayFlushTime = 3 * 1000; // 3s + ProducerConf conf; + conf.delayFlushTime = delayFlushTime; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.autoCleanup = true; + for (const auto &kv : producerDesc) { + for (size_t i = 0; i < kv.second; i++) { + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(kv.first, producer, conf)); + producers.emplace_back(producer); + } + } + + for (const auto &kv : consumerDesc) { + std::shared_ptr consumer; + SubscriptionConfig config(kv.second, SubscriptionType::STREAM); + RETURN_IF_NOT_OK(client->Subscribe(kv.first, config, consumer)); + consumers.emplace_back(consumer); + } + return Status::OK(); + } + + Status CreateConsumer(std::shared_ptr client, const std::string &streamName, + const std::string &subName, std::shared_ptr &consumer) + { + SubscriptionConfig config(subName, SubscriptionType::STREAM); + return client->Subscribe(streamName, config, consumer); + } + + Status CreateProducer(std::shared_ptr client, const std::string &streamName, + std::shared_ptr &producer) + { + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + return client->CreateProducer(streamName, producer, conf); + } + + void CheckCount(std::shared_ptr client, const std::string &streamName, int producerCount, + int consumerCount) + { + uint64_t result = 0; + if (producerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalProducersNum(streamName, result)); + EXPECT_EQ(result, static_cast(producerCount)); + result = 0; + } + if (consumerCount >= 0) { + DS_ASSERT_OK(client->QueryGlobalConsumersNum(streamName, result)); + EXPECT_EQ(result, static_cast(consumerCount)); + result = 0; + } + } + + void CreateElement(size_t elementSize, Element &element, std::vector &writeElement) + { + writeElement = RandomData().RandomBytes(elementSize); + element = Element(reinterpret_cast(writeElement.data()), writeElement.size()); + } + + void GetResMonitorLogInfo(int index, const std::string &fileName, std::vector &infos) + { + std::string fullName = FormatString("%s/../worker%d/log/%s", FLAGS_log_dir.c_str(), index, fileName); + std::ifstream ifs(fullName); + ASSERT_TRUE(ifs.is_open()); + std::string line; + std::streampos prev = ifs.tellg(); + std::streampos pos = ifs.tellg(); + // Get the last line + while (std::getline(ifs, line)) { + prev = pos; + pos = ifs.tellg(); + } + ifs.clear(); + ifs.seekg(prev); + std::getline(ifs, line); + infos = Split(line, " | "); + const int ignoreCount = 7; + ASSERT_TRUE(infos.size() == static_cast(ResMetricName::RES_METRICS_END) + ignoreCount); + infos.erase(infos.begin(), infos.begin() + ignoreCount); + } + + const int nodeTimeout = 4; // 4s; + const int nodeDead = 6; // 6s + const int waitNodeTimeout = nodeTimeout + 2; + const int waitNodeDead = nodeDead + 4; + const int logLevel = 1; + const int K_3 = 3; + const int K_5000 = 5000; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(StreamDfxTest, TestRemotePushTimeOut) +{ + LOG(INFO) << "TestRemotePushTimeOut start!"; + std::shared_ptr client1; + std::shared_ptr client2; + std::string streamName = "testRemotePushTO"; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.RemoteSendOnePageView.end", + "1*return(K_RPC_UNAVAILABLE)")); + // Subscribe before send. + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + const size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + std::shared_ptr producer; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + CreateElement(testSize, element, writeElement); + ASSERT_EQ(producer->Send(element), Status::OK()); + + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, 100, outElements)); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + std::string data(reinterpret_cast(writeElement.data()), writeElement.size()); + EXPECT_EQ(data, actualData); + LOG(INFO) << "TestRemotePushTimeOut finish!"; +} + +TEST_F(StreamDfxTest, TestMultiThreadsClientInit) +{ + std::shared_ptr client; + InitStreamClient(0, client); + size_t threadNum = 100; + ThreadPool threadPool(threadNum); + + for (size_t i = 0; i < threadNum; ++i) { + threadPool.Execute([&client] { DS_ASSERT_OK(client->Init()); }); + } +} + +TEST_F(StreamDfxTest, TestCreateProducerConsumerFail) +{ + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr producer; + std::shared_ptr consumer; + std::string streamName = "CreateProdConFail"; + DS_ASSERT_OK(datasystem::inject::Set("ClientBaseImpl.init_fail_before_cursor", "2*return(K_INVALID)")); + DS_ASSERT_NOT_OK(CreateProducer(client1, streamName, producer)); + DS_ASSERT_NOT_OK(CreateConsumer(client1, streamName, "sub1", consumer)); + DS_ASSERT_OK(CreateProducer(client1, streamName, producer)); + DS_ASSERT_OK(CreateConsumer(client1, streamName, "sub1", consumer)); +} + +/* +AutoCleanup set to true. Create producer and consumer on same node. +CreateConsumer fails and prodcucer closes. CheckCount shows no producer +or consumer. AutoCleanup cleans stream metadata. +*/ +TEST_F(StreamDfxTest, TestConsumerFailWithAutoDelete) +{ + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.autoCleanup = true; + std::shared_ptr producer; + std::string streamName = "ConFailWithAutoDelete"; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + std::shared_ptr consumer; + DS_ASSERT_OK(datasystem::inject::Set("ClientBaseImpl.init_fail_before_cursor", "1*return(K_INVALID)")); + DS_ASSERT_NOT_OK(CreateConsumer(client1, streamName, "sub1", consumer)); + CheckCount(client1, streamName, 1, 0); + DS_ASSERT_OK(producer->Close()); + CheckCount(client1, streamName, 0, 0); +} + +/* +AutoCleanup set to true. Create producer and consumer on same node. +CreateProducer fails and consumer closes. CheckCount shows no producer +or consumer. AutoCleanup cleans stream metadata. +*/ +TEST_F(StreamDfxTest, TestProducerFailWithAutoDelete) +{ + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr producer; + std::shared_ptr consumer; + std::string streamName = "ProdFailWithAutoDelete"; + DS_ASSERT_OK(CreateConsumer(client1, streamName, "sub1", consumer)); + DS_ASSERT_OK(datasystem::inject::Set("ClientBaseImpl.init_fail_before_cursor", "1*return(K_INVALID)")); + DS_ASSERT_NOT_OK(CreateProducer(client1, streamName, producer)); + CheckCount(client1, streamName, 0, 1); + DS_ASSERT_OK(consumer->Close()); + CheckCount(client1, streamName, 0, 0); +} + +TEST_F(StreamDfxTest, TestProducerTimerQueue) +{ + DS_ASSERT_OK(datasystem::inject::Set("ProducerImpl.ExecAndCancelTimer.sleep", "1*sleep(1000)")); + DS_ASSERT_OK(datasystem::inject::Set("ProducerImpl.ExecFlush.sleep", "1*sleep(5000)")); + + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client1, "ProducerTimerQueue", producer)); + const size_t testSize = 4ul * 1024ul; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + DS_ASSERT_OK(producer->Send(element)); + + // In the destructor, before we cancel the timer in ExecAndCloseTimer(), we sleep 1 second to let the timer to + // remove the task from the queue. Then the timer sleep for 5 seconds so that the destructor did the flush and the + // producer deallocated. We expect the timer checks the producer still exist through weak pointer before calling + // ExecFlush(). + producer.reset(); + LOG(INFO) << "producer destructed"; + // Sleep extra to ensure no segmentation fault + const uint FIVE_SECS = 5; + sleep(FIVE_SECS); +} + +TEST_F(StreamDfxTest, TestMasterSubTimeout) +{ + std::shared_ptr client1; + std::shared_ptr client2; + uint32_t timeoutMs = 3000; + DS_ASSERT_OK(InitClient(1, client1, timeoutMs)); + DS_ASSERT_OK(InitClient(2, client2, timeoutMs)); + std::string streamName = "testStream"; + std::shared_ptr consumer; + + for (int index = 0; index < WORKER_COUNT; index++) { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, index, "master.SubIncreaseNode.afterLock", "sleep(3000)")); + } + DS_ASSERT_NOT_OK(CreateConsumer(client1, streamName, streamName, consumer)); + + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client2, streamName, producer)); +} + +TEST_F(StreamDfxTest, TestDiskFullWithAutoDelete) +{ + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(1, client1)); + ProducerConf conf; + conf.maxStreamSize = TEST_STREAM_SIZE; + conf.autoCleanup = true; + std::shared_ptr producer; + std::string streamName = "DiskFullWithAutoDelete"; + client1->CreateProducer(streamName, producer, conf); + sleep(1); + std::vector infos; + GetResMonitorLogInfo(1, "resource.log", infos); + int streamCountIdx = (int)ResMetricName::STREAM_COUNT - (int)ResMetricName::SHARED_MEMORY; + ASSERT_EQ(std::stoi(infos[streamCountIdx]), 1); + + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, 0, "master.rocksdb.delete", "return(K_KVSTORE_ERROR)")); + DS_ASSERT_OK(producer->Close()); + sleep(1); + GetResMonitorLogInfo(1, "resource.log", infos); + ASSERT_EQ(std::stoi(infos[streamCountIdx]), 0); + + DS_ASSERT_OK(cluster_->KillWorker(0)); + DS_ASSERT_OK(cluster_->StartNode(WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + sleep(3); // wait 3s for resource log flush + GetResMonitorLogInfo(1, "resource.log", infos); + ASSERT_EQ(std::stoi(infos[streamCountIdx]), 0); +} + +class StreamDfxMultiTest : public StreamDfxTest { + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + const uint32_t workerCount = 4; + opts.numWorkers = workerCount; + opts.masterIdx = 3; + opts.enableDistributedMaster = "true"; + } + +protected: + Status SetupMulti(int i, std::vector> &producers, + std::vector> &consumers, std::shared_ptr &client1, + std::shared_ptr &client2, std::shared_ptr &client3, + std::shared_ptr &client4, std::string strmName) + { + std::string streamName = strmName + std::to_string(i); + const int TWO = 2; + LOG(INFO) << FormatString("Setup Multi configuration %d!", i); + switch (i) { + case TWO: + LOG(INFO) << "Config: W1: p1, W2: p2, W3: c1, W4: c2"; + RETURN_IF_NOT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, {}, consumers)); + RETURN_IF_NOT_OK(CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, {}, consumers)); + RETURN_IF_NOT_OK( + CreateProducerAndConsumer(client3, {}, producers, { { streamName, "sub1" } }, consumers)); + RETURN_IF_NOT_OK( + CreateProducerAndConsumer(client4, {}, producers, { { streamName, "sub2" } }, consumers)); + c1_location = TWO; + break; + case 1: + LOG(INFO) << "Config: W1: p1 p2, W2: c1 c2"; + RETURN_IF_NOT_OK(CreateProducerAndConsumer(client1, { { streamName, 2 } }, producers, {}, consumers)); + RETURN_IF_NOT_OK(CreateProducerAndConsumer( + client2, {}, producers, { { streamName, "sub1" }, { streamName, "sub2" } }, consumers)); + c1_location = 1; + break; + case 0: + LOG(INFO) << "Config: W1: p1 p2 c1 c2"; + RETURN_IF_NOT_OK(CreateProducerAndConsumer(client1, { { streamName, 2 } }, producers, + { { streamName, "sub1" }, { streamName, "sub2" } }, + consumers)); + c1_location = 0; + break; + default: + LOG(INFO) << "No configuration"; + break; + } + return Status::OK(); + } + + Status MultiClientFaultHelper(int i, std::vector> &producers, + std::vector> &consumers, + std::shared_ptr &client1, std::shared_ptr &client2, + std::shared_ptr &client3, std::shared_ptr &client4, + std::string strmName) + { + std::string streamName = strmName + std::to_string(i); + const int TWO = 2; + LOG(INFO) << FormatString("Setup Multi configuration %d!", i); + switch (i) { + case TWO: + LOG(INFO) << "Config: W1: p1, W2: p2, W3: c1, W4: c2"; + RETURN_IF_NOT_OK(CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, {}, consumers)); + RETURN_IF_NOT_OK( + CreateProducerAndConsumer(client3, {}, producers, { { streamName, "sub1" } }, consumers)); + RETURN_IF_NOT_OK( + CreateProducerAndConsumer(client4, {}, producers, { { streamName, "sub2" } }, consumers)); + c1_location = TWO; + break; + case 1: + LOG(INFO) << "Config: W1: p1 p2, W2: c1 c2"; + RETURN_IF_NOT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, {}, consumers)); + RETURN_IF_NOT_OK(CreateProducerAndConsumer( + client2, {}, producers, { { streamName, "sub1" }, { streamName, "sub2" } }, consumers)); + c1_location = 1; + break; + case 0: + LOG(INFO) << "Config: W1: p1 p2 c1 c2"; + RETURN_IF_NOT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, + { { streamName, "sub1" }, { streamName, "sub2" } }, + consumers)); + c1_location = 0; + break; + default: + break; + } + return Status::OK(); + } + + const uint32_t timeoutMs = 1000; + const uint32_t DEFAULT_TIMEOUT_MS = 60'000; + int c1_location = 0; + std::string data = "Hello World"; + std::vector out; + int TWO = 2, THREE = 3; +}; + +TEST_F(StreamDfxMultiTest, TestMultiBasic1) +{ + std::shared_ptr client1, client2, client3, client4; + std::string streamName = "TestMultiBasic1"; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(TWO, client3)); + DS_ASSERT_OK(InitClient(THREE, client4)); + Element element(reinterpret_cast(&data.front()), data.size()); + + int NUM_CONFIGS = 3; + for (int config = 0; config < NUM_CONFIGS; config++) { + std::vector> producers; + std::vector> consumers; + SetupMulti(config, producers, consumers, client1, client2, client3, client4, streamName); + DS_ASSERT_OK(producers[0]->Send(element)); + DS_ASSERT_OK(consumers[0]->Receive(timeoutMs, out)); + DS_ASSERT_OK(consumers[1]->Receive(timeoutMs, out)); + DS_ASSERT_OK(producers[1]->Send(element)); + DS_ASSERT_OK(consumers[0]->Receive(timeoutMs, out)); + DS_ASSERT_OK(consumers[1]->Receive(timeoutMs, out)); + + // Close producers/consumers + DS_ASSERT_OK(producers[0]->Close()); + DS_ASSERT_OK(producers[1]->Close()); + DS_ASSERT_OK(consumers[0]->Close()); + DS_ASSERT_OK(consumers[1]->Close()); + } +} + +TEST_F(StreamDfxMultiTest, TestMultiCloseProducer) +{ + std::shared_ptr client1, client2, client3, client4; + std::string streamName = "testMultiCloseProd"; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(TWO, client3)); + DS_ASSERT_OK(InitClient(THREE, client4)); + Element element(reinterpret_cast(&data.front()), data.size()); + + int NUM_CONFIGS = 3; + for (int config = 0; config < NUM_CONFIGS; config++) { + // Close P1 during data transmission + std::vector> producers; + std::vector> consumers; + SetupMulti(config, producers, consumers, client1, client2, client3, client4, streamName); + DS_ASSERT_OK(producers[0]->Send(element)); + DS_ASSERT_OK(producers[1]->Send(element)); + DS_ASSERT_OK(producers[0]->Close()); + + // Assert normal functions + DS_ASSERT_OK(consumers[0]->Receive(1, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(consumers[1]->Receive(1, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + + DS_ASSERT_OK(producers[1]->Send(element)); + DS_ASSERT_OK(consumers[0]->Receive(1, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(consumers[1]->Receive(1, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + // Close producers/consumers due to out of scope + } +} + +TEST_F(StreamDfxMultiTest, TestMultiCloseConsumer) +{ + std::shared_ptr client1, client2, client3, client4; + std::string streamName = "testMultiCloseCon"; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(TWO, client3)); + DS_ASSERT_OK(InitClient(THREE, client4)); + Element element(reinterpret_cast(&data.front()), data.size()); + + int NUM_CONFIGS = 3; + for (int config = 0; config < NUM_CONFIGS; config++) { + // Consumer C1 is closed during data transmission and receiving + std::vector> producers; + std::vector> consumers; + SetupMulti(config, producers, consumers, client1, client2, client3, client4, streamName); + DS_ASSERT_OK(producers[0]->Send(element)); + DS_ASSERT_OK(consumers[1]->Receive(1, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + DS_ASSERT_OK(producers[1]->Send(element)); + + consumers[0]->Close(); + + // Assert normal functions + DS_ASSERT_OK(consumers[1]->Receive(1, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + + DS_ASSERT_OK(producers[0]->Send(element)); + DS_ASSERT_OK(consumers[1]->Receive(1, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 1); + // Close producers/consumers due to out of scope + } +} + +TEST_F(StreamDfxMultiTest, DISABLED_LEVEL1_TestMultiClientFault) +{ + std::shared_ptr client1, client2, client3, client4; + std::string streamName = "testMultiClientFaults"; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(TWO, client3)); + DS_ASSERT_OK(InitClient(THREE, client4)); + Element element(reinterpret_cast(&data.front()), data.size()); + + int NUM_CONFIGS = 3; + for (int config = 0; config < NUM_CONFIGS; config++) { + std::vector> producers; + std::vector> consumers; + LOG(INFO) << FormatString("TestClientFault start configuration %d!", config); + + MultiClientFaultHelper(config, producers, consumers, client1, client2, client3, client4, streamName); + auto pid = fork(); + if (pid == 0) { + std::shared_ptr client1a; + DS_ASSERT_OK(InitClient(0, client1a)); + std::vector> p2; + std::string streamName = "testMultiClientFaults" + std::to_string(config); + DS_ASSERT_OK(CreateProducerAndConsumer(client1a, { { streamName, 1 } }, p2, {}, consumers)); + // Fake a crash point within producer + datasystem::inject::Set("producer_insert", "1*abort()"); + DS_ASSERT_NOT_OK(p2[0]->Send(element)); + _exit(0); + } + int status; + waitpid(pid, &status, 0); + datasystem::inject::Clear("producer_insert"); + + LOG(INFO) << FormatString("C1 located at %d", c1_location); + // Verification + for (const auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Receive(timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), 0); + } + DS_ASSERT_OK(producers[0]->Send(element)); + + for (const auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Receive(DEFAULT_TIMEOUT_MS, out)); + DS_ASSERT_TRUE(out.size(), 1); + } + LOG(INFO) << FormatString("TestClientFault configuration %d finished!", config); + // All prod/cons are closed since out of scope + } +} + +TEST_F(StreamDfxMultiTest, TestMultiWorkerFault) +{ + std::string streamName = "testMultiWorkerFaults"; + std::shared_ptr client1; + std::shared_ptr client2; + std::shared_ptr client3; + std::shared_ptr client4; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(TWO, client3)); + DS_ASSERT_OK(InitClient(THREE, client4)); + Element element(reinterpret_cast(&data.front()), data.size()); + + int NUM_CONFIGS = 3; + for (int config = 0; config < NUM_CONFIGS; config++) { + // Consumer C1 is closed during data transmission and receiving + std::vector> producers; + std::vector> consumers; + SetupMulti(config, producers, consumers, client1, client2, client3, client4, streamName); + DS_ASSERT_OK(producers[0]->Send(element)); + DS_ASSERT_OK(producers[1]->Send(element)); + + consumers[0]->Close(); + + // Assert normal functions + DS_ASSERT_OK(consumers[1]->Receive(timeoutMs, out)); + DS_ASSERT_OK(producers[1]->Send(element)); + DS_ASSERT_OK(producers[0]->Send(element)); + DS_ASSERT_OK(consumers[1]->Receive(TWO, timeoutMs, out)); + DS_ASSERT_TRUE(out.size(), (unsigned int)TWO); + // Close producers/consumers due to out of scope + } +} + +class StreamDfxWorkerCrashTest : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + const uint32_t workerCount = 3; + opts.numWorkers = workerCount; + opts.masterIdx = 2; + opts.enableDistributedMaster = "false"; + } + +protected: + const int maxStreamSizeMb = 10; +}; + +TEST_F(StreamDfxWorkerCrashTest, DISABLED_TestWorkerCrashStopRemotePush) +{ + LOG(INFO) << "TestWorkerCrashStopRemotePush start!"; + std::shared_ptr client1; + std::shared_ptr client2; + std::string streamName = "testWkrCrashStopRemotePush"; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = maxStreamSizeMb * 1024 * 1024; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + std::this_thread::sleep_for(std::chrono::seconds(1)); + + const size_t testSize = 4ul * 1024ul; + // Keep sending until out of memory + size_t sendCount = 0; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + while (true) { + Status rc = producer->Send(element); + if (rc.IsOk()) { + ++sendCount; + continue; + } + ASSERT_EQ(rc.GetCode(), K_OUT_OF_MEMORY); + break; + } + ASSERT_TRUE(sendCount > 0); + LOG(INFO) << "Number of elements created: " << sendCount; + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + const int64_t timeoutMs = 1000; + for (size_t i = 0; i < sendCount; i++) { + DS_ASSERT_OK(producer->Send(element, timeoutMs)); + } + LOG(INFO) << "TestWorkerCrashStopRemotePush finish!"; +} + +TEST_F(StreamDfxWorkerCrashTest, TestOneWorkerCrashAutoDelete) +{ + LOG(INFO) << "TestOneWorkerCrashAutoDelete start!"; + std::shared_ptr client1; + + DS_ASSERT_OK(InitClient(0, client1)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "testOneWkrCrashAutoDel"; + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + + CheckCount(client1, streamName, 1, 1); + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client1, streamName, 0, 0); + + LOG(INFO) << "TestOneWorkerCrashAutoDelete finish!"; +} + +TEST_F(StreamDfxWorkerCrashTest, LEVEL1_TestOneWorkerCrash) +{ + LOG(INFO) << "TestOneWorkerCrash start!"; + std::shared_ptr client1; + std::shared_ptr client2; + std::shared_ptr client3; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(2, client3)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "testOneWkrCrashed"; + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, { { streamName, "sub2" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client3, { { streamName, 1 } }, producers, { { streamName, "sub3" } }, consumers)); + + CheckCount(client1, streamName, 3, 3); + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + CheckCount(client2, streamName, -1, 3); + CheckCount(client3, streamName, -1, 3); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client2, streamName, 2, 2); + CheckCount(client3, streamName, 2, 2); + + const size_t testSize = 1024; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + std::vector outElements; + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client2, streamName, "sub1", consumer)); + DS_ASSERT_OK(producers[1]->Send(element)); + DS_ASSERT_OK(producers[2]->Send(element)); + DS_ASSERT_OK(consumers[2]->Receive(2, 1000, outElements)); + ASSERT_EQ(outElements.size(), size_t(2)); + outElements.clear(); + DS_ASSERT_OK(consumer->Receive(2, 1000, outElements)); + ASSERT_EQ(outElements.size(), size_t(2)); + LOG(INFO) << "TestOneWorkerCrash finish!"; +} + +TEST_F(StreamDfxWorkerCrashTest, DISABLED_TestTwoWorkerCrash) +{ + LOG(INFO) << "TestTwoWorkerCrash start!"; + std::shared_ptr client1; + std::shared_ptr client2; + std::shared_ptr client3; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(2, client3)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "testTwoWkrCrashed"; + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, { { streamName, "sub2" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client3, { { streamName, 1 } }, producers, { { streamName, "sub3" } }, consumers)); + + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "master.UpdateTopoNotification.setTimeout", + "call(5000)")); + + CheckCount(client3, streamName, 3, 3); + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + CheckCount(client3, streamName, -1, 3); + + // after worker0 start, master will clear metadata and notify to worker1 and worker2 + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client3, streamName, -1, 1); + + std::shared_ptr consumer1; + DS_ASSERT_OK(CreateConsumer(client3, streamName, "sub1", consumer1)); + std::shared_ptr consumer2; + DS_ASSERT_OK(CreateConsumer(client3, streamName, "sub2", consumer2)); + + const size_t testSize = 1024; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + + DS_ASSERT_OK(producers[2]->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumer1->Receive(1, 1000, outElements)); + ASSERT_EQ(outElements.size(), size_t(1)); + + outElements.clear(); + DS_ASSERT_OK(consumer2->Receive(1, 1000, outElements)); + ASSERT_EQ(outElements.size(), size_t(1)); + + outElements.clear(); + DS_ASSERT_OK(consumers[2]->Receive(1, 1000, outElements)); + ASSERT_EQ(outElements.size(), size_t(1)); + + LOG(INFO) << "TestTwoWorkerCrash finish!"; +} + +TEST_F(StreamDfxWorkerCrashTest, DISABLED_LEVEL1_TestProducerWorkerCrashWhileConsumerReceive) +{ + LOG(INFO) << "LEVEL1_TestProducerWorkerCrashWhileConsumerReceive start!"; + std::shared_ptr client1; + std::shared_ptr client2; + std::string streamName = "ProdWkrCrashWhileConRecv"; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::shared_ptr producer; + ProducerConf conf; + conf.maxStreamSize = maxStreamSizeMb * 1024 * 1024; + DS_ASSERT_OK(client1->CreateProducer(streamName, producer, conf)); + + std::shared_ptr consumer; + SubscriptionConfig config("sub", SubscriptionType::STREAM); + DS_ASSERT_OK(client2->Subscribe(streamName, config, consumer)); + std::vector outElements; + + std::string data = "Hello"; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + + Status rc = consumer->Receive(1, -1, outElements); + ASSERT_EQ(rc, Status::OK()); + std::string actualData(reinterpret_cast(outElements[0].ptr), outElements[0].size); + EXPECT_EQ(data, actualData); + + // Producer worker crashes. But the consumer is not aware till the nodeDead period passes. + // Therefore, consumer makes a receive request and waits. The producer carsh report arrives after nodeDead period. + // The consumer is unblocked and the error code for producer crashed is returned. + ThreadPool threadPool(1); + threadPool.Submit([this, consumer]() { + std::vector outElements; + Status rc = consumer->Receive(1, waitNodeDead * 1000, outElements); + ASSERT_EQ(rc.GetCode(), K_SC_PRODUCER_NOT_FOUND); + }); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + LOG(INFO) << "LEVEL1_TestProducerWorkerCrashWhileConsumerReceive finish!"; +} + +class StreamDfxMasterCrashTest : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + const uint32_t workerCount = 3; + opts.masterIdx = 2; + opts.numWorkers = workerCount; + opts.enableDistributedMaster = "false"; + } +}; + +TEST_F(StreamDfxMasterCrashTest, LEVEL1_TestSameMetadata) +{ + LOG(INFO) << "TestSameMetadata start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { "SameMeta", 1 } }, producers, { { "SameMeta", "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { "SameMeta", 1 } }, producers, { { "SameMeta", "sub2" } }, consumers)); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 2); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); + std::this_thread::sleep_for(std::chrono::seconds(nodeTimeout)); + CheckCount(client1, "SameMeta", 2, 2); + CheckCount(client2, "SameMeta", 2, 2); + + LOG(INFO) << "TestSameMetadata finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, TestDiffMetadata) +{ + LOG(INFO) << "TestDiffMetadata start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { "DiffMeta", 1 } }, producers, { { "DiffMeta", "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { "DiffMeta", 1 } }, producers, { { "DiffMeta", "sub2" } }, consumers)); + + // close pub sub in worker 0, but not send to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", + "1*return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseProducer.beforeSendToMaster", + "1*return(K_OK)")); + producers[0] = nullptr; + consumers[0] = nullptr; + CheckCount(client1, "DiffMeta", 2, 2); + CheckCount(client2, "DiffMeta", 2, 2); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 2); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); + // Extend the sleep time for testcase stability purposes + + std::this_thread::sleep_for(std::chrono::seconds(nodeTimeout)); + CheckCount(client1, "DiffMeta", 1, 1); + CheckCount(client2, "DiffMeta", 1, 1); + + LOG(INFO) << "TestDiffMetadata finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, RecoveryAutoDeletePubSub) +{ + LOG(INFO) << "RecoveryAutoDeletePubSub start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "testRecoveryAutoDelPubSub"; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, { { streamName, "sub2" } }, consumers)); + + // close pub sub in worker 0, but not send to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", + "1*return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseProducer.beforeSendToMaster", + "1*return(K_OK)")); + // Delete the producers and consumers on worker 0 + producers[0] = nullptr; + consumers[0] = nullptr; + // Master still thinks there are two producers and consumers because of the above injected actions + CheckCount(client1, streamName, 2, 2); + CheckCount(client2, streamName, 2, 2); + + // Remove remaining producer and consumer so when we restart we can invoke auto delete + producers[1] = nullptr; + consumers[1] = nullptr; + + cluster_->QuicklyShutdownWorker(2); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client1, streamName, 0, 0); + CheckCount(client2, streamName, 0, 0); + + LOG(INFO) << "RecoveryAutoDeletePubSub finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, LEVEL2_RecoveryAutoDeleteSub) +{ + LOG(INFO) << "RecoveryAutoDeleteSub start!"; + std::shared_ptr client1; + + DS_ASSERT_OK(InitClient(0, client1)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "testRecoverAutoDelSub"; + + DS_ASSERT_OK(CreateProducerAndConsumer(client1, {}, producers, { { streamName, "sub1" } }, consumers)); + + // close pub sub in worker 0, but not send to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", + "1*return(K_OK)")); + + // Delete the producers and consumers on worker 0 + consumers[0] = nullptr; + // Master still thinks there are two producers and consumers because of the above injected actions + CheckCount(client1, streamName, 0, 1); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 2); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client1, streamName, 0, 0); + + LOG(INFO) << "RecoveryAutoDeleteSub finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, DISABLED_TestMasterAndClientCrash) +{ + LOG(INFO) << "TestMasterAndClientCrash start!"; + + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { "stream", 1 } }, producers, { { "stream", "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { "stream", 1 } }, producers, { { "stream", "sub2" } }, consumers)); + CheckCount(client1, "stream", 2, 2); + cluster_->ShutdownNode(ClusterNodeType::WORKER, 2); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", + "1*return(K_RPC_UNAVAILABLE)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseProducer.beforeSendToMaster", + "1*return(K_RPC_UNAVAILABLE)")); + + DS_ASSERT_OK(client1->ShutDown()); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client2, "stream", 1, 1); + LOG(INFO) << "TestMasterAndClientCrash finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, LEVEL1_TestCloseProducer) +{ + LOG(INFO) << "LEVEL1_TestCloseProducer start!"; + std::shared_ptr client1; + + DS_ASSERT_OK(InitClient(0, client1)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "closeProdTest"; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 2); + + ASSERT_EQ(producers[0]->Close().GetCode(), K_RPC_UNAVAILABLE); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client1, streamName, 1, 1); + DS_ASSERT_OK(producers[0]->Close()); + CheckCount(client1, streamName, 0, 1); + LOG(INFO) << "LEVEL1_TestCloseProducer finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, TestMasterFailRecoverMetaFromRocksDB) +{ + LOG(INFO) << "TestMasterFailRecoverMetaFromRocksDB start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + std::string streamName = "testMstrFailRecoverMetaRocksDB"; + std::string streamName2 = "testMstrFailRecoverMetaRocksDB_s2"; + // Do not store metadata on RocksDB + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, K_TWO, + "master.RocksStreamMetaStore.DoNotAddPubSubMetadata", "6*return(K_OK)")); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName2, 1 } }, producers, { { streamName2, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, { { streamName, "sub2" } }, consumers)); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, K_TWO); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, K_TWO, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, K_TWO)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client1, streamName, K_TWO, K_TWO); + CheckCount(client1, streamName2, 1, 1); + CheckCount(client2, streamName, K_TWO, K_TWO); + LOG(INFO) << "TestMasterFailRecoverMetaFromRocksDB finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, LEVEL1_TestMasterAndWorkerLostMetadata) +{ + std::string streamName = "testMstrWrkrLostMeta"; + LOG(INFO) << "TestMasterAndWorkerLostMetadata start!"; + auto pid = fork(); + if (pid == 0) { + // Do not store metadata on RocksDB + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, K_TWO, + "master.RocksStreamMetaStore.DoNotAddPubSubMetadata", "2*return(K_OK)")); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, K_TWO); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, K_TWO, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, K_TWO)); + } + ASSERT_TRUE(pid > 0); + std::shared_ptr client1; + const int timeoutMs = 2000; + DS_ASSERT_OK(InitClient(0, client1, timeoutMs)); + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + CheckCount(client1, streamName, 1, 1); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + + int status; + waitpid(pid, &status, 0); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client1, streamName, 0, 0); + + LOG(INFO) << "TestMasterAndWorkerLostMetadata finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, LEVEL1_TestMasterAndSubscriberRestart) +{ + LOG(INFO) << "TestMasterAndWorkerRestart start!"; + std::shared_ptr client1, client2, client3; + std::string streamName = "testMstrAndSubRestart"; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + // Do not store metadata on RocksDB + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, K_TWO, + "master.RocksStreamMetaStore.DoNotAddPubSubMetadata", "2*return(K_OK)")); + + std::vector> producers1, producers2; + std::vector> consumers1, consumers2; + + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers1, {}, consumers1)); + DS_ASSERT_OK(CreateProducerAndConsumer(client2, {}, producers2, { { streamName, "sub1" } }, consumers2)); + CheckCount(client1, streamName, 1, 1); + + ThreadPool pool(K_TWO); + auto fut1 = pool.Submit([this]() { cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); }); + auto fut2 = pool.Submit([this]() { cluster_->ShutdownNode(ClusterNodeType::WORKER, K_TWO); }); + fut1.get(); + fut2.get(); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, K_TWO, "")); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + + fut1 = pool.Submit([this]() { cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1); }); + fut2 = pool.Submit([this]() { cluster_->WaitNodeReady(ClusterNodeType::WORKER, K_TWO); }); + fut1.get(); + fut2.get(); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + + CheckCount(client2, streamName, 1, 0); + DS_ASSERT_OK(InitClient(1, client3)); + + DS_ASSERT_OK(CreateProducerAndConsumer(client3, {}, producers2, { { streamName, "sub2" } }, consumers2)); + std::string data = "This is some data"; + Element element(reinterpret_cast(&data.front()), data.size()); + producers1[0]->Send(element); + std::vector outElements; + DS_ASSERT_OK(consumers2[1]->Receive(1, 10000, outElements)); + ASSERT_EQ(outElements.size(), (size_t)1); + LOG(INFO) << "TestMasterAndWorkerRestart finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, LEVEL2_TestMasterAndPublisherRestart) +{ + LOG(INFO) << "TestMasterAndWorkerRestart start!"; + std::shared_ptr client1, client2, client3; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + // Do not store metadata on RocksDB + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, K_TWO, + "master.RocksStreamMetaStore.DoNotAddPubSubMetadata", "2*return(K_OK)")); + + std::vector> producers1, producers2; + std::vector> consumers1, consumers2; + + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { "MasterPublisherRestart", 1 } }, producers1, {}, consumers1)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, {}, producers2, { { "MasterPublisherRestart", "sub1" } }, consumers2)); + CheckCount(client1, "MasterPublisherRestart", 1, 1); + + ThreadPool pool(K_TWO); + auto fut1 = pool.Submit([this]() { cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); }); + auto fut2 = pool.Submit([this]() { cluster_->ShutdownNode(ClusterNodeType::WORKER, K_TWO); }); + fut1.get(); + fut2.get(); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, K_TWO, "")); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + fut1 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, K_TWO)); }); + fut2 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); }); + fut1.get(); + fut2.get(); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + + CheckCount(client2, "MasterPublisherRestart", 0, 1); + DS_ASSERT_OK(InitClient(1, client3)); + + DS_ASSERT_OK(CreateProducerAndConsumer(client3, { { "MasterPublisherRestart", 1 } }, producers1, {}, consumers1)); + std::string data = "This is some data"; + Element element(reinterpret_cast(&data.front()), data.size()); + producers1[1]->Send(element); + std::vector outElements; + DS_ASSERT_OK(consumers2[0]->Receive(1, 10000, outElements)); + ASSERT_EQ(outElements.size(), (size_t)1); + LOG(INFO) << "TestMasterAndWorkerRestart finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, DISABLED_LEVEL1_TestQueryMetaProducerNotFound) +{ + // This testcase aims to test a bug fix where producer pb can be empty + // at QueryMeta request if the client id for the producer is not found + LOG(INFO) << "LEVEL1_TestQueryMetaProducerNotFound start!"; + const int masterIdx = 2; + + std::shared_ptr client1; + std::shared_ptr client2; + std::string streamName = "testQueryMetaProdNotFound"; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, + { { streamName, "subscription1" } }, consumers)); + DS_ASSERT_OK(CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, + { { streamName, "subscription2" } }, consumers)); + CheckCount(client1, streamName, K_TWO, K_TWO); + cluster_->ShutdownNode(ClusterNodeType::WORKER, masterIdx); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", + "1*return(K_RPC_UNAVAILABLE)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseProducer.beforeSendToMaster", + "1*return(K_RPC_UNAVAILABLE)")); + + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "GetProducerConsumerMetadata.NotFound", "1*call()")); + // Cleanup rocksdb so stream fields can be updated at reconciliation + std::string rocksPath = + cluster_->GetRootDir() + "/worker" + std::to_string(masterIdx) + "/rocksdb/stream_meta_data"; + LOG(INFO) << "Remove rocksdb at path " << rocksPath; + DS_ASSERT_OK(RemoveAll(rocksPath)); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, masterIdx, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, masterIdx)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + static_cast(cluster_.get())->KillWorker(masterIdx); + sleep(1); + static_cast(cluster_.get())->StartNode(ClusterNodeType::WORKER, masterIdx, ""); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + + CheckCount(client1, streamName, K_TWO, K_TWO); + LOG(INFO) << "LEVEL1_TestQueryMetaProducerNotFound finish!"; +} + +TEST_F(StreamDfxMasterCrashTest, TestMetadataNodeFault) +{ + // After producer P1 and consumer C1 are created, the node where the metadata is located is faulty. + // Data receiving and sending are not affected. + // Producer or consumer fails to close until the node where the metadata resides recovers. + + LOG(INFO) << "TestMetadataNodeFault start!"; + + std::shared_ptr client1; + std::shared_ptr client2; + std::string streamName = "testMetaNodeFault"; + std::string streamName2 = "testMetaNodeFault_s2"; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + + CheckCount(client1, streamName, 1, 1); + + // Injection for logging metadata not found + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "GetProducerConsumerMetadata.NotFound", "10*call()")); + // no trigger + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, K_TWO, "master.sc.close_producer_error", "1*return()")); + + // Shut down metadata node + cluster_->ShutdownNode(ClusterNodeType::WORKER, K_TWO); + + // Assert that data send and receive is not affected + std::string data = "This is some data"; + Element element(reinterpret_cast(&data.front()), data.size()); + producers[0]->Send(element); + std::vector outElements; + const int64_t timeoutMs = 1000; + DS_ASSERT_OK(consumers[0]->Receive(1, timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + + // Assert that closing does not work + DS_ASSERT_NOT_OK(producers[0]->Close()); + DS_ASSERT_NOT_OK(consumers[0]->Close()); + + // Restart: + LOG(INFO) << "Restarting metadata node"; + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, K_TWO, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, K_TWO)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName2, 1 } }, producers, { { streamName2, "sub2" } }, consumers)); + + // Assert that closing works + LOG(INFO) << "Can close producer and consumer now"; + DS_ASSERT_OK(producers[0]->Close()); + DS_ASSERT_OK(consumers[0]->Close()); + DS_ASSERT_OK(producers[1]->Close()); + DS_ASSERT_OK(consumers[1]->Close()); + + CheckCount(client1, streamName, 0, 0); + LOG(INFO) << "TestMetadataNodeFault finish!"; +} + +class StreamDfxHeartbeatTest : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + const uint32_t workerCount = 3; + opts.masterIdx = 2; + opts.numWorkers = workerCount; + opts.enableDistributedMaster = "false"; + opts.disableRocksDB = false; + } +}; + +TEST_F(StreamDfxHeartbeatTest, LEVEL1_TestWorkerCrashTimeout) +{ + LOG(INFO) << "LEVEL1_TestWorkerCrashTimeout start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "testWrkrCrashTimeout"; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, { { streamName, "sub2" } }, consumers)); + + CheckCount(client1, streamName, 2, 2); + cluster_->QuicklyShutdownWorker(0); + CheckCount(client2, streamName, -1, 2); + + // sleep until master clear the worker metadata. + std::this_thread::sleep_for(std::chrono::seconds(waitNodeDead)); + CheckCount(client2, streamName, -1, 1); + + const int K_TWO = 2; + ThreadPool pool(K_TWO); + cluster_->QuicklyShutdownWorker(1); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + auto fut1 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); }); + auto fut2 = pool.Submit([this]() { DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); }); + fut1.get(); + fut2.get(); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client2, streamName, -1, 0); + + // sleep until node times out, and check that consumer can still be created + std::shared_ptr consumer1; + std::shared_ptr consumer2; + DS_ASSERT_OK(CreateConsumer(client1, streamName, "sub1", consumer1)); + DS_ASSERT_OK(CreateConsumer(client2, streamName, "sub2", consumer2)); + LOG(INFO) << "LEVEL1_TestWorkerCrashTimeout finish!"; +} + +TEST_F(StreamDfxHeartbeatTest, LEVEL1_TestMasterAndWorkerCrashNotStartWorker) +{ + LOG(INFO) << "LEVEL1_TestMasterAndWorkerCrashNotStartWorker start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + const int timeoutMs = 2000; + DS_ASSERT_OK(InitClient(0, client1, timeoutMs)); + DS_ASSERT_OK(InitClient(1, client2, timeoutMs)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "MasterWorkerCrashNotStartWorker"; + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, { { streamName, "sub2" } }, consumers)); + + client1.reset(); + const int workerIdx = 2; + static_cast(cluster_.get())->KillWorker(workerIdx); + static_cast(cluster_.get())->KillWorker(0); + static_cast(cluster_.get())->StartNode(ClusterNodeType::WORKER, 2, ""); + CheckCount(client2, streamName, -1, 2); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeDead)); + CheckCount(client2, streamName, -1, 1); + static_cast(cluster_.get())->KillWorker(workerIdx); + LOG(INFO) << "LEVEL1_TestMasterAndWorkerCrashNotStartWorker finish!"; +} + +TEST_F(StreamDfxHeartbeatTest, LEVEL1_TestWorkerToMasterTimeout) +{ + LOG(INFO) << "LEVEL1_TestWorkerToMasterTimeout start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "TestWorkerMasterTimeout"; + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { streamName, 1 } }, producers, { { streamName, "sub2" } }, consumers)); + + CheckCount(client1, streamName, 2, 2); + + // heartbeat timeout and node dead. + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "heartbeat.sleep", "1*sleep(10000)")); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeDead)); + CheckCount(client1, streamName, 1, 1); + + const size_t testSize = 1024; + Element element; + std::vector writeElement; + CreateElement(testSize, element, writeElement); + + DS_ASSERT_OK(producers[1]->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumers[1]->Receive(1, 100, outElements)); + LOG(INFO) << "LEVEL1_TestWorkerToMasterTimeout finish!"; +} + +class StreamDfxTopoTest : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + const uint32_t workerCount = 4; + opts.numWorkers = workerCount; + opts.masterIdx = 3; + opts.workerGflagParams = "-node_timeout_s=2 -node_dead_timeout_s=60"; + opts.disableRocksDB = false; + } + + void SetUp() override + { + StreamDfxTest::SetUp(); + int index = 0; + DS_ASSERT_OK(InitClient(index++, client1_)); + DS_ASSERT_OK(InitClient(index++, client2_)); + DS_ASSERT_OK(InitClient(index++, client3_)); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + client3_ = nullptr; + StreamDfxTest::TearDown(); + } + +protected: + Status TryAndDeleteStream(std::shared_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + void waitAbort(uint32_t idx); + std::shared_ptr client1_; + std::shared_ptr client2_; + std::shared_ptr client3_; + const int K_TWO = 2, K_3 = 3, K_5 = 5, K_200 = 200; +}; + +TEST_F(StreamDfxTopoTest, TestCreateConsumerTimeout) +{ + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client1_, "testCreateConsumerTimeout", "sub", consumer)); +} + +TEST_F(StreamDfxTopoTest, TestCreateProducerConsumer) +{ + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client1_, "TestDfxCreateProducerConsumer", "sub", consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client1_, "TestDfxCreateProducerConsumer", producer)); + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(producer->Close()); + + std::shared_ptr consumer2; + DS_ASSERT_OK(CreateConsumer(client1_, "TestDfxCreateProducerConsumer2", "sub", consumer2)); + std::shared_ptr producer2; + DS_ASSERT_OK(CreateProducer(client1_, "TestDfxCreateProducerConsumer2", producer2)); + + DS_ASSERT_OK(producer2->Send(element)); + DS_ASSERT_OK(consumer2->Receive(1, 0, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer2->Close()); + DS_ASSERT_OK(producer2->Close()); +} + +TEST_F(StreamDfxTopoTest, LEVEL1_TestTopoChangeWhenWorkerTimeout) +{ + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "master.UpdateTopoNotification.setTimeout", + "call(2000)")); + + std::shared_ptr p1w1; + std::shared_ptr p1w2; + std::shared_ptr p1w3; + std::string streamName = "TopoWhenWorkerTimeout"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, p1w1)); + DS_ASSERT_OK(CreateProducer(client2_, streamName, p1w2)); + DS_ASSERT_OK(CreateProducer(client3_, streamName, p1w3)); + + std::shared_ptr c1w1; + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "sub-w1", c1w1)); + + // Sleep 10s when worker 0 send heartbeat message. + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "heartbeat.sleep", "1*sleep(10000)")); + const uint32_t sleepTime = 18 * 1000; // 18s + auto heartBeartRecoverTime = std::chrono::system_clock::now() + std::chrono::milliseconds(sleepTime); + std::shared_ptr c1w2; + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "sub-w2", c1w2)); + + // After 3 seconds, the master thinks the worker1 has timeout. + std::this_thread::sleep_for(std::chrono::seconds(3)); + std::shared_ptr c1w3; + DS_ASSERT_OK(CreateConsumer(client3_, streamName, "sub-w3", c1w3)); + DS_ASSERT_OK(c1w2->Close()); + DS_ASSERT_OK(c1w3->Close()); + + std::shared_ptr c2w2; + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "sub-w22", c2w2)); + std::shared_ptr c2w3; + DS_ASSERT_OK(CreateConsumer(client3_, streamName, "sub-w23", c2w3)); + while (std::chrono::system_clock::now() <= heartBeartRecoverTime) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + + DS_ASSERT_OK(p1w1->Send(element)); + DS_ASSERT_OK(p1w2->Send(element)); + DS_ASSERT_OK(p1w3->Send(element)); + std::vector outElements; + DS_ASSERT_OK(c1w1->Receive(K_3, K_5000, outElements)); + ASSERT_EQ(outElements.size(), 3ul); + DS_ASSERT_OK(c2w2->Receive(K_3, K_5000, outElements)); + ASSERT_EQ(outElements.size(), 3ul); + DS_ASSERT_OK(c2w3->Receive(K_3, K_5000, outElements)); + ASSERT_EQ(outElements.size(), 3ul); + + DS_ASSERT_OK(p1w1->Close()); + DS_ASSERT_OK(p1w2->Close()); + DS_ASSERT_OK(p1w3->Close()); + DS_ASSERT_OK(c1w1->Close()); + DS_ASSERT_OK(c2w2->Close()); + DS_ASSERT_OK(c2w3->Close()); + DS_ASSERT_OK(client1_->DeleteStream(streamName)); +} + +TEST_F(StreamDfxTopoTest, LEVEL1_TestNotAllowDeleteStreamIfExistPendingNotify) +{ + std::string streamName = "NotAllowDeleteStream"; + // The testcase tests that if there is pending async notification to send, the delete stream will not succeed. + // It is done via having NotifyDelConsumer UpdateTopoNotification request go through RPC unavailable. + std::shared_ptr c1w1; + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "sub-w1", c1w1)); + + std::shared_ptr p1w2; + DS_ASSERT_OK(CreateProducer(client2_, streamName, p1w2)); + + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 3, "master.SendPendingNotification", "1*sleep(10000)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "heartbeat.sleep", "1*sleep(10000)")); + const uint32_t sleepTime = 20 * 1000; // 20s + auto heartBeartRecoverTime = std::chrono::system_clock::now() + std::chrono::milliseconds(sleepTime); + + // After 5 seconds, the master thinks the worker1 has timeout. + std::this_thread::sleep_for(std::chrono::seconds(5)); + + DS_ASSERT_OK(c1w1->Close()); + DS_ASSERT_OK(p1w2->Close()); + DS_ASSERT_NOT_OK(client2_->DeleteStream(streamName)); + + while (std::chrono::system_clock::now() <= heartBeartRecoverTime) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + DS_ASSERT_OK(client2_->DeleteStream(streamName)); +} + +TEST_F(StreamDfxTopoTest, DISABLED_TestContinueSendNotifyAfterMasterRestart) +{ + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 3, "master.UpdateTopoNotification.setTimeout", + "call(2000)")); + + std::shared_ptr p1w1; + DS_ASSERT_OK(CreateProducer(client1_, "test-stream", p1w1)); + std::shared_ptr c1w1; + DS_ASSERT_OK(CreateConsumer(client1_, "test-stream", "sub-w1", c1w1)); + + // Sleep 10s when worker 0 send heartbeat message. + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "heartbeat.sleep", "1*sleep(10000)")); + const uint32_t sleepTime = 20 * 1000; // 20s + auto heartBeartRecoverTime = std::chrono::system_clock::now() + std::chrono::milliseconds(sleepTime); + + // After 3 seconds, the master thinks worker1 has timeout. + std::this_thread::sleep_for(std::chrono::seconds(3)); + + std::shared_ptr c1w2; + DS_ASSERT_OK(CreateConsumer(client2_, "test-stream", "sub-w2", c1w2)); + std::shared_ptr p1w2; + DS_ASSERT_OK(CreateProducer(client2_, "test-stream", p1w2)); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 3); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 3, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 3)); + + while (std::chrono::system_clock::now() <= heartBeartRecoverTime) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + + std::vector outElements; + DS_ASSERT_OK(p1w1->Send(element)); + DS_ASSERT_OK(c1w2->Receive(1, 1000, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + + DS_ASSERT_OK(p1w2->Send(element)); + DS_ASSERT_OK(c1w1->Receive(1, 1000, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + + DS_ASSERT_OK(p1w1->Close()); + DS_ASSERT_OK(c1w1->Close()); + DS_ASSERT_OK(p1w2->Close()); + DS_ASSERT_OK(c1w2->Close()); + + DS_ASSERT_OK(client1_->DeleteStream("test-stream")); +} + +TEST_F(StreamDfxTopoTest, DISABLED_TestMasterCrashWhenCreateProducer) +{ + std::shared_ptr c1w2; + DS_ASSERT_OK(CreateConsumer(client2_, "test-stream", "sub-w2", c1w2)); + std::shared_ptr c1w3; + DS_ASSERT_OK(CreateConsumer(client3_, "test-stream", "sub-w3", c1w3)); + + DS_ASSERT_OK(inject::Set("rpc_util.retry_on_rpc_error_by_count", "1*call(1)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 3, + "master.PubIncreaseNodeImpl.afterSendNotification", "abort()")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CreateProducer.beforeSendToMaster", + "call(K_OK)")); + std::shared_ptr p1w1; + DS_ASSERT_NOT_OK(CreateProducer(client1_, "test-stream", p1w1)); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 3, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 3)); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + DS_ASSERT_OK(c1w2->Close()); + DS_ASSERT_OK(c1w3->Close()); + DS_ASSERT_OK(client2_->DeleteStream("test-stream")); +} + +void StreamDfxTopoTest::waitAbort(uint32_t idx) +{ + while (cluster_->CheckWorkerProcess(idx)) { + std::this_thread::sleep_for(std::chrono::milliseconds(K_200)); + } +} + +TEST_F(StreamDfxTopoTest, LEVEL2_TestMasterCrashWhenCloseProducer) +{ + std::string streamName = "MasterCrashWhenCloseProd"; + std::shared_ptr p1w1; + DS_ASSERT_OK(CreateProducer(client1_, streamName, p1w1)); + + std::shared_ptr c1w2; + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "sub-w2", c1w2)); + std::shared_ptr c1w3; + DS_ASSERT_OK(CreateConsumer(client3_, streamName, "sub-w3", c1w3)); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 3, "master.PubDecreaseNode.afterSendNotification", + "abort()")); + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseProducer.beforeSendToMaster", "call(K_OK)")); + + DS_ASSERT_NOT_OK(p1w1->Close()); + + (void)cluster_->KillWorker(K_3); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 3, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 3)); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + DS_ASSERT_OK(p1w1->Close()); + DS_ASSERT_OK(c1w2->Close()); + DS_ASSERT_OK(c1w3->Close()); + DS_ASSERT_OK(client2_->DeleteStream(streamName)); +} + +TEST_F(StreamDfxTopoTest, LEVEL1_TestMasterCrashWhenSubscribe) +{ + std::string streamName = "MasterCrashWhenSub"; + std::shared_ptr p1w2; + DS_ASSERT_OK(CreateProducer(client2_, streamName, p1w2)); + std::shared_ptr p1w3; + DS_ASSERT_OK(CreateProducer(client3_, streamName, p1w3)); + + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 3, + "master.SubIncreaseNodeImpl.afterSendNotification", "abort()")); + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.Subscribe.beforeSendToMaster", "call(K_OK)")); + std::shared_ptr c1w1; + DS_ASSERT_NOT_OK(CreateConsumer(client1_, streamName, "sub-w1", c1w1)); + + (void)cluster_->KillWorker(K_3); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 3, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 3)); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + + DS_ASSERT_OK(p1w2->Close()); + DS_ASSERT_OK(p1w3->Close()); + DS_ASSERT_OK(client2_->DeleteStream(streamName)); +} + +TEST_F(StreamDfxTopoTest, LEVEL1_TestMasterCrashWhenCloseConsumer) +{ + std::string streamName = "MasterCrashWhenCloseCon"; + std::shared_ptr c1w1; + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "sub-w1", c1w1)); + std::shared_ptr p1w2; + DS_ASSERT_OK(CreateProducer(client2_, streamName, p1w2)); + std::shared_ptr p1w3; + DS_ASSERT_OK(CreateProducer(client3_, streamName, p1w3)); + + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 3, "master.SubDecreaseNode.afterSendNotification", + "abort()")); + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", "call(K_OK)")); + DS_ASSERT_NOT_OK(c1w1->Close()); + + (void)cluster_->KillWorker(K_3); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 3, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 3)); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + + DS_ASSERT_OK(c1w1->Close()); + DS_ASSERT_OK(p1w2->Close()); + DS_ASSERT_OK(p1w3->Close()); + DS_ASSERT_OK(TryAndDeleteStream(client2_, streamName)); +} + +TEST_F(StreamDfxTopoTest, LEVEL1_TestWorkerRestartThenClosePubSub) +{ + std::shared_ptr p1w1; + DS_ASSERT_OK(CreateProducer(client1_, "WorkerRestartClosePubSub", p1w1)); + std::shared_ptr c1w1; + DS_ASSERT_OK(CreateConsumer(client1_, "WorkerRestartClosePubSub", "sub-w1", c1w1)); + + std::shared_ptr p2w2; + DS_ASSERT_OK(CreateProducer(client2_, "WorkerRestartClosePubSub", p2w2)); + std::shared_ptr c2w2; + DS_ASSERT_OK(CreateConsumer(client2_, "WorkerRestartClosePubSub", "sub-w2", c2w2)); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + + DS_ASSERT_OK(p1w1->Close()); + DS_ASSERT_OK(c1w1->Close()); + DS_ASSERT_OK(client1_->DeleteStream("WorkerRestartClosePubSub")); +} + +class StreamDfxDistMasterCrashTest : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + opts.masterIdx = 2; + opts.numWorkers = workerCount_; + opts.enableDistributedMaster = "true"; + } + +protected: + const uint32_t workerCount_ = 3; +}; + +TEST_F(StreamDfxDistMasterCrashTest, TestSameMetadata) +{ + LOG(INFO) << "TestSameMetadata start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(K_TWO, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { "stream", 1 } }, producers, { { "stream", "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { "stream", 1 } }, producers, { { "stream", "sub2" } }, consumers)); + + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client1, "stream", K_TWO, K_TWO); + CheckCount(client2, "stream", K_TWO, K_TWO); + + LOG(INFO) << "TestSameMetadata finish!"; +} + +TEST_F(StreamDfxDistMasterCrashTest, DISABLED_TestDiffMetadata) +{ + LOG(INFO) << "TestDiffMetadata start!"; + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(K_TWO, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { "stream", 1 } }, producers, { { "stream", "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { "stream", 1 } }, producers, { { "stream", "sub2" } }, consumers)); + + // close pub sub in worker 0, but not send to master + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", + "1*return(K_OK)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseProducer.beforeSendToMaster", + "1*return(K_OK)")); + producers[0] = nullptr; + consumers[0] = nullptr; + CheckCount(client1, "stream", K_TWO, K_TWO); + CheckCount(client2, "stream", K_TWO, K_TWO); + + // Here stream is always hashed to worker 1 + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + // Extend the sleep time for testcase stability purposes + const int WAIT_NODE_READY_TIME = 10; + std::this_thread::sleep_for(std::chrono::seconds(WAIT_NODE_READY_TIME)); + CheckCount(client1, "stream", 1, 1); + CheckCount(client2, "stream", 1, 1); + + LOG(INFO) << "TestDiffMetadata finish!"; +} + +TEST_F(StreamDfxDistMasterCrashTest, DISABLED_TestMasterAndClientCrash) +{ + LOG(INFO) << "TestMasterAndClientCrash start!"; + + std::shared_ptr client1; + std::shared_ptr client2; + + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(K_TWO, client2)); + + std::vector> producers; + std::vector> consumers; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { "stream", 1 } }, producers, { { "stream", "sub1" } }, consumers)); + DS_ASSERT_OK( + CreateProducerAndConsumer(client2, { { "stream", 1 } }, producers, { { "stream", "sub2" } }, consumers)); + CheckCount(client1, "stream", K_TWO, K_TWO); + cluster_->ShutdownNode(ClusterNodeType::WORKER, 1); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseConsumer.beforeSendToMaster", + "1*return(K_RPC_UNAVAILABLE)")); + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CloseProducer.beforeSendToMaster", + "1*return(K_RPC_UNAVAILABLE)")); + + DS_ASSERT_OK(client1->ShutDown()); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client2, "stream", 1, 1); + LOG(INFO) << "TestMasterAndClientCrash finish!"; +} + +TEST_F(StreamDfxDistMasterCrashTest, TestCloseProducer) +{ + LOG(INFO) << "TestCloseProducer start!"; + std::shared_ptr client1; + + DS_ASSERT_OK(InitClient(0, client1)); + + std::vector> producers; + std::vector> consumers; + std::string streamName = "testCloseProd"; + + DS_ASSERT_OK( + CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, { { streamName, "sub1" } }, consumers)); + + for (uint32_t i = 0; i < workerCount_; i++) { + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, i, "master.CloseProducerImpl", + "1*return(K_RPC_UNAVAILABLE)")); + } + + ASSERT_EQ(producers[0]->Close().GetCode(), K_RPC_UNAVAILABLE); + + CheckCount(client1, streamName, 1, 1); + DS_ASSERT_OK(producers[0]->Close()); + CheckCount(client1, streamName, 0, 1); + LOG(INFO) << "TestCloseProducer finish!"; +} + +class StreamDistDfxTopoTest : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + const uint32_t workerCount = 2; + opts.numWorkers = workerCount; + opts.enableDistributedMaster = "true"; + opts.workerGflagParams = "-node_timeout_s=2 -node_dead_timeout_s=60 -v=2"; + } + + void SetUp() override + { + StreamDfxTest::SetUp(); + int index = 0; + DS_ASSERT_OK(InitClient(index++, client1_)); + DS_ASSERT_OK(InitClient(index++, client2_)); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + StreamDfxTest::TearDown(); + } + + template + void UntilTrueOrTimeout(Func &&func, uint64_t timeoutMs) + { + auto timeOut = std::chrono::steady_clock::now() + std::chrono::milliseconds(timeoutMs); + while (std::chrono::steady_clock::now() < timeOut) { + std::string value; + if (func()) { + return; + } + const int interval = 1000; // 1000ms; + std::this_thread::sleep_for(std::chrono::milliseconds(interval)); + } + ASSERT_TRUE(false) << "Timeout"; + } + +protected: + std::shared_ptr client1_; + std::shared_ptr client2_; +}; + +TEST_F(StreamDistDfxTopoTest, DISABLED_LEVEL1_TestWorkerExistsMetaAndStart) +{ + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "master.UpdateTopoNotification.setTimeout", + "call(2000)")); + + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 1, "master.UpdateTopoNotification.setTimeout", + "call(2000)")); + std::shared_ptr p1w1; + DS_ASSERT_OK(CreateProducer(client1_, "test-stream", p1w1)); + std::shared_ptr c1w2; + DS_ASSERT_OK(CreateConsumer(client2_, "test-stream", "sub-w2-1", c1w2)); + + // Sleep 10s when worker 0 send heartbeat message. + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "heartbeat.sleep", "1*sleep(10000)")); + cluster_->ShutdownNode(ClusterNodeType::WORKER, 0); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "-inject_actions=worker.InitRing:call()")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + + std::shared_ptr c2w2; + DS_ASSERT_OK(CreateConsumer(client2_, "test-stream", "sub-w2-2", c2w2)); +} + +TEST_F(StreamDistDfxTopoTest, TestWorkerRestartRetryCheckMeta) +{ + std::vector> producers; + std::vector> consumers; + const int streamCount = 10; + for (int i = 0; i < streamCount; i++) { + std::string streamName = "test-stream-" + std::to_string(i); + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client1_, streamName, producer)); + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "sub", consumer)); + producers.emplace_back(std::move(producer)); + consumers.emplace_back(std::move(consumer)); + } + + DS_ASSERT_OK(cluster_->KillWorker(1)); + DS_ASSERT_OK( + cluster_->StartNode(ClusterNodeType::WORKER, 1, + "-inject_actions=worker.MasterRemoteWorkerSCApi.QueryMetadata:3*return(K_TRY_AGAIN)")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + + int maxTimeout = 10000; // 10s. + for (int i = 0; i < streamCount; i++) { + std::string streamName = "test-stream-" + std::to_string(i); + UntilTrueOrTimeout( + [this, &streamName] { + uint64_t gProducerNum = 0; + uint64_t gConsumerNum = 0; + LOG_IF_ERROR(client1_->QueryGlobalProducersNum(streamName, gProducerNum), + "QueryGlobalProducersNum failed"); + LOG_IF_ERROR(client1_->QueryGlobalConsumersNum(streamName, gConsumerNum), + "QueryGlobalConsumersNum failed"); + return gProducerNum == 1 && gConsumerNum == 1; + }, + maxTimeout); + } +} + +class StreamDfxSingleProducerMultiConsumerTest : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + const uint32_t workerCount = 3; + opts.masterIdx = 2; + opts.numWorkers = workerCount; + opts.enableDistributedMaster = "true"; + } + +protected: + int timeOut = 1000; + std::string data = "Hello World"; +}; + +/* +On same node. Create 1 producer 2 consumers. producer is faulty on send through injection. +consumers should not be able to receive anything. Producer can send normally after +injection cleared. Consumer will receive after. +*/ +TEST_F(StreamDfxSingleProducerMultiConsumerTest, SameNodeProducerClientFault) +{ + LOG(INFO) << "SameNodeProducerClientFault start!"; + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::vector> producers; + std::vector> consumers; + std::string streamName = "SameNodeProducerClientFault"; + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, + { { streamName, "sub1" }, { streamName, "sub2" } }, consumers)); + std::vector outElements; + Element element(reinterpret_cast(&data.front()), data.size()); + + std::thread sendThread([&]() { + datasystem::inject::Set("producer_insert", "1*return(K_INVALID)"); + DS_ASSERT_NOT_OK(producers[0]->Send(element)); + }); + for (const auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Receive(timeOut, outElements)); + DS_ASSERT_TRUE(outElements.size(), 0); + } + sendThread.join(); + datasystem::inject::Clear("producer_insert"); + + DS_ASSERT_OK(producers[0]->Send(element)); + for (const auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Receive(timeOut, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + } + + LOG(INFO) << "SameNodeProducerClientFault finish!"; +} + +/* +On same node. Create 1 producer 2 consumers. 1 Consumer is faulty on receive through +injection. Consumer2 should be able to receive still. +*/ +TEST_F(StreamDfxSingleProducerMultiConsumerTest, SameNodeOneConsumerClientFault) +{ + LOG(INFO) << "SameNodeOneConsumerClientFault start!"; + std::shared_ptr client1, client2, client3; + DS_ASSERT_OK(InitClient(0, client1)); + std::vector> producers; + std::vector> consumers; + std::string streamName = "SameNodeOneConsumerClientFault"; + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, + { { streamName, "sub1" }, { streamName, "sub2" } }, consumers)); + std::vector outElements1; + std::vector outElements2; + Element element(reinterpret_cast(&data.front()), data.size()); + + std::thread sendThread([&]() { DS_ASSERT_OK(producers[0]->Send(element)); }); + datasystem::inject::Set("consumerImpl.receive.fail", "1*return(K_INVALID)"); + DS_ASSERT_NOT_OK(consumers[0]->Receive(timeOut, outElements1)); + DS_ASSERT_OK(consumers[1]->Receive(timeOut, outElements2)); + + sendThread.join(); + DS_ASSERT_TRUE(outElements1.size(), 0); + DS_ASSERT_TRUE(outElements2.size(), 1); + + LOG(INFO) << "SameNodeOneConsumerClientFault finish!"; +} + +/* +On same node. Create 1 producer 2 consumers. Both consumers are faulty on receive through +injection. Producer sends, but both consumers do not receive. Can receive normally after +injection clear. +*/ +TEST_F(StreamDfxSingleProducerMultiConsumerTest, SameNodeBothConsumerClientFault) +{ + LOG(INFO) << "SameNodeBothConsumerClientFault start!"; + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + std::vector> producers; + std::vector> consumers; + std::string streamName = "stream_BothConsumerClientFault"; + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, + { { streamName, "sub1" }, { streamName, "sub2" } }, consumers)); + std::vector outElements; + Element element(reinterpret_cast(&data.front()), data.size()); + + std::thread sendThread([&]() { DS_ASSERT_OK(producers[0]->Send(element)); }); + datasystem::inject::Set("consumerImpl.receive.fail", "2*return(K_INVALID)"); + for (const auto &consumer : consumers) { + DS_ASSERT_NOT_OK(consumer->Receive(timeOut, outElements)); + } + sendThread.join(); + datasystem::inject::Clear("consumerImpl.receive.fail"); + + for (const auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Receive(timeOut, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + } + + LOG(INFO) << "SameNodeBothConsumerClientFault finish!"; +} + +/* +On different nodes. Create 1 producer 2 consumemrs on 3 seperate node. Node that has +producer is faulty and is shutdown. producer is not able to send. restart the node. +global prod count should be 0 and consumers cannot receive anything. creating new +producer after node restart would work fine. +*/ +TEST_F(StreamDfxSingleProducerMultiConsumerTest, DiffNodeProducerWorkerFault) +{ + LOG(INFO) << "DiffNodeProducerWorkerFault start!"; + std::shared_ptr client1, client2, client3; + int idx = 2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(idx, client3)); + std::vector> producers; + std::vector> consumers; + std::string streamName = "DiffNodeProducerWorkerFault"; + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, {}, consumers)); + DS_ASSERT_OK(CreateProducerAndConsumer(client2, { {} }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK(CreateProducerAndConsumer(client3, { {} }, producers, { { streamName, "sub2" } }, consumers)); + std::vector outElements; + Element element(reinterpret_cast(&data.front()), data.size()); + + DS_ASSERT_OK(cluster_->QuicklyShutdownWorker(0)); + DS_ASSERT_NOT_OK(producers[0]->Send(element)); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 0, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 0)); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + int consumerCheck = 2; + CheckCount(client1, streamName, 0, consumerCheck); + for (const auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Receive(timeOut, outElements)); + DS_ASSERT_TRUE(outElements.size(), 0); + } + + LOG(INFO) << "DiffNodeProducerWorkerFault finish!"; +} + +/* +On different nodes. Create 1 producer 2 consumemrs on 3 seperate node. Node that has +first consumer is faulty and is shutdown. consumer[0] is cannot receive. restart the node. +global consumer count should be 1 and only one consumer can receive. creating new +consumer after node restart would work fine. +*/ +TEST_F(StreamDfxSingleProducerMultiConsumerTest, LEVEL1_DiffNodeOneConsumerWorkerFault) +{ + LOG(INFO) << "DiffNodeOneConsumerWorkerFault start!"; + std::shared_ptr client1, client2, client3; + int idx = 2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(idx, client3)); + std::vector> producers; + std::vector> consumers; + std::string streamName = "DiffNodeOneConsumerWorkerFault"; + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, {}, consumers)); + DS_ASSERT_OK(CreateProducerAndConsumer(client2, { {} }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK(CreateProducerAndConsumer(client3, { {} }, producers, { { streamName, "sub2" } }, consumers)); + std::vector outElements; + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producers[0]->Send(element)); + + DS_ASSERT_OK(cluster_->QuicklyShutdownWorker(1)); + DS_ASSERT_NOT_OK(consumers[0]->Receive(timeOut, outElements)); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); + + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client2, streamName, 1, 1); + DS_ASSERT_OK(consumers[1]->Receive(timeOut, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + + LOG(INFO) << "DiffNodeOneConsumerWorkerFault finish!"; +} + +/* +On different nodes. Create 1 producer 2 consumers on 3 seperate node. both node that +consumer is faulty and is shutdown. consumers cannot receive. restart the node. +global consumer count should be 1 and only one consumer can receive. creating new +consumers after node restart would work fine. +*/ +TEST_F(StreamDfxSingleProducerMultiConsumerTest, DiffNodeBothConsumerWorkerFault) +{ + LOG(INFO) << "DiffNodeBothConsumerWorkerFault start!"; + std::shared_ptr client1, client2, client3; + int idx = 2; + DS_ASSERT_OK(InitClient(0, client1)); + DS_ASSERT_OK(InitClient(1, client2)); + DS_ASSERT_OK(InitClient(idx, client3)); + std::vector> producers; + std::vector> consumers; + std::string streamName = "DiffNodeBothConsumerWorkerFault"; + DS_ASSERT_OK(CreateProducerAndConsumer(client1, { { streamName, 1 } }, producers, {}, consumers)); + DS_ASSERT_OK(CreateProducerAndConsumer(client2, { {} }, producers, { { streamName, "sub1" } }, consumers)); + DS_ASSERT_OK(CreateProducerAndConsumer(client3, { {} }, producers, { { streamName, "sub2" } }, consumers)); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producers[0]->Send(element)); + ThreadPool pool(idx); + auto fut1 = pool.Submit([this, consumers]() { + DS_ASSERT_OK(cluster_->QuicklyShutdownWorker(1)); + std::vector outElements; + DS_ASSERT_NOT_OK(consumers[0]->Receive(timeOut, outElements)); + }); + auto fut2 = pool.Submit([this, consumers]() { + DS_ASSERT_OK(cluster_->QuicklyShutdownWorker(2)); + std::vector outElements1; + DS_ASSERT_NOT_OK(consumers[1]->Receive(timeOut, outElements1)); + }); + fut1.get(); + fut2.get(); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 1, "")); + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + fut1 = pool.Submit([this, consumers]() { DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 1)); }); + fut2 = pool.Submit([this, consumers]() { DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); }); + fut1.get(); + fut2.get(); + std::this_thread::sleep_for(std::chrono::seconds(waitNodeTimeout)); + CheckCount(client3, streamName, 1, 0); + + LOG(INFO) << "DiffNodeBothConsumerWorkerFault finish!"; +} + +// DFX testcases for Multiple-Producer and Single-Consumer +class StreamDfxClientCrashMPSC : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + StreamDfxTest::SetClusterSetupOptions(opts); + opts.numWorkers = workerCount_; + } + +protected: + const uint32_t workerCount_ = 3; +}; + +TEST_F(StreamDfxClientCrashMPSC, DISABLED_SameNodeProducerProcessCrash) +{ + std::string streamName("stream1"); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + + // Create Producer1 + auto pid = fork(); + if (pid == 0) { + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + + std::shared_ptr Producer1; + DS_ASSERT_OK(CreateProducer(client1, streamName, Producer1)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(Producer1->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + + // Create Producer2 + pid = fork(); + if (pid == 0) { + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client2)); + std::shared_ptr Producer2; + DS_ASSERT_OK(CreateProducer(client2, streamName, Producer2)); + // Fake a crash point within producer after it holds the lock + datasystem::inject::Set("producer_obtained_lock", "1*abort()"); + DS_ASSERT_NOT_OK(Producer2->Send(element)); + _exit(0); + } + ASSERT_TRUE(pid > 0); + int status; + waitpid(pid, &status, 0); + + // Create Consumer1 + std::shared_ptr client3; + DS_ASSERT_OK(InitClient(0, client3)); + std::shared_ptr Consumer1; + DS_ASSERT_OK(CreateConsumer(client3, streamName, "sub1", Consumer1)); + + std::vector outElements; + const uint32_t timeoutMs = 100; + DS_ASSERT_OK(Consumer1->Receive(timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), 0); + DS_ASSERT_OK(Consumer1->Close()); +} + +TEST_F(StreamDfxClientCrashMPSC, SameNodeConsumerProcessCrash) +{ + std::string streamName("testSameNodeConProcessCrash"); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + + std::shared_ptr client1; + DS_ASSERT_OK(InitClient(0, client1)); + + // Create Consumer1 + auto pid = fork(); + if (pid == 0) { + std::shared_ptr client2; + DS_ASSERT_OK(InitClient(0, client2)); + std::shared_ptr Consumer1; + DS_ASSERT_OK(CreateConsumer(client2, streamName, "sub1", Consumer1)); + std::vector outElements; + + const int SLEEP_TIME = 1000; + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_TIME)); + + // Fake a crash point within consumer after it gets the DataPage from worker + datasystem::inject::Set("consumer_after_get_datapage", "abort()"); + const uint32_t timeoutMs = 2000; + DS_ASSERT_NOT_OK(Consumer1->Receive(timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), 0); + _exit(0); + } + ASSERT_TRUE(pid > 0); + + // Create Producer1 and Producer2 + std::shared_ptr Producer1; + DS_ASSERT_OK(CreateProducer(client1, streamName, Producer1)); + std::shared_ptr Producer2; + DS_ASSERT_OK(CreateProducer(client1, streamName, Producer2)); + + uint64_t totalConsumerNum = 0; + while (totalConsumerNum == 0) { + client1->QueryGlobalConsumersNum(streamName, totalConsumerNum); + } + + int threadNum = 2; + ThreadPool pool(threadNum); + auto fut1 = pool.Submit([this, Producer1, element]() { + Producer1->Send(element); + Producer1->Close(); + }); + auto fut2 = pool.Submit([this, Producer2, element]() { + Producer2->Send(element); + Producer2->Close(); + }); + + int status; + waitpid(pid, &status, 0); + fut1.get(); + fut2.get(); +} + +class StreamDfxWorkerCrashMPSC : public StreamDfxTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = numWorkers_; + opts.numEtcd = numEtcd_; + opts.numRpcThreads = numRpcThreads_; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + StreamDfxTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + client3_ = nullptr; + StreamDfxTest::TearDown(); + } + +protected: + void InitTest() + { + int workerIndex = 0; + InitClient(workerIndex++, client1_); + InitClient(workerIndex++, client2_); + InitClient(workerIndex, client3_); + } + + std::shared_ptr client1_ = nullptr; + std::shared_ptr client2_ = nullptr; + std::shared_ptr client3_ = nullptr; + + // Cluster config + int numWorkers_ = 3; + int numEtcd_ = 1; + int numRpcThreads_ = 0; +}; + +TEST_F(StreamDfxWorkerCrashMPSC, CrossNodeProducerWorkerCrash) +{ + std::string stream1("CrossNodeProducerWorkerCrash"); + + std::shared_ptr Producer1; + DS_ASSERT_OK(CreateProducer(client1_, stream1, Producer1)); + std::shared_ptr Producer2; + DS_ASSERT_OK(CreateProducer(client2_, stream1, Producer2)); + std::shared_ptr Consumer1; + DS_ASSERT_OK(CreateConsumer(client3_, stream1, "sub1", Consumer1)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + int threadNum = 4; + ThreadPool pool(threadNum); + uint32_t workerIndex = 0; + + cluster_->QuicklyShutdownWorker(workerIndex); + + // Set pause at send + auto fut2 = pool.Submit([this, Producer2, element]() { DS_ASSERT_OK(Producer2->Send(element)); }); + datasystem::inject::Set("ProducerImpl.beforeCreateWritePage", "sleep(3000)"); + auto fut1 = pool.Submit([this, Producer1, element]() { DS_ASSERT_NOT_OK(Producer1->Send(element)); }); + auto fut3 = pool.Submit([this, Consumer1]() { + uint32_t timeoutMs = 5000; + std::vector outElements; + DS_ASSERT_OK(Consumer1->Receive(timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), 1); + }); + + // Wait for all send/receive to finish + fut1.get(); + fut2.get(); + fut3.get(); +} + +TEST_F(StreamDfxWorkerCrashMPSC, DISABLED_CrossNodeBothProducerWorkerCrash) +{ + std::string stream1("CrossNodeBothProducerWorkerCrash"); + + std::shared_ptr Producer1; + DS_ASSERT_OK(CreateProducer(client1_, stream1, Producer1)); + std::shared_ptr Producer2; + DS_ASSERT_OK(CreateProducer(client2_, stream1, Producer2)); + std::shared_ptr Consumer1; + DS_ASSERT_OK(CreateConsumer(client3_, stream1, "sub1", Consumer1)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + int threadNum = 5; + ThreadPool pool(threadNum); + uint32_t workerIndex = 0; + + // Close workers of both Producer1 and Producer2 + auto fut = pool.Submit([this, workerIndex]() { cluster_->QuicklyShutdownWorker(workerIndex); }); + workerIndex++; + auto fut0 = pool.Submit([this, workerIndex]() { cluster_->QuicklyShutdownWorker(workerIndex); }); + fut.get(); + fut0.get(); + + // Set pause at send + datasystem::inject::Set("ProducerImpl.beforeCreateWritePage", "sleep(3000)"); + auto fut1 = pool.Submit([this, Producer1, element]() { DS_ASSERT_NOT_OK(Producer1->Send(element)); }); + auto fut2 = pool.Submit([this, Producer2, element]() { DS_ASSERT_NOT_OK(Producer2->Send(element)); }); + auto fut3 = pool.Submit([this, Consumer1]() { + uint32_t timeoutMs = 3000; + std::vector outElements; + DS_ASSERT_OK(Consumer1->Receive(timeoutMs, outElements)); + DS_ASSERT_TRUE(outElements.size(), 0); + }); + + // Wait for all send/receive to return + fut1.get(); + fut2.get(); + fut3.get(); +} + +TEST_F(StreamDfxWorkerCrashMPSC, LEVEL1_CrossNodeConsumerWorkerCrash) +{ + std::string stream1("CrossNodeConsumerWorkerCrash"); + + std::shared_ptr Producer1; + DS_ASSERT_OK(CreateProducer(client1_, stream1, Producer1)); + std::shared_ptr Producer2; + DS_ASSERT_OK(CreateProducer(client2_, stream1, Producer2)); + std::shared_ptr Consumer1; + DS_ASSERT_OK(CreateConsumer(client3_, stream1, "sub1", Consumer1)); + + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + int threadNum = 3; + ThreadPool pool(threadNum); + uint32_t workerIndex = 2; + + // shutdown the consumer worker + cluster_->QuicklyShutdownWorker(workerIndex); + + // Set pause at receive/send + auto fut1 = pool.Submit([this, Producer1, element]() { DS_ASSERT_OK(Producer1->Send(element)); }); + auto fut2 = pool.Submit([this, Producer2, element]() { DS_ASSERT_OK(Producer2->Send(element)); }); + auto fut3 = pool.Submit([this, Consumer1]() { + uint32_t timeoutMs = 5000; + std::vector outElements; + DS_ASSERT_NOT_OK(Consumer1->Receive(timeoutMs, outElements)); + ASSERT_TRUE(outElements.size() < 2); + }); + + // Wait for all send/receive to finish + fut1.get(); + fut2.get(); + fut3.get(); +} + +TEST_F(StreamDfxWorkerCrashMPSC, CrossNodeNetworkIssueBetweenWorkers) +{ + std::string stream1("CrossNodeNetworkIssueBetweenWorkers"); + + std::shared_ptr Producer1; + DS_ASSERT_OK(CreateProducer(client1_, stream1, Producer1)); + std::shared_ptr Producer2; + DS_ASSERT_OK(CreateProducer(client2_, stream1, Producer2)); + std::shared_ptr Consumer1; + DS_ASSERT_OK(CreateConsumer(client3_, stream1, "sub1", Consumer1)); + + // Simulate lost rpc request + int workerIndex = 2; + DS_ASSERT_OK( + cluster_->SetInjectAction(ClusterNodeType::WORKER, workerIndex, "PushElementsCursors.begin", "sleep(10000)")); + std::string data = "Hello World"; + Element element(reinterpret_cast(&data.front()), data.size()); + int threadNum = 3; + ThreadPool pool(threadNum); + + // Set pause after PushElementsCursors being sent + + auto fut1 = pool.Submit([this, Producer1, element]() { DS_ASSERT_OK(Producer1->Send(element)); }); + auto fut2 = pool.Submit([this, Producer2, element]() { DS_ASSERT_OK(Producer2->Send(element)); }); + auto fut3 = pool.Submit([this, Consumer1]() { + uint32_t timeoutMs = 5000; + std::vector outElements; + DS_ASSERT_OK(Consumer1->Receive(timeoutMs, outElements)); + ASSERT_TRUE(outElements.size() == 0); + }); + + // Wait for all send/receive to finish + fut1.get(); + fut2.get(); + fut3.get(); +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_meta_shm_test.cpp b/tests/st/client/stream_cache/stream_meta_shm_test.cpp new file mode 100644 index 0000000..05889ae --- /dev/null +++ b/tests/st/client/stream_cache/stream_meta_shm_test.cpp @@ -0,0 +1,351 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test StreamMetaShm. + */ +#include +#include +#include +#include "client/stream_cache/sc_client_common.h" +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "datasystem/common/util/raii.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/utils/status.h" +namespace datasystem { +namespace st { +constexpr int K_TWO = 2; +using namespace datasystem::client::stream_cache; +class StreamMetaShmTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; + + void SetUp() override; + + void TearDown() override; + +protected: + Status TryAndDeleteStream(std::shared_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + Status SendHelper(std::shared_ptr producer, size_t numElements, bool &stop, size_t idx, int minEleSize, + int maxEleSize); + Status ReceiveHelper(std::shared_ptr consumer, size_t numElements, bool &stopReceive); + void BasicMPSCTest(int minEleSize, int maxEleSize); + + // Mock producer worker. + HostPort w1Addr_; + HostPort w2Addr_; + HostPort w3Addr_; + + std::vector> clients; + + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + const int DEFAULT_WORKER_NUM = 3; + const int DEFAULT_LOG_LEVEL = 1; +}; + +void StreamMetaShmTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = DEFAULT_WORKER_NUM; + opts.enableDistributedMaster = "true"; + opts.numRpcThreads = 0; + opts.vLogLevel = DEFAULT_LOG_LEVEL; + SCClientCommon::SetClusterSetupOptions(opts); +} + +void StreamMetaShmTest::SetUp() +{ + ExternalClusterTest::SetUp(); + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, w1Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, w2Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(K_TWO, w3Addr_)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, K_TWO)); + std::shared_ptr c1, c2, c3; + InitStreamClient(0, c1); + InitStreamClient(1, c2); + InitStreamClient(K_TWO, c3); + + clients.emplace_back(c1); + clients.emplace_back(c2); + clients.emplace_back(c3); +} + +void StreamMetaShmTest::TearDown() +{ + for (auto &client : clients) { + client.reset(); + } + ExternalClusterTest::TearDown(); +} + +Status StreamMetaShmTest::SendHelper(std::shared_ptr producer, size_t numElements, bool &stop, size_t idx, + int minEleSize, int maxEleSize) +{ + const int DEFAULT_SLEEP_TIME = 300; + int retryLimit = 300; + uint64_t totalEleSize = 0; + Raii raii([&totalEleSize, idx]() { + LOG(INFO) << "TotalSendSize: " << totalEleSize << ", producer: " << idx; + }); + for (size_t i = 0; i < numElements; i++) { + RandomData rand; + int64_t dataSize = rand.GetRandomUint64(minEleSize, maxEleSize); + std::string writeElement = RandomData().GetRandomString(dataSize); + Element element(reinterpret_cast(writeElement.data()), writeElement.size()); + datasystem::Status rc = producer->Send(element); + if (rc.IsError()) { + while (rc.GetCode() == K_OUT_OF_MEMORY && retryLimit-- > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(DEFAULT_SLEEP_TIME)); + CHECK_FAIL_RETURN_STATUS(!stop, K_RUNTIME_ERROR, ""); + rc = producer->Send(element); + } + } + if (rc) { + totalEleSize += dataSize; + } + CHECK_FAIL_RETURN_STATUS(!stop, K_RUNTIME_ERROR, ""); + RETURN_IF_NOT_OK(rc); + } + return Status::OK(); +} + +Status StreamMetaShmTest::ReceiveHelper(std::shared_ptr consumer, size_t numElements, bool &stopReceive) +{ + Timer timer; + size_t remaining = numElements; + const int PER_RECEIVE_NUM = 1; + const int DEFAULT_WAIT_TIME = 1000; + uint64_t receiveSize = 0; + size_t receiveCount = 0; + while (remaining > 0 && !stopReceive) { + std::vector outElements; + RETURN_IF_NOT_OK(consumer->Receive(PER_RECEIVE_NUM, DEFAULT_WAIT_TIME, outElements)); + if (!outElements.empty()) { + remaining -= outElements.size(); + CHECK_FAIL_RETURN_STATUS(outElements.size() == PER_RECEIVE_NUM, K_RUNTIME_ERROR, "aaa"); + receiveSize += outElements[0].size; + receiveCount++; + RETURN_IF_NOT_OK(consumer->Ack(outElements.back().id)); + } + } + LOG(INFO) << "TotalReceiveSize:" << receiveSize << ", receiveCount: " << receiveCount; + CHECK_FAIL_RETURN_STATUS(remaining == 0, K_RUNTIME_ERROR, "failed to receive all data"); + return Status::OK(); +} + +void StreamMetaShmTest::BasicMPSCTest(int minEleSize, int maxEleSize) +{ + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(clients[K_TWO]->Subscribe("stream1", config, consumer1)); + + ProducerConf conf; + conf.maxStreamSize = 10 * MB; // 10MB + conf.streamMode = StreamMode::MPSC; + conf.pageSize = 1 * MB; + conf.retainForNumConsumers = 1; + conf.autoCleanup = true; + + std::vector> producers; + int producerCount = 3; + for (int i = 0; i < producerCount; i++) { + std::shared_ptr producer; + DS_ASSERT_OK(clients[i % clients.size()]->CreateProducer("stream1", producer, conf)); + producers.emplace_back(std::move(producer)); + } + + const int threadNum = producers.size() + 1; + const int totalEleSize = 1 * GB; + const size_t numElementsPerPub = totalEleSize / producers.size() / ((minEleSize + maxEleSize) / K_TWO); + const size_t totalEleNum = numElementsPerPub * producers.size(); + + ThreadPool pool(threadNum); + bool stopAll = false; + std::vector> producerFuts; + for (size_t i = 0; i < producers.size(); i++) { + auto p = producers[i]; + producerFuts.emplace_back(pool.Submit([this, p, &stopAll, i, numElementsPerPub, minEleSize, maxEleSize]() { + return SendHelper(p, numElementsPerPub, stopAll, i, minEleSize, maxEleSize); + })); + } + auto consumerFut = pool.Submit( + [this, consumer1, &stopAll, &totalEleNum]() { return ReceiveHelper(consumer1, totalEleNum, stopAll); }); + + Status lastPRc; + while (!producerFuts.empty()) { + for (auto itr = producerFuts.begin(); itr != producerFuts.end();) { + if (!itr->valid()) { + ++itr; + continue; + } + auto pRc = itr->get(); + if (pRc.IsError() && !stopAll) { + lastPRc = pRc; + stopAll = true; + } + itr = producerFuts.erase(itr); + } + sleep(1); + } + auto sRc = consumerFut.get(); + DS_ASSERT_OK(lastPRc); + DS_ASSERT_OK(sRc); + + for (auto &p : producers) { + DS_ASSERT_OK(p->Close()); + } + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(TryAndDeleteStream(clients[0], "stream1")); +} + +TEST_F(StreamMetaShmTest, TestRemoteConsumerNotReceive) +{ + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(clients[1]->Subscribe("stream1", config, consumer1)); + + int maxStreamSize = 4 * MB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.streamMode = StreamMode::SPSC; + conf.pageSize = 1 * MB; + std::shared_ptr producer1; + DS_ASSERT_OK(clients[0]->CreateProducer("stream1", producer1, conf)); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + + auto maxEleCountPerNode = maxStreamSize / sizeElement; + + for (auto i = 0ul; i < maxEleCountPerNode; ++i) { + DS_ASSERT_OK(producer1->Send(element1)); + } + + // Wait for all ele to be flushed to w1. After flushing, the shared memory occupied by this stream on w0 should be + // 0. + int waitFlushTimeSec = 2; + sleep(waitFlushTimeSec); + + for (auto i = 0ul; i < maxEleCountPerNode; ++i) { + DS_ASSERT_OK(producer1->Send(element1)); + } + + auto rc = producer1->Send(element1); + ASSERT_EQ(rc.GetCode(), K_OUT_OF_MEMORY); + + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(TryAndDeleteStream(clients[0], "stream1")); +} + +TEST_F(StreamMetaShmTest, TestRemoteConsumerReceiveAfterOOM) +{ + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + std::shared_ptr consumer1; + DS_ASSERT_OK(clients[1]->Subscribe("stream1", config, consumer1)); + + int maxStreamSize = 4 * MB; + int pageSize = 1 * MB; + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + conf.streamMode = StreamMode::SPSC; + conf.pageSize = 1 * MB; + std::shared_ptr producer1; + DS_ASSERT_OK(clients[0]->CreateProducer("stream1", producer1, conf)); + + const size_t sizeElement = 1 * KB; + std::string writeElement1 = RandomData().GetRandomString(sizeElement); + Element element1(reinterpret_cast(writeElement1.data()), writeElement1.size()); + + auto maxEleCountPerNode = maxStreamSize / sizeElement; + auto maxEleCountPerPage = pageSize / sizeElement; + + for (auto i = 0ul; i < maxEleCountPerNode; ++i) { + DS_ASSERT_OK(producer1->Send(element1)); + } + + // Wait for all ele to be flushed to w1. After flushing, the shared memory occupied by this stream on w0 should be + // 0. + int waitFlushTimeSec = 2; + sleep(waitFlushTimeSec); + + for (auto i = 0ul; i < maxEleCountPerNode; ++i) { + DS_ASSERT_OK(producer1->Send(element1)); + } + + auto rc = producer1->Send(element1); + ASSERT_EQ(rc.GetCode(), K_OUT_OF_MEMORY); + + std::vector eleToReceive; + DS_ASSERT_OK(consumer1->Receive(maxEleCountPerPage + 1, 0, eleToReceive)); + ASSERT_EQ(eleToReceive.size(), maxEleCountPerPage + 1); + // Ack a whole page of elements to ensure that the shared memory usage of the stream is reduced. + DS_ASSERT_OK(consumer1->Ack(eleToReceive.rbegin()->id)); + + sleep(waitFlushTimeSec); + + DS_ASSERT_OK(producer1->Send(element1)); + + DS_ASSERT_OK(producer1->Close()); + DS_ASSERT_OK(consumer1->Close()); + DS_ASSERT_OK(TryAndDeleteStream(clients[0], "stream1")); +} + +// All are normal elements +TEST_F(StreamMetaShmTest, DISABLED_EXCLUSIVE_LEVEL1_MPSCTest1) +{ + int minEleSize = 200; + int maxEleSize = 1000; + BasicMPSCTest(minEleSize, maxEleSize); +} + +// Mix of normal elements and big elements +TEST_F(StreamMetaShmTest, DISABLED_EXCLUSIVE_LEVEL1_MPSCTest2) +{ + int minEleSize = 1 * MB; + int maxEleSize = 3 * MB; + BasicMPSCTest(minEleSize, maxEleSize); +} + +// All are big elements +TEST_F(StreamMetaShmTest, DISABLED_EXCLUSIVE_LEVEL1_MPSCTest3) +{ + int minEleSize = 2 * MB; + int maxEleSize = 3 * MB; + BasicMPSCTest(minEleSize, maxEleSize); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_multi_tenant.cpp b/tests/st/client/stream_cache/stream_multi_tenant.cpp new file mode 100644 index 0000000..7822f8a --- /dev/null +++ b/tests/st/client/stream_cache/stream_multi_tenant.cpp @@ -0,0 +1,299 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for multi-tenant + */ + +#include +#include +#include "common.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +const uint32_t WORKER_NUM = 6; + +class StreamMultiTenant: public SCClientCommon { +public: + Status CreateConsumer(std::shared_ptr client, const std::string &streamName, + const std::string &subName, std::shared_ptr &consumer) + { + SubscriptionConfig config(subName, SubscriptionType::STREAM); + return client->Subscribe(streamName, config, consumer); + } + + Status CreateProducer(std::shared_ptr client, const std::string &streamName, + std::shared_ptr &producer) + { + const int64_t autoFlushTime = 10 * 1000; // 10s; + const int64_t pageSize = 4 * 1024; // The size of page is 4096 bytes + ProducerConf conf = { .delayFlushTime = autoFlushTime, .pageSize = pageSize, + .maxStreamSize = TEST_STREAM_SIZE }; + return client->CreateProducer(streamName, producer, conf); + } + + using VECPRODUCER = std::vector>; + void SendAndReceiveData(const size_t dataSize, std::shared_ptr producer, + VECPRODUCER receiveCon, VECPRODUCER unReceiveCon) + { + std::string data = RandomData().GetRandomString(dataSize); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_OK(producer->Send(element)); + sleep(1); + + for (auto consumer: receiveCon) { + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + ASSERT_EQ(memcmp(outElements[0].ptr, data.c_str(), outElements[0].size), 0); + } + + for (auto consumer: unReceiveCon) { + std::vector outElements; + ASSERT_EQ(consumer->Receive(1, 0, outElements), Status::OK()); + EXPECT_EQ(outElements.size(), (size_t)0); + } + } + + void QueryProducerAndConsumer(std::shared_ptr queryProducer, uint64_t producerCnt, + std::shared_ptr queryConsumer, uint64_t consumerCnt, std::string streamName) + { + uint64_t producersCount = 0; + uint64_t consumersCount = 0; + DS_ASSERT_OK(queryProducer->QueryGlobalProducersNum(streamName, producersCount)); + ASSERT_EQ(producersCount, producerCnt); + DS_ASSERT_OK(queryConsumer->QueryGlobalConsumersNum(streamName, consumersCount)); + ASSERT_EQ(consumersCount, consumerCnt); + } + + void TearDown() override + { + client0_.reset(); + client1_.reset(); + client2_.reset(); + client3_.reset(); + client4_.reset(); + client5_.reset(); + ExternalClusterTest::TearDown(); + } + + void IdenticalStreamNameDataIsolation(std::string streamName) + { + std::shared_ptr client0Pro, client2Pro, client3Pro; + std::shared_ptr client1Con, client4Con, client5Con; + DS_ASSERT_OK(CreateProducer(client0_, streamName, client0Pro)); + DS_ASSERT_OK(CreateProducer(client2_, streamName, client2Pro)); + DS_ASSERT_OK(CreateProducer(client3_, streamName, client3Pro)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "subname1", client1Con)); + DS_ASSERT_OK(CreateConsumer(client4_, streamName, "subname2", client4Con)); + DS_ASSERT_OK(CreateConsumer(client5_, streamName, "subname3", client5Con)); + + // 1. Send small and large date by different tenant client with the same stream name. + const size_t smallElementSize = 10; + SendAndReceiveData(smallElementSize, client0Pro, VECPRODUCER{client1Con}, VECPRODUCER{client4Con, client5Con}); + SendAndReceiveData(smallElementSize, client2Pro, VECPRODUCER{client4Con, client5Con}, VECPRODUCER{client1Con}); + + // 2. Query the num of the producer and consumer by stream name. + QueryProducerAndConsumer(client0_, 1, client1_, 1, streamName); // Expect 1 producer and 1 consumer + QueryProducerAndConsumer(client2_, 2, client3_, 2, streamName); // Expect 2 producers and 2 consumers + + // 3. Query the num after closing producer and consumer + DS_ASSERT_OK(client2Pro->Close()); + DS_ASSERT_OK(client4Con->Close()); + QueryProducerAndConsumer(client0_, 1, client1_, 1, streamName); + QueryProducerAndConsumer(client3_, 1, client5_, 1, streamName); + + // 4. Send date and query num after deleting stream + DS_ASSERT_OK(client0Pro->Close()); + DS_ASSERT_OK(client1Con->Close()); + DS_ASSERT_OK(client0_->DeleteStream(streamName)); + QueryProducerAndConsumer(client0_, 0, client1_, 0, streamName); // Expect no producer and consumer + QueryProducerAndConsumer(client3_, 1, client5_, 1, streamName); // Expect 1 producer and 1 consumer + SendAndReceiveData(smallElementSize, client3Pro, VECPRODUCER{client5Con}, VECPRODUCER{}); + std::string data = RandomData().GetRandomString(10); + Element element(reinterpret_cast(&data.front()), data.size()); + DS_ASSERT_NOT_OK(client0Pro->Send(element)); + } + + void DifferentStreamNameDataIsolation(std::string streamName) + { + std::string streamName1 = streamName + std::to_string(1); + std::string streamName2 = streamName + std::to_string(2); + std::shared_ptr client0Pro1, client0Pro2; + std::shared_ptr client1Con1, client1Con2; + DS_ASSERT_OK(CreateProducer(client0_, streamName1, client0Pro1)); + DS_ASSERT_OK(CreateProducer(client0_, streamName2, client0Pro2)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName1, "subname", client1Con1)); + DS_ASSERT_OK(CreateConsumer(client1_, streamName2, "subname", client1Con2)); + + // 1. Send small and large date by different tenant client with the different stream name. + const size_t smallElementSize = 10; + SendAndReceiveData(smallElementSize, client0Pro1, VECPRODUCER{client1Con1}, VECPRODUCER{client1Con2}); + + // 2. Query the num of the producer and consumer by stream name. + QueryProducerAndConsumer(client0_, 1, client1_, 1, streamName1); // Expect 1 producer and 1 consumer + QueryProducerAndConsumer(client1_, 1, client0_, 1, streamName2); // Expect 1 producer and 1 consumer + + // 3. Query the num after closing producer and consumer + DS_ASSERT_OK(client0Pro1->Close()); + DS_ASSERT_OK(client1Con1->Close()); + DS_ASSERT_OK(client0_->DeleteStream(streamName1)); + QueryProducerAndConsumer(client0_, 0, client1_, 0, streamName1); // Expect no producer and consumer + QueryProducerAndConsumer(client1_, 1, client0_, 1, streamName2); // Expect 1 producer and 1 consumer + + // 4. Query the num after closing producer and consumer + DS_ASSERT_OK(client0Pro2->Close()); + DS_ASSERT_OK(client1Con2->Close()); + DS_ASSERT_OK(client1_->DeleteStream(streamName2)); + QueryProducerAndConsumer(client0_, 0, client1_, 0, streamName1); // Expect no producer and consumer + QueryProducerAndConsumer(client1_, 0, client0_, 0, streamName2); // Expect no producer and consumer + } + +protected: + std::shared_ptr client0_; + std::shared_ptr client1_; + std::shared_ptr client2_; + std::shared_ptr client3_; + std::shared_ptr client4_; + std::shared_ptr client5_; +}; + +class StreamMultiTenantTokenAuth : public StreamMultiTenant { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = WORKER_NUM; + opts.workerGflagParams = " -authorization_enable=true -v=2 " + "-page_size=4096 -shared_memory_size_mb=10240 "; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + for (size_t i = 0; i < WORKER_NUM; i++) { + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, i)); + if (i <= 1) { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "worker.auth", "100*return(Token, TenantId1)")); + } else { + DS_ASSERT_OK(cluster_->SetInjectAction(WORKER, i, "worker.auth", "100*return(Token, TenantId2)")); + } + } + + DS_ASSERT_OK(InitClient(0, client0_)); // Init client to worker 0 + DS_ASSERT_OK(InitClient(1, client1_)); // Init client to worker 1 + DS_ASSERT_OK(InitClient(2, client2_)); // Init client to worker 2 + DS_ASSERT_OK(InitClient(3, client3_)); // Init client to worker 3 + DS_ASSERT_OK(InitClient(4, client4_)); // Init client to worker 4 + DS_ASSERT_OK(InitClient(5, client5_)); // Init client to worker 5 + } + +private: + Status InitClient(int index, std::shared_ptr &client) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(index, workerAddress)); + LOG(INFO) << "worker index " << index << ": " << workerAddress.ToString(); + ConnectOptions connectOptions = { .host = workerAddress.Host(), .port = workerAddress.Port() }; + if (index <= 1) { + connectOptions.SetAkSkAuth(accessKey_, secretKey_, "TenantId1"); + } else { + connectOptions.SetAkSkAuth(accessKey_, secretKey_, "TenantId2"); + } + client = std::make_shared(connectOptions); + return client->Init(); + } + + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(StreamMultiTenantTokenAuth, IdenticalStreamNameDataIsolation) +{ + std::string streamName = "MultiTenantTokenAuthIdenticalName"; + IdenticalStreamNameDataIsolation(streamName); +} + +TEST_F(StreamMultiTenantTokenAuth, DISABLED_DifferentStreamNameDataIsolation) +{ + std::string streamName = "MultiTenantTokenAuthDiffName"; + DifferentStreamNameDataIsolation(streamName); +} + +class StreamMultiTenantAkSkAuth : public StreamMultiTenant { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numEtcd = 1; + opts.numWorkers = WORKER_NUM; + opts.workerGflagParams = "-page_size=4096 -shared_memory_size_mb=10240 -v=2"; + opts.systemAccessKey = accessKey_; + opts.systemSecretKey = secretKey_; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + for (size_t i = 0; i < WORKER_NUM; i++) { + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, i)); + } + DS_ASSERT_OK(InitClient(0, client0_, "TenantId1")); // Init client to worker 0 + DS_ASSERT_OK(InitClient(1, client1_, "TenantId1")); // Init client to worker 1 + DS_ASSERT_OK(InitClient(2, client2_, "TenantId2")); // Init client to worker 2 + DS_ASSERT_OK(InitClient(3, client3_, "TenantId2")); // Init client to worker 3 + DS_ASSERT_OK(InitClient(4, client4_, "TenantId2")); // Init client to worker 4 + DS_ASSERT_OK(InitClient(5, client5_, "TenantId2")); // Init client to worker 5 + } + +protected: + Status InitClient(int index, std::shared_ptr &client, std::string tenantId) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(index, workerAddress)); + LOG(INFO) << "worker index " << index << ": " << workerAddress.ToString(); + ConnectOptions connectOptions = { .host = workerAddress.Host(), .port = workerAddress.Port() }; + connectOptions.SetAkSkAuth(accessKey_, secretKey_, tenantId); + client = std::make_shared(connectOptions); + return client->Init(); + } + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(StreamMultiTenantAkSkAuth, IdenticalStreamNameDataIsolation) +{ + std::string streamName = "MultiTenantAkSkAuthIdenticalName"; + IdenticalStreamNameDataIsolation(streamName); +} + +TEST_F(StreamMultiTenantAkSkAuth, DifferentStreamNameDataIsolation) +{ + std::string streamName = "MultiTenantAkSkAuthDiffName"; + DifferentStreamNameDataIsolation(streamName); +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_observability_test.cpp b/tests/st/client/stream_cache/stream_observability_test.cpp new file mode 100644 index 0000000..e52fadd --- /dev/null +++ b/tests/st/client/stream_cache/stream_observability_test.cpp @@ -0,0 +1,432 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test observability. + */ + +#include +#include +#ifdef BUILD_OBSERVABILITY +#include +#endif + +#include "common.h" +#include "common/stream_cache/element_generator.h" +#include "common/stream_cache/stream_common.h" +#include "sc_client_common.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/common/metrics/res_metric_collector.h" + +DS_DECLARE_string(log_dir); + +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +#ifdef BUILD_OBSERVABILITY +using json = nlohmann::json; +#endif + +class StreamObservabilityTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; + + void SetUp() override; + + void TearDown() override; + + static std::string streamName_; + +protected: + void GetResMonitorLogInfo(int index, const std::string &fileName, std::vector &infos); + std::string GetMetric(const std::vector &workerMetrics, ResMetricName metric); + Status Produce(std::shared_ptr &producer, std::string producerName, int numEle, uint64_t eleSz, + int timeout = 0); + Status CreateProducerAndConsumer(const std::shared_ptr &client, + std::vector> producerDesc, + std::vector> &producers, + std::vector> consumerDesc, + std::vector> &consumers, bool autoCleanup); + + // Mock producer worker. + HostPort w1Addr_; + HostPort w2Addr_; + + std::shared_ptr w1Client_ = nullptr; + std::shared_ptr w2Client_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + const int NUM_WORKER = 2; + const int V_LEVEL = 2; + const int SLEEP_TIME = 2; + const int DEFAULT_NUM_ELEMENT = 20; + const int TEST_ELEMENT_SIZE = 2 * KB - 128; + const int MAX_STREAM_SIZE = 2 * MB; + const int DELAY_FLUSH_TIME = 3000; +}; +std::string StreamObservabilityTest::streamName_ = "stream"; + +void StreamObservabilityTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = NUM_WORKER; + opts.enableDistributedMaster = "false"; + opts.workerGflagParams = " -page_size=" + std::to_string(PAGE_SIZE) + + " -log_monitor=true -log_monitor_interval_ms=2000 -shared_memory_size_mb=64"; + opts.numRpcThreads = 0; + opts.vLogLevel = V_LEVEL; + SCClientCommon::SetClusterSetupOptions(opts); +} + +void StreamObservabilityTest::SetUp() +{ + ExternalClusterTest::SetUp(); + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, w1Addr_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, w2Addr_)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 0)); + DS_ASSERT_OK(cluster_->WaitNodeReady(WORKER, 1)); + // Worker 1. + InitStreamClient(0, w1Client_); + // Worker 2. + InitStreamClient(1, w2Client_); + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; +} + +void StreamObservabilityTest::TearDown() +{ + w1Client_ = nullptr; + w2Client_ = nullptr; + ExternalClusterTest::TearDown(); +} + +void StreamObservabilityTest::GetResMonitorLogInfo(int index, const std::string &fileName, + std::vector &infos) +{ + std::string fullName = FormatString("%s/../worker%d/log/%s", FLAGS_log_dir.c_str(), index, fileName); + std::ifstream ifs(fullName); + ASSERT_TRUE(ifs.is_open()); + std::string line; + std::streampos prev = ifs.tellg(); + std::streampos pos = ifs.tellg(); + // Get the last line + while (std::getline(ifs, line)) { + prev = pos; + pos = ifs.tellg(); + } + ifs.clear(); + ifs.seekg(prev); + std::getline(ifs, line); + infos = Split(line, " | "); + const int ignoreCount = 7; + ASSERT_TRUE(infos.size() == static_cast(ResMetricName::RES_METRICS_END) + ignoreCount); + infos.erase(infos.begin(), infos.begin() + ignoreCount); +}; + +std::string StreamObservabilityTest::GetMetric(const std::vector &workerMetrics, ResMetricName metric) +{ + int index = (int)metric - (int)ResMetricName::SHARED_MEMORY; + return workerMetrics[index]; +} + +Status StreamObservabilityTest::CreateProducerAndConsumer(const std::shared_ptr &client, + std::vector> producerDesc, + std::vector> &producers, + std::vector> consumerDesc, + std::vector> &consumers, + bool autoCleanup) +{ + ProducerConf conf; + conf.delayFlushTime = DELAY_FLUSH_TIME; + conf.pageSize = PAGE_SIZE; // 4K + conf.maxStreamSize = MAX_STREAM_SIZE; + conf.autoCleanup = autoCleanup; + for (const auto &kv : producerDesc) { + for (size_t i = 0; i < kv.second; i++) { + std::shared_ptr producer; + RETURN_IF_NOT_OK(client->CreateProducer(kv.first, producer, conf)); + producers.emplace_back(producer); + } + } + + for (const auto &kv : consumerDesc) { + std::shared_ptr consumer; + SubscriptionConfig config(kv.second, SubscriptionType::STREAM); + RETURN_IF_NOT_OK(client->Subscribe(kv.first, config, consumer, false)); + consumers.emplace_back(consumer); + } + return Status::OK(); +} + +Status StreamObservabilityTest::Produce(std::shared_ptr &producer, std::string producerName, int numEle, + uint64_t eleSz, int timeout) +{ + Status stat = Status::OK(); + ElementGenerator elementGenerator(eleSz, eleSz); + auto strs = elementGenerator.GenElements(producerName, numEle, 1); + Status rc; + + for (int i = 0; i < numEle; i++) { + if (timeout) { + rc = producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()), timeout); + } else { + rc = producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size())); + } + if (rc.IsError()) { + stat = rc; + } + } + return stat; +} + +TEST_F(StreamObservabilityTest, DISABLED_StreamMetricsLog) +{ + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName_, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub1", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName_, config, consumer)); + const uint32_t eleSz = 512; + const uint32_t eleNum = 1000; + ElementGenerator elementGenerator(eleSz); + auto strs = elementGenerator.GenElements("producer1", eleNum, 1); + + auto sender = [&producer, &strs]() { + for (uint32_t i = 0; i < eleNum; i++) { + DS_ASSERT_OK(producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()))); + } + }; + + const uint32_t recvNum = 100; + std::vector outElements; + auto recver = [eleNum, recvNum, &consumer, &outElements]() { + for (size_t i = 0; i < eleNum / recvNum; ++i) { + DS_ASSERT_OK(consumer->Receive(recvNum, 5000, outElements)); + ASSERT_EQ(outElements.size(), recvNum); + DS_ASSERT_OK(consumer->Ack(recvNum)); + } + }; + + std::thread sendThr(sender); + std::thread recvThr(recver); + sendThr.join(); + recvThr.join(); + + std::vector infos; + GetResMonitorLogInfo(1, "resource.log", infos); + // expected pattern/format of stream metrics + re2::RE2 strNumPattern("[0-9 ]+"); + ASSERT_TRUE(re2::RE2::FullMatch(infos[static_cast(ResMetricName::STREAM_COUNT)], strNumPattern)); +} + +TEST_F(StreamObservabilityTest, StreamCount) +{ + std::vector> producers; + std::vector> consumers; + const int streamNum = 2; + DS_ASSERT_OK(CreateProducerAndConsumer( + w1Client_, { { "stream1", 3 } }, producers, + { { "stream1", "sub1" }, { "stream2", "sub1" }, { "stream1", "sub2" }, { "stream2", "sub2" } }, consumers, + true)); + + DS_ASSERT_OK( + CreateProducerAndConsumer(w2Client_, { { "stream2", 3 }, { "stream3", 1 } }, producers, {}, consumers, true)); + + sleep(SLEEP_TIME); + + std::vector worker0Metrics; + GetResMonitorLogInfo(1, "resource.log", worker0Metrics); + + std::vector worker1Metrics; + GetResMonitorLogInfo(1, "resource.log", worker1Metrics); + ASSERT_EQ(std::stoi(GetMetric(worker0Metrics, ResMetricName::STREAM_COUNT)), streamNum); + ASSERT_EQ(std::stoi(GetMetric(worker1Metrics, ResMetricName::STREAM_COUNT)), streamNum); + + for (auto &producer : producers) { + DS_ASSERT_OK(producer->Close()); + } + for (auto &consumer : consumers) { + DS_ASSERT_OK(consumer->Close()); + } + sleep(SLEEP_TIME); + + GetResMonitorLogInfo(1, "resource.log", worker0Metrics); + GetResMonitorLogInfo(1, "resource.log", worker1Metrics); + ASSERT_EQ(std::stoi(GetMetric(worker0Metrics, ResMetricName::STREAM_COUNT)), 0); + ASSERT_EQ(std::stoi(GetMetric(worker1Metrics, ResMetricName::STREAM_COUNT)), 0); +} + +TEST_F(StreamObservabilityTest, StreamSharedMemory) +{ + std::vector> producers; + std::vector> consumers; + std::vector worker0Metrics; + std::unordered_map> sc0Metrics; + const int theoreticalMem = 5 * 40 * KB; + const int memoryLimit = 64 * MB; + + DS_ASSERT_OK(CreateProducerAndConsumer(w1Client_, { { "stream1", 3 }, { "stream2", 2 } }, producers, + { { "stream1", "sub1" }, { "stream2", "sub1" }, { "stream2", "sub2" } }, + consumers, false)); + + int i = 0; + // Each producer sends 20 elements -> 40KB + // Each element is slightly less than 2KB to account for page header overhead + for (auto &producer : producers) { + i++; + DS_ASSERT_OK(Produce(producer, "producer" + std::to_string(i), DEFAULT_NUM_ELEMENT, TEST_ELEMENT_SIZE)); + } + sleep(SLEEP_TIME); + + GetResMonitorLogInfo(0, "resource.log", worker0Metrics); + // TotalStreamMemoryUsed is real memory size allocated, might be larger than theoretical + std::vector metrics = Split(GetMetric(worker0Metrics, ResMetricName::SHARED_MEMORY), "/"); + int streamMemoryUsage = std::stoi(metrics[4]); + int streamMemoryLimit = std::stoi(metrics[5]); + ASSERT_TRUE(streamMemoryUsage >= theoreticalMem); + ASSERT_TRUE(streamMemoryLimit <= memoryLimit); +} +#ifdef BUILD_OBSERVABILITY +class StreamYrObservabilityTest : public StreamObservabilityTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override; + void GetObservabilityMetricsLog0(int index); + void GetObservabilityMetricsLog1(int index); +}; + +void StreamYrObservabilityTest::SetClusterSetupOptions(ExternalClusterOptions &opts) +{ + opts.numEtcd = 1; + opts.numWorkers = NUM_WORKER; + opts.enableDistributedMaster = "false"; + opts.workerGflagParams = " -page_size=" + std::to_string(PAGE_SIZE) + + " -log_monitor=true -log_monitor_interval_ms=100 -log_monitor_exporter=yr_file_exporter"; + opts.numRpcThreads = 0; + opts.vLogLevel = V_LEVEL; + SCClientCommon::SetClusterSetupOptions(opts); +} + +void StreamYrObservabilityTest::GetObservabilityMetricsLog0(int index) +{ + std::string fullName = + FormatString("%s/../worker%d/log/observability/yr_metrics.data", FLAGS_log_dir.c_str(), index); + std::ifstream ifs(fullName, std::ios::binary); + ASSERT_TRUE(ifs.is_open()); + std::string line; + bool found = false; + while (std::getline(ifs, line)) { + ASSERT_TRUE(!line.empty()); + json jsonLine = json::parse(line); + const std::string &name = jsonLine["name"]; + const std::string &value = jsonLine["value"]; + if (name == "STREAM_COUNT_STREAM_COUNT" && std::stoi(value) > 0) { + found = true; + break; + } + } + ASSERT_TRUE(found); +} + +TEST_F(StreamYrObservabilityTest, StreamMetricsLog0) +{ + std::shared_ptr producer; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName_, producer, defaultProducerConf_)); + std::shared_ptr consumer; + SubscriptionConfig config("sub0", SubscriptionType::STREAM); + DS_ASSERT_OK(w2Client_->Subscribe(streamName_, config, consumer)); + const uint32_t eleSz = 512; + const uint32_t eleNum = 1000; + ElementGenerator elementGenerator(eleSz); + auto strs = elementGenerator.GenElements("producer1", eleNum, 1); + + auto sender = [&producer, &strs]() { + for (uint32_t i = 0; i < eleNum; i++) { + DS_ASSERT_OK(producer->Send(Element((uint8_t *)strs[i].data(), strs[i].size()))); + } + }; + + const uint32_t recvNum = 100; + std::vector outElements; + auto recver = [eleNum, recvNum, &consumer, &outElements]() { + for (size_t i = 0; i < eleNum / recvNum; ++i) { + DS_ASSERT_OK(consumer->Receive(recvNum, 5000, outElements)); + ASSERT_EQ(outElements.size(), recvNum); + DS_ASSERT_OK(consumer->Ack(recvNum)); + } + }; + + std::thread sendThr(sender); + std::thread recvThr(recver); + sendThr.join(); + recvThr.join(); + + GetObservabilityMetricsLog0(0); + GetObservabilityMetricsLog0(1); +} + +TEST_F(StreamYrObservabilityTest, StreamMetricsLog1) +{ + std::shared_ptr producer0; + std::string streamName0 = streamName_ + "0"; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName0, producer0, defaultProducerConf_)); + std::shared_ptr consumer0; + DS_ASSERT_OK(w2Client_->Subscribe(streamName0, SubscriptionConfig("sub0", SubscriptionType::STREAM), consumer0)); + + std::shared_ptr producer1; + std::string streamName1 = streamName_ + "1"; + DS_ASSERT_OK(w1Client_->CreateProducer(streamName1, producer1, defaultProducerConf_)); + std::shared_ptr consumer1; + DS_ASSERT_OK(w2Client_->Subscribe(streamName1, SubscriptionConfig("sub1", SubscriptionType::STREAM), consumer1)); + + const uint32_t eleSz = 512; + const uint32_t eleNum = 1000; + ElementGenerator elementGenerator(eleSz); + auto strs = elementGenerator.GenElements("producer1", eleNum, 1); + + auto sender = [&producer0, &producer1, &strs]() { + for (uint32_t i = 0; i < eleNum; i++) { + DS_ASSERT_OK(producer0->Send(Element((uint8_t *)strs[i].data(), strs[i].size()))); + DS_ASSERT_OK(producer1->Send(Element((uint8_t *)strs[i].data(), strs[i].size()))); + } + }; + + const uint32_t recvNum = 100; + std::vector outElements0; + std::vector outElements1; + auto recver = [eleNum, recvNum, &consumer0, &consumer1, &outElements0, &outElements1]() { + for (size_t i = 0; i < eleNum / recvNum; ++i) { + DS_ASSERT_OK(consumer0->Receive(recvNum, 5000, outElements0)); + ASSERT_EQ(outElements0.size(), recvNum); + DS_ASSERT_OK(consumer0->Ack(recvNum)); + DS_ASSERT_OK(consumer1->Receive(recvNum, 5000, outElements1)); + ASSERT_EQ(outElements1.size(), recvNum); + DS_ASSERT_OK(consumer1->Ack(recvNum)); + } + }; + + std::thread sendThr(sender); + std::thread recvThr(recver); + sendThr.join(); + recvThr.join(); + + GetObservabilityMetricsLog0(0); + GetObservabilityMetricsLog0(1); +} +#endif +} // namespace st +} // namespace datasystem diff --git a/tests/st/client/stream_cache/stream_size_test.cpp b/tests/st/client/stream_cache/stream_size_test.cpp new file mode 100644 index 0000000..7b150b1 --- /dev/null +++ b/tests/st/client/stream_cache/stream_size_test.cpp @@ -0,0 +1,241 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Unit test for stream cache + */ +#include + +#include "common.h" +#include "sc_client_common.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream_client.h" + +using namespace datasystem::client::stream_cache; +namespace datasystem { +namespace st { +constexpr uint64_t MB = 1024 * 1024; +class StreamSizeTest : public SCClientCommon { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + const int vLogLevel = 3; + const int numWorkers = 3; + opts.masterIdx = 2; + opts.numEtcd = 1; + opts.numWorkers = numWorkers; + opts.vLogLevel = vLogLevel; + SCClientCommon::SetClusterSetupOptions(opts); + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + int index = 0; + DS_ASSERT_OK(InitClient(index++, client1_)); + DS_ASSERT_OK(InitClient(index++, client2_)); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + ExternalClusterTest::TearDown(); + } + +protected: + Status InitClient(int index, std::shared_ptr &client) + { + HostPort workerAddress; + RETURN_IF_NOT_OK(cluster_->GetWorkerAddr(index, workerAddress)); + LOG(INFO) << "worker index " << index << ": " << workerAddress.ToString(); + ConnectOptions options; + options.host = workerAddress.Host(); + options.port = workerAddress.Port(); + options.secretKey = secretKey_; + options.accessKey = accessKey_; + client = std::make_shared(options); + return client->Init(); + } + + Status CreateConsumer(std::shared_ptr client, const std::string &streamName, + const std::string &subName, std::shared_ptr &consumer) + { + SubscriptionConfig config(subName, SubscriptionType::STREAM); + return client->Subscribe(streamName, config, consumer); + } + + Status CreateProducer(std::shared_ptr client, const std::string &streamName, uint64_t maxStreamSize, + std::shared_ptr &producer) + { + ProducerConf conf; + conf.maxStreamSize = maxStreamSize; + return client->CreateProducer(streamName, producer, conf); + } + + std::shared_ptr client1_; + std::shared_ptr client2_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(StreamSizeTest, TestCreateProducerWithDiffSize) +{ + std::string streamName = "CreateProducerWithDiffSize"; + std::shared_ptr producer1; + std::shared_ptr producer2; + std::shared_ptr producer3; + std::shared_ptr producer4; + DS_ASSERT_OK(CreateProducer(client1_, streamName, 10 * MB, producer1)); + DS_ASSERT_NOT_OK(CreateProducer(client1_, streamName, 12 * MB, producer2)); + DS_ASSERT_NOT_OK(CreateProducer(client2_, streamName, 12 * MB, producer3)); + DS_ASSERT_OK(CreateProducer(client1_, streamName, 10 * MB, producer4)); +} + +TEST_F(StreamSizeTest, TestCreateProducerThenConsumer) +{ + std::string streamName = "CreateProducerThenConsumer"; + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client1_, streamName, 10 * MB, producer)); + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "sub", consumer)); + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(StreamSizeTest, TestCreateConsumerThenProducer) +{ + std::string streamName = "CreateConsumerThenProducer"; + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "sub", consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client1_, streamName, 10 * MB, producer)); + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + DS_ASSERT_OK(consumer->Receive(1, 0, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(StreamSizeTest, TestCreateProducerThenConsumerTwoWorker) +{ + std::string streamName = "CreateProducerThenConsumerTwoWorker"; + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client1_, streamName, 10 * MB, producer)); + LOG(INFO) << "Created producer"; + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client2_, streamName, "sub", consumer)); + LOG(INFO) << "Created consumer"; + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + + LOG(INFO) << "do send"; + DS_ASSERT_OK(producer->Send(element)); + LOG(INFO) << "do flush"; + std::vector outElements; + const uint32_t waitTime = 3000; // 3s; + LOG(INFO) << "do consumer receive"; + DS_ASSERT_OK(consumer->Receive(1, waitTime, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + LOG(INFO) << "do consumer and producer close"; + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(StreamSizeTest, TestCreateConsumerThenProducerTwoWorker) +{ + std::string streamName = "CreateConsumerThenProducerTwoWorker"; + std::shared_ptr consumer; + DS_ASSERT_OK(CreateConsumer(client1_, streamName, "sub", consumer)); + std::shared_ptr producer; + DS_ASSERT_OK(CreateProducer(client2_, streamName, 10 * MB, producer)); + + std::string str = "hello hello"; + Element element(reinterpret_cast(const_cast(str.data())), str.length()); + + DS_ASSERT_OK(producer->Send(element)); + std::vector outElements; + const uint32_t waitTime = 3000; // 3s; + DS_ASSERT_OK(consumer->Receive(1, waitTime, outElements)); + ASSERT_EQ(outElements.size(), 1ul); + DS_ASSERT_OK(consumer->Close()); + DS_ASSERT_OK(producer->Close()); +} + +TEST_F(StreamSizeTest, LEVEL1_TestCreateProducerAfterMasterRestart) +{ + std::shared_ptr producer1; + std::string streamName = "CreateProducerAfterMasterRestart"; + DS_ASSERT_OK(CreateProducer(client1_, streamName, 10 * MB, producer1)); + + DS_ASSERT_OK(cluster_->StartNode(ClusterNodeType::WORKER, 2, "")); + DS_ASSERT_OK(cluster_->WaitNodeReady(ClusterNodeType::WORKER, 2)); + + std::shared_ptr producer2; + DS_ASSERT_NOT_OK(CreateProducer(client2_, streamName, 12 * MB, producer2)); +} + +TEST_F(StreamSizeTest, TestCreateProducerWorkerFailed) +{ + std::string streamName = "CreateProducerWorkerFailed"; + DS_ASSERT_OK(cluster_->SetInjectAction(ClusterNodeType::WORKER, 0, "worker.CreateProducer.beforeSendToMaster", + "1*return(K_RUNTIME_ERROR)")); + std::shared_ptr producer1; + DS_ASSERT_NOT_OK(CreateProducer(client1_, streamName, 10 * MB, producer1)); + + std::shared_ptr producer2; + DS_ASSERT_OK(CreateProducer(client1_, streamName, 12 * MB, producer2)); +} + +class StreamSizeCentralizedTest : public StreamSizeTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.enableDistributedMaster = "false"; + StreamSizeTest::SetClusterSetupOptions(opts); + } +}; + +TEST_F(StreamSizeCentralizedTest, TestCreateProducerMasterFailed) +{ + // Master has to be centralized for this testcase to work. + DS_ASSERT_OK(cluster_->SetInjectAction( + ClusterNodeType::WORKER, 2, "master.PubIncreaseNodeImpl.beforeSendNotification", "1*return(K_RUNTIME_ERROR)")); + + std::shared_ptr producer1; + DS_ASSERT_NOT_OK(CreateProducer(client1_, "stream1", 10 * MB, producer1)); + + std::shared_ptr producer2; + DS_ASSERT_OK(CreateProducer(client1_, "stream1", 12 * MB, producer2)); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/client_c_api/stream_cache/stream_cache_test.cpp b/tests/st/client_c_api/stream_cache/stream_cache_test.cpp new file mode 100644 index 0000000..b3ceed1 --- /dev/null +++ b/tests/st/client_c_api/stream_cache/stream_cache_test.cpp @@ -0,0 +1,290 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: test cases for c client api. + */ + +#include +#include +#include + +#include + +#include "common.h" +#include "datasystem/common/log/log.h" +#include "datasystem/common/util/random_data.h" +#include "datasystem/stream/element.h" +#include "datasystem/stream_client.h" + +namespace datasystem { +namespace st { +class StreamCacheTest : public ExternalClusterTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + int numWorkers = 2; + opts.numWorkers = numWorkers; + opts.numMasters = 1; + opts.numEtcd = 1; + opts.workerGflagParams = "-shared_memory_size_mb=10000"; + opts.isStreamCacheCase = true; + } + + void SetUp() override + { + ClusterTest::SetUp(); + HostPort srcWorkerAddress; + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, srcWorkerAddress)); + client0_ = CreateStreamCacheClient(srcWorkerAddress.Host(), srcWorkerAddress.Port(), 60000, "", "", "", "", ak_, + sk_, "", "", "", "", "true"); + ASSERT_EQ(StreamConnectWorker(client0_, false).code, K_OK); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, srcWorkerAddress)); + client1_ = CreateStreamCacheClient(srcWorkerAddress.Host(), srcWorkerAddress.Port(), 60000, "", "", "", "", ak_, + sk_, "", "", "", "", "true"); + ASSERT_EQ(StreamConnectWorker(client1_, false).code, K_OK); + } + + void TearDown() override + { + if (client0_ != nullptr) { + StreamFreeClient(client0_); + } + if (client1_ != nullptr) { + StreamFreeClient(client1_); + } + } + + StreamClient_p CreateStreamCacheClient(const std::string &workerHost, const int workerPort, const int timeOut, + const std::string &token, const std::string &clientPublicKey, + const std::string &clientPrivateKey, const std::string &serverPublicKey, + const std::string &accessKey, const std::string &secretKey, + const std::string &oAuthClientid, const std::string &oAuthClientSecret, + const std::string &oAuthUrl, const std::string &tenantId, + const std::string &enableCrossNodeConnection) + { + (void)oAuthClientid; + (void)oAuthClientSecret; + (void)oAuthUrl; + (void)token; + return StreamCreateClient(workerHost.c_str(), workerPort, timeOut, + clientPublicKey.c_str(), clientPublicKey.length(), clientPrivateKey.c_str(), + clientPrivateKey.length(), serverPublicKey.c_str(), serverPublicKey.length(), + accessKey.c_str(), accessKey.length(), secretKey.c_str(), secretKey.length(), + tenantId.c_str(), tenantId.length(), enableCrossNodeConnection.c_str()); + } + + void Subscribe(StreamClient_p clientPtr, const std::string &streamName, const std::string &subName, + Consumer_p *consumer) + { + auto rc = StreamSubscribe(clientPtr, streamName.c_str(), streamName.length(), subName.c_str(), subName.length(), + SubType::STREAM, false, false, SubscriptionConfig::SC_CACHE_CAPACITY, + SubscriptionConfig::SC_CACHE_LWM, consumer); + ASSERT_EQ(rc.code, K_OK); + } + + void CreateProducer(StreamClient_p clientPtr, const std::string &streamName, int64_t delayFlushTime, + int64_t pageSize, uint64_t maxStreamSize, bool autoCleanup, Producer_p *producer) + { + auto rc = StreamCreateProducer(clientPtr, streamName.c_str(), streamName.length(), delayFlushTime, pageSize, + maxStreamSize, autoCleanup, producer); + ASSERT_EQ(rc.code, K_OK); + } + + void CreateProducerWithConfig(StreamClient_p clientPtr, const std::string &streamName, int64_t delayFlushTime, + int64_t pageSize, uint64_t maxStreamSize, bool autoCleanup, + uint64_t retainForNumConsumers, bool encryptStream, uint64_t reserveSize, + Producer_p *producer) + { + auto rc = StreamCreateProducerWithConfig(clientPtr, streamName.c_str(), streamName.length(), delayFlushTime, + pageSize, maxStreamSize, autoCleanup, retainForNumConsumers, + encryptStream, reserveSize, producer); + ASSERT_EQ(rc.code, K_OK); + } + + void CreateElement(size_t elementSize, Element &element, std::string &writeElement) + { + writeElement = RandomData().GetRandomString(elementSize); + element = Element(reinterpret_cast(&writeElement[0]), elementSize); + } + + void CreateElements(size_t numEle, size_t elementSize, std::vector &elements, + std::vector &writeElements) + { + elements.clear(); + elements.resize(numEle); + writeElements.clear(); + writeElements.resize(numEle); + for (size_t i = 0; i < numEle; ++i) { + CreateElement(elementSize, elements[i], writeElements[i]); + } + } + + void SendElements(Producer_p producerPtr, std::vector &elements) + { + StatusC rc; + for (auto &ele : elements) { + rc = StreamProducerSend(producerPtr, ele.ptr, ele.size, ele.id); + ASSERT_EQ(rc.code, K_OK); + } + } + + void ReceiveElements(Consumer_p consumerPtr, std::vector &elements) + { + StreamElement *eles = nullptr; + uint64_t count = 0; + elements.clear(); + auto rc = StreamConsumerReceive(consumerPtr, timeout_, &eles, &count); + ASSERT_EQ(rc.code, K_OK); + elements.reserve(count); + for (uint64_t i = 0; i < count; ++i) { + elements.emplace_back(reinterpret_cast(eles[i].ptr), eles[i].size); + } + rc = StreamConsumerAck(consumerPtr, eles[count - 1].id); + ASSERT_EQ(rc.code, K_OK); + delete eles; + } + + void ReceiveElementsExpected(Consumer_p consumerPtr, uint32_t numExpect, std::vector &elements) + { + StreamElement *eles = nullptr; + uint64_t count = 0; + elements.clear(); + auto rc = StreamConsumerReceiveExpect(consumerPtr, numExpect, timeout_, &eles, &count); + ASSERT_EQ(rc.code, K_OK); + elements.reserve(count); + for (uint64_t i = 0; i < count; ++i) { + elements.emplace_back(reinterpret_cast(eles[i].ptr), eles[i].size); + } + rc = StreamConsumerAck(consumerPtr, eles[count - 1].id); + ASSERT_EQ(rc.code, K_OK); + delete eles; + } + +protected: + std::string ak_ = "QTWAOYTTINDUT2QVKYUC"; + std::string sk_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + StreamClient_p client0_{ nullptr }; + StreamClient_p client1_{ nullptr }; + int64_t delayFlushTime_{ 5 }; + int64_t pageSize_{ 1024 * 1024ul }; + uint64_t maxStreamSize_{ 1024 * 1024 * 1024ul }; + uint32_t timeout_ = 100; + uint64_t retainForNumConsumers = 0; + bool autoCleanup = false; + bool encryptStream = false; +}; + +TEST_F(StreamCacheTest, CreateProducerConsumer) +{ + Producer_p producer = nullptr; + std::string streamName = "CreateProducerConsumer"; + CreateProducer(client0_, streamName, delayFlushTime_, pageSize_, maxStreamSize_, autoCleanup, &producer); + std::string subName = "CreateProducerConsumerSub"; + Consumer_p consumer = nullptr; + Subscribe(client0_, streamName, subName, &consumer); +} + +TEST_F(StreamCacheTest, SendReceiveElements1) +{ + Producer_p producer = nullptr; + std::string streamName = "SendReceiveElements1"; + CreateProducer(client0_, streamName, delayFlushTime_, pageSize_, maxStreamSize_, autoCleanup, &producer); + std::string subName = "SendReceiveElements1Sub"; + Consumer_p consumer = nullptr; + Subscribe(client0_, streamName, subName, &consumer); + size_t numEle = 100; + size_t eleSize = 1024; + std::vector elements; + std::vector writeElements; + CreateElements(numEle, eleSize, elements, writeElements); + + SendElements(producer, elements); + auto rc = StreamProducerFlush(producer); + ASSERT_EQ(rc.code, K_OK); + std::vector outElements; + ReceiveElements(consumer, outElements); + ASSERT_EQ(writeElements.size(), outElements.size()); + for (size_t i = 0; i < writeElements.size(); ++i) { + ASSERT_EQ(writeElements[i], writeElements[i]); + } +} + +TEST_F(StreamCacheTest, SendReceiveElements2) +{ + Producer_p producer = nullptr; + std::string streamName = "SendReceiveElements2"; + CreateProducer(client0_, streamName, delayFlushTime_, pageSize_, maxStreamSize_, autoCleanup, &producer); + std::string subName = "SendReceiveElements2Sub"; + Consumer_p consumer = nullptr; + Subscribe(client0_, streamName, subName, &consumer); + size_t numEle = 100; + size_t eleSize = 1024; + std::vector writeElements; + for (size_t i = 0; i < numEle; ++i) { + std::vector elements; + std::vector writeElement; + CreateElements(1, eleSize, elements, writeElement); + SendElements(producer, elements); + auto rc = StreamProducerFlush(producer); + ASSERT_EQ(rc.code, K_OK); + writeElements.emplace_back(writeElement[0]); + } + std::vector outElements; + ReceiveElements(consumer, outElements); + ASSERT_EQ(writeElements.size(), outElements.size()); + for (size_t i = 0; i < writeElements.size(); ++i) { + ASSERT_EQ(writeElements[i], writeElements[i]); + } +} + +TEST_F(StreamCacheTest, SendReceiveElementsExpectFromRemote) +{ + Producer_p producer = nullptr; + std::string streamName = "SendReceiveElementsFromRemote"; + CreateProducer(client0_, streamName, delayFlushTime_, pageSize_, maxStreamSize_, autoCleanup, &producer); + std::string subName = "SendReceiveElementsFromRemoteSub"; + Consumer_p consumer = nullptr; + // connect to remote worker + Subscribe(client1_, streamName, subName, &consumer); + size_t numEle = 50; + size_t eleSize = 1024; + std::vector writeElements; + for (size_t i = 0; i < numEle; ++i) { + std::vector elements; + std::vector writeElement; + CreateElements(1, eleSize, elements, writeElement); + SendElements(producer, elements); + auto rc = StreamProducerFlush(producer); + ASSERT_EQ(rc.code, K_OK); + writeElements.emplace_back(writeElement[0]); + } + std::vector outElements; + int sleepTime = 1; + std::this_thread::sleep_for(std::chrono::seconds(sleepTime)); + for (size_t i = 0; i < numEle; ++i) { + std::vector outElement; + ReceiveElementsExpected(consumer, 1, outElement); + outElements.emplace_back(outElement[0]); + } + for (size_t i = 0; i < numEle; ++i) { + ASSERT_EQ(writeElements[i], writeElements[i]); + } + ReceiveElementsExpected(consumer, 1, outElements); + ASSERT_EQ(outElements.size(), 0u); +} +} // namespace st +} // namespace datasystem \ No newline at end of file diff --git a/tests/st/cluster/base_cluster.h b/tests/st/cluster/base_cluster.h index 7d9bd13..38c0fbc 100644 --- a/tests/st/cluster/base_cluster.h +++ b/tests/st/cluster/base_cluster.h @@ -66,6 +66,8 @@ struct BaseClusterOptions { numZmqServerCtx(DEFAULT_ZMQ_SERVER_IO_CTX_NUM), numOcThreadNum(DEFAULT_THREAD_NUM), numSpillThreadNum(DEFAULT_THREAD_NUM), + numScRegularSocket(DEFAULT_THREAD_NUM), + numScStreamSocket(DEFAULT_THREAD_NUM), numEtcd(0), numOBS(0) { @@ -100,6 +102,14 @@ struct BaseClusterOptions { // Default: 4 uint32_t numSpillThreadNum; + // The number of regular backend socket for stream cache. + // Default: 4 + uint32_t numScRegularSocket; + + // The number of stream backend socket for stream cache. + // Default: 4 + uint32_t numScStreamSocket; + // Master ip address. // Default: None. Must be specified by the user. std::vector masterIpAddrs; @@ -109,9 +119,10 @@ struct BaseClusterOptions { std::vector workerConfigs; // Extra tcp/ip port for worker <-> worker direct connection. - // One for WorkerWorkerOCService + // One for WorkerWorkerOCService and one WorkerWorkerSCService // Default: None. It is optional but its size must match workerConfigs std::vector workerOcDirectPorts; + std::vector workerScDirectPorts; uint32_t numEtcd; diff --git a/tests/st/cluster/external_cluster.cpp b/tests/st/cluster/external_cluster.cpp index 7125848..b63dee4 100644 --- a/tests/st/cluster/external_cluster.cpp +++ b/tests/st/cluster/external_cluster.cpp @@ -910,7 +910,7 @@ Status ExternalCluster::StartMaster(int index) } etcdUrl += addrs.first.ToString(); } - masterCmd += " -backend_store=etcd -etcd_address=" + etcdUrl + " -az_name=" + opts_.etcdPrefix; + masterCmd += " -backend_store=etcd -etcd_address=" + etcdUrl + " -cluster_name=" + opts_.etcdPrefix; } LOG(INFO) << "Launch master [" << index << "] command: " << masterCmd; auto masterProcess = std::make_unique(masterCmd, opts_.masterIpAddrs[index]); @@ -999,7 +999,15 @@ Status ExternalCluster::StartWorker(int index, const HostPort &address, std::str spillDir = rootDir + "/spill"; } (void)DeleteFile(healthFile); - + if (opts_.isStreamCacheCase) { + opts_.numRpcThreads = 1; + opts_.numOcThreadNum = 1; + opts_.workerGflagParams = " -sc_regular_socket_num=" + std::to_string(opts_.numScRegularSocket) + + " -sc_stream_socket_num=" + std::to_string(opts_.numScStreamSocket) + " " + + opts_.workerGflagParams; + } else { + opts_.workerGflagParams = " -sc_regular_socket_num=0 -sc_stream_socket_num=0 " + opts_.workerGflagParams; + } std::string injectActions = "test.start.notWait:call(0)" + (opts_.injectActions.empty() ? "" : ";" + opts_.injectActions) + (opts_.disableRocksDB ? ";master.disableRocksDb:1*call()" : ""); @@ -1175,6 +1183,7 @@ ExternalClusterOptions::ExternalClusterOptions() waitWorkerReady(true), enableLivenessProbe(false), skipWorkerPreShutdown(true), + isStreamCacheCase(false), disableRocksDB(true) { if (isObjectCache) { diff --git a/tests/st/cluster/external_cluster.h b/tests/st/cluster/external_cluster.h index 00f6083..61eef17 100644 --- a/tests/st/cluster/external_cluster.h +++ b/tests/st/cluster/external_cluster.h @@ -81,7 +81,7 @@ public: std::map crossAZMap; // Parameters for starting the worker - // For example, "-shared_memory_size_mb=1024" + // For example, "-page_size=102400 -shared_memory_size_mb=1024" std::string workerGflagParams; // Parameter for specific worker. @@ -122,6 +122,8 @@ public: // Skip the shutdown process to accelerate worker exit. bool skipWorkerPreShutdown; + bool isStreamCacheCase; + // Disable rocksdb, default true bool disableRocksDB; diff --git a/tests/st/common/kvstore/etcd_store_test.cpp b/tests/st/common/kvstore/etcd_store_test.cpp index c810465..17b115d 100644 --- a/tests/st/common/kvstore/etcd_store_test.cpp +++ b/tests/st/common/kvstore/etcd_store_test.cpp @@ -45,7 +45,7 @@ using namespace datasystem; DS_DECLARE_string(etcd_address); -DS_DECLARE_string(az_name); +DS_DECLARE_string(cluster_name); DS_DECLARE_bool(enable_etcd_auth); DS_DECLARE_string(etcd_target_name_override); DS_DECLARE_string(encrypt_kit); @@ -848,7 +848,7 @@ TEST_F(EtcdStoreTest, TestGetEtcdPrefix) { LOG(INFO) << "Test EtcdStore GetEtcdPrefix"; std::string table_prefix = "AZ1"; - FLAGS_az_name = table_prefix; + FLAGS_cluster_name = table_prefix; InitTestEtcdInstance(); std::string prefix; DS_ASSERT_OK(db_->GetEtcdPrefix(tableName_, prefix)); diff --git a/tests/st/common/kvstore/rocks_store_test.cpp b/tests/st/common/kvstore/rocks_store_test.cpp index 326bab7..58e3ca7 100644 --- a/tests/st/common/kvstore/rocks_store_test.cpp +++ b/tests/st/common/kvstore/rocks_store_test.cpp @@ -27,9 +27,10 @@ #include "datasystem/common/kvstore/rocksdb/rocks_store.h" #include "datasystem/common/util/file_util.h" #include "datasystem/common/util/random_data.h" +#include "datasystem/utils/status.h" using namespace datasystem; - +DS_DECLARE_string(rocksdb_write_mode); namespace datasystem { namespace st { class RocksStoreTest : public CommonTest { @@ -64,6 +65,7 @@ RandomData RocksStoreTest::random_; void RocksStoreTest::SetUp() { + FLAGS_rocksdb_write_mode = "sync"; db_ = RocksStore::GetInstance(dbName_); if (db_ != nullptr) { rocksdb::ColumnFamilyOptions options; diff --git a/tests/st/common/stream_cache/element_generator.cpp b/tests/st/common/stream_cache/element_generator.cpp new file mode 100644 index 0000000..fe9b76f --- /dev/null +++ b/tests/st/common/stream_cache/element_generator.cpp @@ -0,0 +1,268 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Element Generator. + */ +#include "element_generator.h" + +#include +#include +#include +#include +#include + +#include + +#include "common.h" +#include "datasystem/common/util/strings_util.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/common/util/timer.h" +#include "datasystem/common/log/log.h" + +namespace datasystem { +std::string Md5SumStr(uint8_t *data, uint64_t sz) +{ + uint8_t digest[MD5_DIGEST_LENGTH]; + MD5(data, sz, digest); + + std::stringstream result; + constexpr uint32_t hexWidthForUint8 = 2; + for (int i : digest) { + result << std::setw(hexWidthForUint8) << std::setfill('0') << std::hex << i; + } + return result.str(); +} + +std::string Md5Sum(const std::string &str) +{ + return Md5SumStr((uint8_t *)str.data(), str.size()); +} + +std::string ElementGenerator::GenElement(const std::string &producerId) +{ + return GenElements(producerId, 1).front(); +} + +ElementBuilder &ElementBuilder::SetProducerId(std::string producerId) +{ + producerId_ = std::move(producerId); + return *this; +} + +ElementBuilder &ElementBuilder::SetSeqNo(uint64_t seqNo) +{ + seqNo_ = seqNo; + return *this; +} + +ElementBuilder &ElementBuilder::SetData(uint8_t *data, uint64_t sz) +{ + data_ = data; + dataSz_ = sz; + return *this; +} + +std::string ElementBuilder::Build() const +{ + // | lenOfId | id | seqNo | lenData | Data | lenOfMd5 | md5Sum |. + uint64_t cap = sizeof(uint64_t) * 3 + producerId_.size() + dataSz_; + std::string str; + constexpr uint64_t MD5_STR_LEN = 32; + str.reserve(cap + sizeof(uint64_t) + MD5_STR_LEN); + + // uint64 virtual address. + uint64_t uint64 = producerId_.size(); + auto ptr = (uint8_t *)(&uint64); + auto ptr_end = ptr + sizeof(uint64_t); + + // Id. + std::copy(ptr, ptr_end, std::back_inserter(str)); + std::copy(producerId_.data(), producerId_.data() + producerId_.size(), std::back_inserter(str)); + + // SeqNo. + uint64 = seqNo_; + std::copy(ptr, ptr_end, std::back_inserter(str)); + + // Data. + uint64 = dataSz_; + std::copy(ptr, ptr_end, std::back_inserter(str)); + std::copy(data_, data_ + dataSz_, std::back_inserter(str)); + + // Md5Sum. + auto md5Sum = Md5SumStr((uint8_t *)str.data(), cap); + uint64 = md5Sum.size(); + + std::copy(ptr, ptr_end, std::back_inserter(str)); + std::copy(begin(md5Sum), end(md5Sum), std::back_inserter(str)); + return str; +} + +ElementGenerator::ElementGenerator(uint64_t maxSz, uint64_t minSz) + : maxSz_(maxSz), minSz_(minSz), randomData_(), randomBytes_(randomData_.RandomBytes(maxSz * 2)) +{ +} + +std::vector ElementGenerator::GenElements(const std::string &producerId, uint64_t numElements, + int numThreads) +{ + Timer timer; + std::vector res; + res.resize(numElements); + auto it = seqNoMap_.find(producerId); + if (it == seqNoMap_.end()) { + it = seqNoMap_.emplace(producerId, 0).first; + } + auto &seqNo = it->second; + auto startSeqNo = seqNo; + std::vector> token(numElements + 1); + token[0].set_value(true); + + std::mutex mtx; + std::unique_ptr pool; + std::unique_ptr pool2; + LOG_IF_EXCEPTION_OCCURS(pool = std::make_unique(numThreads)); + LOG_IF_EXCEPTION_OCCURS(pool2 = std::make_unique(numThreads)); + for (auto i = 0u; i < numElements; i++) { + std::shared_future strFut = pool->Submit([i, this, &mtx, &seqNo, startSeqNo, producerId]() { + uint64_t sz; + if (i % 10 == 0) { + sz = maxSz_; + } else { + sz = randomData_.GetRandomUint64(minSz_, maxSz_); + } + auto pos = randomData_.GetRandomUint64() % maxSz_; + + auto builder = std::make_shared(); + builder->SetProducerId(producerId).SetData(randomBytes_.data() + pos, sz); + builder->SetSeqNo(i + startSeqNo); + { + std::lock_guard lck(mtx); + seqNo++; + } + auto str = builder->Build(); + return str; + }); + pool2->Submit([i, strFut, &token, &res]() { + token[i].get_future().get(); + res[i] = strFut.get(); + token[i + 1].set_value(true); + }); + } + pool.reset(); + pool2.reset(); + LOG(INFO) << FormatString("Stream Gen elapsed: [%.6lf]s", timer.ElapsedSecond()); + return res; +} + +ElementView::ElementView(std::string view) : view_(std::move(view)) +{ +} + +Status ElementView::ParseData() +{ + // | lenOfId | id | seqNo | lenData | Data | lenOfMd5 | md5Sum |. + uint64_t off = 0; + auto lenOfId = *reinterpret_cast(view_.data() + off); + off += sizeof(uint64_t); + CHECK_FAIL_RETURN_STATUS(off + lenOfId < view_.size(), StatusCode::K_RUNTIME_ERROR, + FormatString("off: [%zu], lenOfId: [%zu], viewSz: [%zu]", off, lenOfId, view_.size())); + producerId_ = std::string(view_.data() + off, lenOfId); + off += lenOfId; + CHECK_FAIL_RETURN_STATUS( + off + sizeof(uint64_t) < view_.size(), StatusCode::K_RUNTIME_ERROR, + FormatString("off: [%zu], lenOfSeqNo: [%zu], viewSz: [%zu]", off, sizeof(uint64_t), view_.size())); + seqNo_ = *reinterpret_cast(view_.data() + off); + off += sizeof(uint64_t); + + CHECK_FAIL_RETURN_STATUS( + off + sizeof(uint64_t) < view_.size(), StatusCode::K_RUNTIME_ERROR, + FormatString("off: [%zu], lenMeta: [%zu], viewSz: [%zu]", off, sizeof(uint64_t), view_.size())); + auto lenOfData = *reinterpret_cast(view_.data() + off); + off += sizeof(uint64_t); + + CHECK_FAIL_RETURN_STATUS(off + lenOfData < view_.size(), StatusCode::K_RUNTIME_ERROR, + FormatString("off: [%zu], data: [%zu], viewSz: [%zu]", off, lenOfData, view_.size())); + data_ = std::string(view_.data() + off, lenOfData); + off += lenOfData; + + CHECK_FAIL_RETURN_STATUS( + off + sizeof(uint64_t) < view_.size(), StatusCode::K_RUNTIME_ERROR, + FormatString("off: [%zu], data: [%zu], viewSz: [%zu]", off, sizeof(uint64_t), view_.size())); + auto lenOfMd5 = *reinterpret_cast(view_.data() + off); + off += sizeof(uint64_t); + md5sum_ = std::string(view_.data() + off, lenOfMd5); + + CHECK_FAIL_RETURN_STATUS(off + lenOfMd5 == view_.size(), StatusCode::K_RUNTIME_ERROR, + FormatString("Producer id:[%s], seq no:[%zu], off: [%zu], data: [%zu], viewSz: [%zu]", + producerId_, seqNo_, off, lenOfMd5, view_.size())); + lenPayload_ = off - sizeof(uint64_t); + isParsed = true; + DLOG(INFO) << FormatString("id: [%s], seqNo: [%zu], lenData: [%zu], md5sum: [%s]", producerId_, seqNo_, lenOfData, + md5sum_); + return Status::OK(); +} + +Status ElementView::VerifyIntegrity() +{ + if (!isParsed) { + RETURN_IF_NOT_OK(ParseData()); + } + auto md5Sum = Md5SumStr((uint8_t *)view_.data(), lenPayload_); + CHECK_FAIL_RETURN_STATUS(md5Sum == md5sum_, StatusCode::K_RUNTIME_ERROR, + FormatString("Integrity violation; ground truth: [%s]; got: [%s]", md5sum_, md5Sum)); + return Status::OK(); +} + +Status ElementView::VerifyFifo(std::unordered_map &seqNoMap, uint64_t offset) +{ + if (!isParsed) { + RETURN_IF_NOT_OK(ParseData()); + } + + std::string producerId{ producerId_.data(), producerId_.size() }; + auto it = seqNoMap.find(producerId); + if (it == seqNoMap.end()) { + it = seqNoMap.emplace(producerId, offset).first; + } + + auto seqNo = it->second; + CHECK_FAIL_RETURN_STATUS( + seqNo == seqNo_, StatusCode::K_RUNTIME_ERROR, + FormatString("FIFO SeqNo violation [%s], got: [%zu], expect: [%zu]", producerId, seqNo_, seqNo)); + it->second++; + return Status::OK(); +} + +Status ElementView::VerifyFifoInitOff(std::unordered_map &seqNoMap) +{ + if (!isParsed) { + RETURN_IF_NOT_OK(ParseData()); + } + std::string producerId{ producerId_.data(), producerId_.size() }; + auto it = seqNoMap.find(producerId); + if (it == seqNoMap.end()) { + it = seqNoMap.emplace(producerId, seqNo_).first; + } + + auto seqNo = it->second; + CHECK_FAIL_RETURN_STATUS( + seqNo == seqNo_, StatusCode::K_RUNTIME_ERROR, + FormatString("FIFO SeqNo violation [%s], gt: [%zu], got: [%zu]", producerId, seqNo_, seqNo)); + it->second++; + return Status::OK(); +} +} // namespace datasystem diff --git a/tests/st/common/stream_cache/element_generator.h b/tests/st/common/stream_cache/element_generator.h new file mode 100644 index 0000000..5d9fbb8 --- /dev/null +++ b/tests/st/common/stream_cache/element_generator.h @@ -0,0 +1,115 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Element Generator. + */ +#ifndef DATASYSTEM_TEST_ST_WORKER_STREAM_CACHE_ELEMENT_GENERATOR_H +#define DATASYSTEM_TEST_ST_WORKER_STREAM_CACHE_ELEMENT_GENERATOR_H + +#include + +#include "datasystem/common/util/random_data.h" +#include "datasystem/common/util/status_helper.h" + +namespace datasystem { +std::string Md5Sum(const std::string &str); +std::string Md5SumStr(uint8_t *data, uint64_t sz); + +/* + * Element: + * | lenOfId | id | seqNo | lenData | Data | lenOfMd5 | md5Sum | + * | lenOfId | seqNo | lenData | lenOfMd5 | id | Data | md5Sum | + * i.e., | Payload | md5 | + */ +class ElementBuilder { +public: + ElementBuilder() = default; + + ElementBuilder &SetProducerId(std::string producerId); + + ElementBuilder &SetSeqNo(uint64_t seqNo); + + ElementBuilder &SetData(uint8_t *data, uint64_t sz); + + std::string Build() const; + +private: + std::string producerId_; + uint64_t seqNo_ = 0; + uint8_t *data_ = nullptr; + uint64_t dataSz_ = 0; +}; + +class ElementView { +public: + explicit ElementView(std::string view); + + Status VerifyFifo(std::unordered_map &seqNoMap, uint64_t offset = 0); + + Status VerifyFifoInitOff(std::unordered_map &seqNoMap); + + Status VerifyIntegrity(); + + Status GetSeqNo(uint64_t &seqNo) + { + if (!isParsed) { + RETURN_IF_NOT_OK(ParseData()); + } + seqNo = seqNo_; + return Status::OK(); + } + + Status GetProducerId(std::string &outProducerId) + { + if (!isParsed) { + RETURN_IF_NOT_OK(ParseData()); + } + outProducerId = producerId_; + return Status::OK(); + } + +private: + Status ParseData(); + + std::string producerId_; + uint64_t seqNo_ = 0; + std::string data_; + std::string md5sum_; + bool isParsed = false; + uint64_t lenPayload_ = 0; + + std::string view_; +}; + +class ElementGenerator { +public: + explicit ElementGenerator(uint64_t maxSz, uint64_t minSz = 1); + + std::vector GenElements(const std::string &producerId, uint64_t numElements, int numThreads = 1); + + std::string GenElement(const std::string &producerId); + +private: + uint64_t maxSz_; + uint64_t minSz_; + RandomData randomData_; + std::vector randomBytes_; + std::unordered_map seqNoMap_; +}; +} // namespace datasystem + +#endif // DATASYSTEM_ELEMENT_GENERATOR_H diff --git a/tests/st/common/stream_cache/mock_evictmanager.h b/tests/st/common/stream_cache/mock_evictmanager.h new file mode 100644 index 0000000..05b735a --- /dev/null +++ b/tests/st/common/stream_cache/mock_evictmanager.h @@ -0,0 +1,44 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test StreamPage StreamPageOwner classes. + */ +#ifndef DATASYSTEM_TEST_ST_WORKER_STREAM_CACHE_MOCK_EVICTMANAGER_H +#define DATASYSTEM_TEST_ST_WORKER_STREAM_CACHE_MOCK_EVICTMANAGER_H + +#include + +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" + +namespace datasystem { +namespace st { +class MockEvictManager : public datasystem::worker::stream_cache::WorkerSCAllocateMemory { +public: + MockEvictManager() : WorkerSCAllocateMemory(nullptr){}; + + Status AllocateMemoryForStream(const std::string &tenantId, const std::string &streamId, + const uint64_t needSize, bool populate, ShmUnit &shmUnit, bool retryOnOOM) + { + (void)streamId; + (void)retryOnOOM; + return shmUnit.AllocateMemory(tenantId, needSize, populate); + } +}; +} // namespace st +} // namespace datasystem +#endif \ No newline at end of file diff --git a/tests/st/common/stream_cache/stream_common.h b/tests/st/common/stream_cache/stream_common.h new file mode 100644 index 0000000..3958d71 --- /dev/null +++ b/tests/st/common/stream_cache/stream_common.h @@ -0,0 +1,46 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test StreamPage StreamPageOwner classes. + */ +#ifndef DATASYSTEM_TEST_ST_WORKER_STREAM_CACHE_SINGLE_STREAM_COMMONa_H +#define DATASYSTEM_TEST_ST_WORKER_STREAM_CACHE_SINGLE_STREAM_COMMONa_H + +#include + +#include "common.h" +#include "mock_evictmanager.h" + +namespace datasystem { +namespace st { +constexpr uint64_t SHM_CAP = 128L * 1024L * 1024L; +constexpr uint64_t BIG_SHM_CAP = 10L * 1024L * 1024L * 1024L; +constexpr uint64_t PAGE_SIZE = 4L * 1024L; +constexpr uint64_t BIG_PAGE_SIZE = 512L * 1024L; +constexpr uint64_t BIG_SIZE_RATIO = 16L; +constexpr uint64_t BIG_SIZE = PAGE_SIZE / BIG_SIZE_RATIO; +constexpr uint64_t KB = 1024L; +constexpr uint64_t MB = 1024L * KB; +constexpr uint64_t GB = 1024L * MB; +constexpr uint64_t TEST_STREAM_SIZE = 64 * MB; + +constexpr int NUM_ELES = 50; + +} // namespace st +} // namespace datasystem + +#endif diff --git a/tests/st/device/dev_object_client_test.cpp b/tests/st/device/dev_object_client_test.cpp deleted file mode 100644 index e69de29..0000000 diff --git a/tests/st/device/dev_object_hetero_test.cpp b/tests/st/device/dev_object_hetero_test.cpp index c815234..f95cf41 100644 --- a/tests/st/device/dev_object_hetero_test.cpp +++ b/tests/st/device/dev_object_hetero_test.cpp @@ -48,7 +48,6 @@ #include "datasystem/client/hetero_cache/device_util.h" #include "device/dev_test_helper.h" -using datasystem::memory::Allocator; using datasystem::memory::DevMemFuncRegister; namespace datasystem { @@ -209,7 +208,7 @@ void DevObjectHeteroTest::SwapInOutPerformanceTest() void DevObjectHeteroTest::Sub(const std::vector &inObjectKeys, const std::vector &strVec, size_t batch, std::shared_ptr client) { - int32_t deviceId = 5; + int32_t deviceId = 1; std::shared_ptr localClient; if (client == nullptr) { InitAcl(deviceId); @@ -241,7 +240,7 @@ void DevObjectHeteroTest::Sub(const std::vector &inObjectKeys, cons void DevObjectHeteroTest::Pub(const std::vector &inObjectKeys, const std::vector &strVec, size_t batch, std::shared_ptr client) { - int32_t deviceId = 4; + int32_t deviceId = 0; std::shared_ptr localClient; if (client == nullptr) { InitAcl(deviceId); @@ -685,7 +684,7 @@ TEST_F(DevObjectHeteroTest, DISABLED_AllocateDeviceMemTest) return Status::OK(); }; - auto *allocator = Allocator::Instance(); + auto *allocator = datasystem::memory::Allocator::Instance(); struct DevMemFuncRegister regFunc; regFunc.devDeviceCreateFunc = allocateFunc; regFunc.devDeviceDestroyFunc = destroyFunc; @@ -712,7 +711,8 @@ TEST_F(DevObjectHeteroTest, DISABLED_AllocateDeviceMemTest) int loopNums = 5; for (int i = 0; i < loopNums; i++) { std::string tenantId1 = "tenant1"; - auto rc = shmUnit.AllocateMemory(tenantId1, maxSize2 - 1, false, datasystem::memory::CacheType::DEV_HOST); + auto rc = shmUnit.AllocateMemory(tenantId1, maxSize2 - 1, false, ServiceType::OBJECT, + datasystem::memory::CacheType::DEV_HOST); LOG(INFO) << "allocate info " << rc.ToString(); auto freeRc = shmUnit.FreeMemory(); LOG(INFO) << "Free result is : " << freeRc.ToString(); @@ -720,6 +720,268 @@ TEST_F(DevObjectHeteroTest, DISABLED_AllocateDeviceMemTest) } } +TEST_F(DevObjectHeteroTest, DISABLED_TestRecvRootInfoDeadlock_ExchangeDataWithEachOther) +{ + /** + c1: DevMSet(key1) and DevMGet(key2) + c2: DevMSet(key2) and DevMGet(key1) + deadlock: c1 RecvRootInfo 、c2 RecvRootInfo + */ + size_t blkSz = 10; + auto blksPerObj = 10; + size_t timeout = 40 * 1000; + std::vector inObjectKeys = { "key1", "key2" }; + std::vector swapOutBlobList; + std::vector swapInBlobList; + DS_ASSERT_OK(inject::Set("CreateHcclCommInSend.sleep", "2*sleep(10)")); + auto child1 = ForkForTest([&]() { + int deviceId = 0; + InitAcl(deviceId); + + std::shared_ptr client1; + InitTestHeteroClient(0, client1); + std::vector failedIdList; + PrePareDevData(1, blksPerObj, blkSz, swapOutBlobList, swapInBlobList, deviceId); + DS_ASSERT_OK(client1->DevMSet({ inObjectKeys[0] }, swapInBlobList, failedIdList)); + DS_ASSERT_TRUE(failedIdList.empty(), true); + + DS_ASSERT_OK(client1->DevMGet({ inObjectKeys[1] }, swapOutBlobList, failedIdList, timeout)); + auto expectContent = std::string(blkSz, 'b'); + for (auto &devBlobList : swapOutBlobList) { + for (auto &blob : devBlobList.blobs) { + CheckDevPtrContent(blob.pointer, blkSz, expectContent); + } + } + + // Use ObjectClient to synchronize + std::shared_ptr objectClient; + InitTestClient(0, objectClient); + std::vector> buffers; + CreateParam param = CreateParam{}; + std::string value = "notice"; + DS_ASSERT_OK(objectClient->Put({ inObjectKeys[1] + "_DevMGet_finish" }, + reinterpret_cast(value.data()), value.size(), param)); + DS_ASSERT_OK(objectClient->Get({ inObjectKeys[0] + "_DevMGet_finish" }, timeout, buffers)); + }); + auto child2 = ForkForTest([&]() { + int deviceId = 1; + InitAcl(deviceId); + + std::shared_ptr client2; + InitTestHeteroClient(0, client2); + std::vector failedIdList; + PrePareDevData(1, blksPerObj, blkSz, swapOutBlobList, swapInBlobList, deviceId); + DS_ASSERT_OK(client2->DevMSet({ inObjectKeys[1] }, swapInBlobList, failedIdList)); + DS_ASSERT_TRUE(failedIdList.empty(), true); + + DS_ASSERT_OK(client2->DevMGet({ inObjectKeys[0] }, swapOutBlobList, failedIdList, timeout)); + auto expectContent = std::string(blkSz, 'b'); + for (auto &devBlobList : swapOutBlobList) { + for (auto &blob : devBlobList.blobs) { + CheckDevPtrContent(blob.pointer, blkSz, expectContent); + } + } + + // Use ObjectClient to synchronize + std::shared_ptr objectClient; + InitTestClient(0, objectClient); + std::vector> buffers; + CreateParam param = CreateParam{}; + std::string value = "notice"; + DS_ASSERT_OK(objectClient->Put({ inObjectKeys[0] + "_DevMGet_finish" }, + reinterpret_cast(value.data()), value.size(), param)); + DS_ASSERT_OK(objectClient->Get({ inObjectKeys[1] + "_DevMGet_finish" }, timeout, buffers)); + }); + DS_ASSERT_TRUE(WaitForChildFork(child1), 0); + DS_ASSERT_TRUE(WaitForChildFork(child2), 0); +} + +TEST_F(DevObjectHeteroTest, DISABLED_TestSendRootInfoDeadlock_ExchangeDataWithEachOther) +{ + /** + c1: DevMSet(key1) and DevMGet(key2) + c2: DevMSet(key2) and DevMGet(key1) + deadlock: c1 SendRootInfo 、 c2 SendRootInfo + */ + size_t blkSz = 10; + auto blksPerObj = 10; + size_t timeout = 40 * 1000; + std::vector inObjectKeys = { "key1", "key2" }; + std::vector swapOutBlobList; + std::vector swapInBlobList; + DS_ASSERT_OK(inject::Set("CreateHcclCommInRecv.sleep", "1*sleep(1000)")); + auto child1 = ForkForTest([&]() { + int deviceId = 0; + InitAcl(deviceId); + + std::shared_ptr client1; + InitTestHeteroClient(0, client1); + std::vector failedIdList; + PrePareDevData(1, blksPerObj, blkSz, swapOutBlobList, swapInBlobList, deviceId); + DS_ASSERT_OK(client1->DevMSet({ inObjectKeys[0] }, swapInBlobList, failedIdList)); + DS_ASSERT_TRUE(failedIdList.empty(), true); + + DS_ASSERT_OK(client1->DevMGet({ inObjectKeys[1] }, swapOutBlobList, failedIdList, timeout)); + auto expectContent = std::string(blkSz, 'b'); + for (auto &devBlobList : swapOutBlobList) { + for (auto &blob : devBlobList.blobs) { + CheckDevPtrContent(blob.pointer, blkSz, expectContent); + } + } + + // Use ObjectClient to synchronize + std::shared_ptr objectClient; + InitTestClient(0, objectClient); + std::vector> buffers; + CreateParam param = CreateParam{}; + std::string value = "notice"; + DS_ASSERT_OK(objectClient->Put({ inObjectKeys[1] + "_DevMGet_finish" }, + reinterpret_cast(value.data()), value.size(), param)); + DS_ASSERT_OK(objectClient->Get({ inObjectKeys[0] + "_DevMGet_finish" }, timeout, buffers)); + }); + auto child2 = ForkForTest([&]() { + int deviceId = 1; + InitAcl(deviceId); + + std::shared_ptr client2; + InitTestHeteroClient(0, client2); + std::vector failedIdList; + PrePareDevData(1, blksPerObj, blkSz, swapOutBlobList, swapInBlobList, deviceId); + DS_ASSERT_OK(client2->DevMSet({ inObjectKeys[1] }, swapInBlobList, failedIdList)); + DS_ASSERT_TRUE(failedIdList.empty(), true); + + DS_ASSERT_OK(client2->DevMGet({ inObjectKeys[0] }, swapOutBlobList, failedIdList, timeout)); + auto expectContent = std::string(blkSz, 'b'); + for (auto &devBlobList : swapOutBlobList) { + for (auto &blob : devBlobList.blobs) { + CheckDevPtrContent(blob.pointer, blkSz, expectContent); + } + } + + // Use ObjectClient to synchronize + std::shared_ptr objectClient; + InitTestClient(0, objectClient); + std::vector> buffers; + CreateParam param = CreateParam{}; + std::string value = "notice"; + DS_ASSERT_OK(objectClient->Put({ inObjectKeys[0] + "_DevMGet_finish" }, + reinterpret_cast(value.data()), value.size(), param)); + DS_ASSERT_OK(objectClient->Get({ inObjectKeys[1] + "_DevMGet_finish" }, timeout, buffers)); + }); + DS_ASSERT_TRUE(WaitForChildFork(child1), 0); + DS_ASSERT_TRUE(WaitForChildFork(child2), 0); +} + +TEST_F(DevObjectHeteroTest, DISABLED_TestDeadlock_Ring) +{ + /** + c1: DevMSet(key1) and DevMGet(key2) + c2: DevMSet(key2) and DevMGet(key3) + c3: DevMSet(key3) and DevMGet(key1) + deadlock: c1 SendRootInfo 、 c2 SendRootInfo、 c3 SendRootInfo + */ + size_t blkSz = 10; + auto blksPerObj = 10; + size_t timeout = 40 * 1000; + std::vector inObjectKeys = { "key1", "key2", "key3" }; + std::vector swapOutBlobList; + std::vector swapInBlobList; + DS_ASSERT_OK(inject::Set("CreateHcclCommInSend.sleep", "2*sleep(1000)")); + auto child1 = ForkForTest([&]() { + LOG(ERROR) << "Start process1"; + int deviceId = 0; + InitAcl(deviceId); + + std::shared_ptr client1; + InitTestHeteroClient(0, client1); + std::vector failedIdList; + LOG(ERROR) << "process1, Start to DevMSet"; + PrePareDevData(1, blksPerObj, blkSz, swapOutBlobList, swapInBlobList, deviceId); + DS_ASSERT_OK(client1->DevMSet({ inObjectKeys[0] }, swapInBlobList, failedIdList)); + DS_ASSERT_TRUE(failedIdList.empty(), true); + + DS_ASSERT_OK(client1->DevMGet({ inObjectKeys[1] }, swapOutBlobList, failedIdList, timeout)); + auto expectContent = std::string(blkSz, 'b'); + for (auto &devBlobList : swapOutBlobList) { + for (auto &blob : devBlobList.blobs) { + CheckDevPtrContent(blob.pointer, blkSz, expectContent); + } + } + // Use ObjectClient to synchronize + std::shared_ptr objectClient; + InitTestClient(0, objectClient); + std::vector> buffers; + CreateParam param = CreateParam{}; + std::string value = "notice"; + DS_ASSERT_OK(objectClient->Put({ inObjectKeys[1] + "_DevMGet_finish" }, + reinterpret_cast(value.data()), value.size(), param)); + DS_ASSERT_OK(objectClient->Get({ inObjectKeys[0] + "_DevMGet_finish" }, timeout, buffers)); + }); + auto child2 = ForkForTest([&]() { + LOG(ERROR) << "Start process2"; + int deviceId = 1; + InitAcl(deviceId); + + std::shared_ptr client2; + InitTestHeteroClient(0, client2); + std::vector failedIdList; + LOG(ERROR) << "process2, Start to DevMSet"; + PrePareDevData(1, blksPerObj, blkSz, swapOutBlobList, swapInBlobList, deviceId); + DS_ASSERT_OK(client2->DevMSet({ inObjectKeys[1] }, swapInBlobList, failedIdList)); + DS_ASSERT_TRUE(failedIdList.empty(), true); + + DS_ASSERT_OK(client2->DevMGet({ inObjectKeys[2] }, swapOutBlobList, failedIdList, timeout)); + auto expectContent = std::string(blkSz, 'b'); + for (auto &devBlobList : swapOutBlobList) { + for (auto &blob : devBlobList.blobs) { + CheckDevPtrContent(blob.pointer, blkSz, expectContent); + } + } + // Use ObjectClient to synchronize + std::shared_ptr objectClient; + InitTestClient(0, objectClient); + std::vector> buffers; + CreateParam param = CreateParam{}; + std::string value = "notice"; + DS_ASSERT_OK(objectClient->Put({ inObjectKeys[2] + "_DevMGet_finish" }, + reinterpret_cast(value.data()), value.size(), param)); + DS_ASSERT_OK(objectClient->Get({ inObjectKeys[1] + "_DevMGet_finish" }, timeout, buffers)); + }); + auto child3 = ForkForTest([&]() { + LOG(ERROR) << "Start process3"; + int deviceId = 2; + InitAcl(deviceId); + + std::shared_ptr client3; + InitTestHeteroClient(0, client3); + std::vector failedIdList; + LOG(ERROR) << "process3, Start to DevMSet"; + PrePareDevData(1, blksPerObj, blkSz, swapOutBlobList, swapInBlobList, deviceId); + DS_ASSERT_OK(client3->DevMSet({ inObjectKeys[2] }, swapInBlobList, failedIdList)); + DS_ASSERT_TRUE(failedIdList.empty(), true); + + DS_ASSERT_OK(client3->DevMGet({ inObjectKeys[0] }, swapOutBlobList, failedIdList, timeout)); + auto expectContent = std::string(blkSz, 'b'); + for (auto &devBlobList : swapOutBlobList) { + for (auto &blob : devBlobList.blobs) { + CheckDevPtrContent(blob.pointer, blkSz, expectContent); + } + } + // Use ObjectClient to synchronize + std::shared_ptr objectClient; + InitTestClient(0, objectClient); + std::vector> buffers; + CreateParam param = CreateParam{}; + std::string value = "notice"; + DS_ASSERT_OK(objectClient->Put({ inObjectKeys[0] + "_DevMGet_finish" }, + reinterpret_cast(value.data()), value.size(), param)); + DS_ASSERT_OK(objectClient->Get({ inObjectKeys[2] + "_DevMGet_finish" }, timeout, buffers)); + }); + DS_ASSERT_TRUE(WaitForChildFork(child1), 0); + DS_ASSERT_TRUE(WaitForChildFork(child2), 0); + DS_ASSERT_TRUE(WaitForChildFork(child3), 0); +} + TEST_F(DevObjectHeteroTest, DISABLED_TestPartOfMGet) { InitAcl(deviceId_); diff --git a/tests/st/device/hetero_d2h_test.cpp b/tests/st/device/hetero_d2h_test.cpp new file mode 100644 index 0000000..9c6525a --- /dev/null +++ b/tests/st/device/hetero_d2h_test.cpp @@ -0,0 +1,166 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: hetero d2h test. + */ +#include "device/dev_test_helper.h" + +namespace datasystem { +using namespace acl; +namespace st { +class HeteroD2HTest : public DevTestHelper { + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 1; + opts.workerGflagParams = + " -v=0 -authorization_enable=true -shared_memory_size_mb=4096 -enable_fallocate=false -arena_per_tenant=2 "; + opts.enableDistributedMaster = "false"; + opts.numEtcd = 1; + FLAGS_v = 0; + } + + void SetUp() override + { + const char *ascend_root = std::getenv("ASCEND_HOME_PATH"); + if (ascend_root == nullptr) { + DS_ASSERT_OK(datasystem::inject::Set("NO_USE_FFTS", "call()")); + DS_ASSERT_OK(datasystem::inject::Set("client.GetOrCreateHcclComm.setIsSameNode", "call(0)")); + BINEXPECT_CALL(AclDeviceManager::Instance, ()).WillRepeatedly(Return(&managerMock_)); + } + ExternalClusterTest::SetUp(); + } + + void TearDown() override + { + ExternalClusterTest::TearDown(); + } +}; + +TEST_F(HeteroD2HTest, Perf) +{ + std::vector keyNums{ 1, 10, 450, 500, 1000 }; + std::vector> costs(keyNums.size()); + std::shared_ptr client; + InitAcl(0); + InitTestHeteroClient(0, client); + auto repeatNum = 100u, testIdx = 0u, blkSz = 1024u, blksPerObj = 1u; + for (auto numOfObjs : keyNums) { + auto idx = 0u; + for (auto i = 0u; i < repeatNum; i++) { + LOG(INFO) << FormatString("start test MSetD2H Test ##### %lu KB, Round:%lu", numOfObjs, idx); + + std::vector inObjectKeys; + std::vector devGetBlobList, devSetBlobList; + for (auto j = 0ul; j < numOfObjs; j++) { + inObjectKeys.emplace_back(FormatString("round_%lu_key_%s", idx, j)); + } + PrePareDevData(numOfObjs, blksPerObj, blkSz, devGetBlobList, devSetBlobList, 0); + Timer t; + DS_ASSERT_OK(client->MSetD2H(inObjectKeys, devSetBlobList)); + costs[testIdx].push_back(t.ElapsedMicroSecond()); + idx++; + } + testIdx++; + } + for (auto i = 0u; i < keyNums.size(); i++) { + LOG(INFO) << FormatString( + "#### MSetD2H Test ##### %4lu KB, Key Size: %4lu B, Key num:%4lu Round:%4lu, --- result(ms): %s", + keyNums[i], blkSz * blksPerObj, keyNums[i], repeatNum, GetTimeCostUsState(costs[i])); + } +} + +TEST_F(HeteroD2HTest, TestNoExist) +{ + std::vector keyNums{ 450, 500 }; + std::vector> costs(keyNums.size()); + std::shared_ptr client; + InitAcl(0); + InitTestHeteroClient(0, client); + auto testIdx = 0u, blkSz = 1024u, blksPerObj = 1u; + for (auto numOfObjs : keyNums) { + LOG(INFO) << FormatString("start test MSetD2H Test ##### %lu KB", numOfObjs); + std::vector inObjectKeys, failedKeys; + std::vector devGetBlobList, devSetBlobList; + for (auto j = 0ul; j < numOfObjs; j++) { + inObjectKeys.emplace_back(FormatString("key_%s", j)); + } + PrePareDevData(numOfObjs, blksPerObj, blkSz, devGetBlobList, devSetBlobList, 0); + DS_ASSERT_OK(client->MSetD2H(inObjectKeys, devSetBlobList)); + DS_ASSERT_OK(client->MGetH2D(inObjectKeys, devGetBlobList, failedKeys, MIN_RPC_TIMEOUT_MS)); + DS_ASSERT_TRUE(failedKeys.empty(), true); + DS_ASSERT_OK(IsSameContent(devGetBlobList, devGetBlobList, 'b')); + testIdx++; + } +} + +TEST_F(HeteroD2HTest, TestAllExist) +{ + std::vector keyNums{ 450, 500 }; + std::vector> costs(keyNums.size()); + std::shared_ptr client; + InitAcl(0); + InitTestHeteroClient(0, client); + auto testIdx = 0u, blkSz = 1024u, blksPerObj = 1u; + for (auto numOfObjs : keyNums) { + LOG(INFO) << FormatString("start test MSetD2H Test ##### %lu KB", numOfObjs); + std::vector inObjectKeys, failedKeys; + std::vector devGetBlobList, devSetBlobList; + for (auto j = 0ul; j < numOfObjs; j++) { + inObjectKeys.emplace_back(FormatString("key_%s", j)); + } + PrePareDevData(numOfObjs, blksPerObj, blkSz, devGetBlobList, devSetBlobList, 0); + DS_ASSERT_OK(client->MSetD2H(inObjectKeys, devSetBlobList)); + DS_ASSERT_OK(client->MSetD2H(inObjectKeys, devSetBlobList)); + DS_ASSERT_OK(client->MGetH2D(inObjectKeys, devGetBlobList, failedKeys, MIN_RPC_TIMEOUT_MS)); + DS_ASSERT_TRUE(failedKeys.empty(), true); + DS_ASSERT_OK(IsSameContent(devGetBlobList, devGetBlobList, 'b')); + testIdx++; + } +} + +TEST_F(HeteroD2HTest, TestPartExist) +{ + std::vector keyNums{ 450, 500 }; + std::vector> costs(keyNums.size()); + std::shared_ptr client; + InitAcl(0); + InitTestHeteroClient(0, client); + auto testIdx = 0u, blkSz = 1024u, blksPerObj = 1u; + for (auto numOfObjs : keyNums) { + LOG(INFO) << FormatString("start test MSetD2H Test ##### %lu KB", numOfObjs); + std::vector inObjectKeys, failedKeys, partKeys; + std::vector devGetBlobList, devSetBlobList, partDevBlobList; + for (auto j = 0ul; j < numOfObjs; j++) { + inObjectKeys.emplace_back(FormatString("key_%s", j)); + } + PrePareDevData(numOfObjs, blksPerObj, blkSz, devGetBlobList, devSetBlobList, 0); + std::vector randIdxs = { 1, 5, 6, 7, 10, 33 }; + for (auto randIdx : randIdxs) { + partKeys.push_back(inObjectKeys[randIdx]); + partDevBlobList.push_back(devSetBlobList[randIdx]); + } + DS_ASSERT_OK(client->MSetD2H(partKeys, partDevBlobList)); + DS_ASSERT_OK(client->MSetD2H(inObjectKeys, devSetBlobList)); + DS_ASSERT_OK(client->MGetH2D(inObjectKeys, devGetBlobList, failedKeys, MIN_RPC_TIMEOUT_MS)); + DS_ASSERT_TRUE(failedKeys.empty(), true); + DS_ASSERT_OK(IsSameContent(devGetBlobList, devGetBlobList, 'b')); + testIdx++; + } +} + +} // namespace st +} // namespace datasystem diff --git a/tests/st/device/hetero_get_meta_info_test.cpp b/tests/st/device/hetero_get_meta_info_test.cpp index 4796f7b..7a057d7 100644 --- a/tests/st/device/hetero_get_meta_info_test.cpp +++ b/tests/st/device/hetero_get_meta_info_test.cpp @@ -449,8 +449,7 @@ TEST_F(HeteroGetMetaInfoTest, Exceed10000Key) inObjectKeys.pop_back(); devSetBlobList.pop_back(); devSetBlobList[0].deviceIdx = -1; - DS_ASSERT_OK( - hasStr(client->DevPublish(inObjectKeys, devSetBlobList, futures), "Got Error/ABNORMAL device, deviceId: -1")); + DS_ASSERT_TRUE(client->DevPublish(inObjectKeys, devSetBlobList, futures).GetCode(), StatusCode::K_INVALID); }; } // namespace st } // namespace datasystem \ No newline at end of file diff --git a/tests/st/device/mock/ascend_device_manager_mock.cpp b/tests/st/device/mock/ascend_device_manager_mock.cpp index 1118a73..05b2763 100644 --- a/tests/st/device/mock/ascend_device_manager_mock.cpp +++ b/tests/st/device/mock/ascend_device_manager_mock.cpp @@ -348,10 +348,9 @@ public: return Status::OK(); } - Status DSHcclGetCommAsyncError(HcclComm comm, HcclResult *asyncError) override + Status DSHcclGetCommAsyncError(HcclComm comm) override { (void)comm; - *asyncError = HcclResult::HCCL_SUCCESS; return Status::OK(); } @@ -425,15 +424,13 @@ public: return Status::OK(); } - Status aclrtQueryDeviceStatus(uint32_t deviceId, int32_t *deviceStatus) override + Status aclrtQueryDeviceStatus(uint32_t deviceId) override { uint32_t count; aclrtGetDeviceCount(&count); - int32_t mockStatus = 0; if (deviceId >= count) { return Status(K_INVALID, "Illegal or abnormal device id."); } - *deviceStatus = mockStatus; return Status::OK(); } diff --git a/tests/st/master/object_cache/oc_giveup_primary_test.cpp b/tests/st/master/object_cache/oc_giveup_primary_test.cpp index b56596f..733c91e 100644 --- a/tests/st/master/object_cache/oc_giveup_primary_test.cpp +++ b/tests/st/master/object_cache/oc_giveup_primary_test.cpp @@ -141,6 +141,7 @@ public: param.workerWorkerService = nullptr; param.workerWorkerService = nullptr; param.isOcEnabled = true; + param.isScEnabled = false; RETURN_IF_NOT_OK(replicaManager_->Init(param)); RETURN_IF_NOT_OK(replicaManager_->AddOrSwitchTo(param.currWorkerId, ReplicaType::Primary)); etcdCM_->SetWorkerReady(); diff --git a/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp b/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp index 56c1daf..4aa282a 100644 --- a/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp +++ b/tests/st/master/object_cache/oc_migrate_metadata_manager_test.cpp @@ -108,6 +108,7 @@ public: param.masterWorkerService = nullptr; param.workerWorkerService = nullptr; param.isOcEnabled = true; + param.isScEnabled = false; RETURN_IF_NOT_OK(replicaManager_->Init(param)); RETURN_IF_NOT_OK(replicaManager_->AddOrSwitchTo(param.currWorkerId, ReplicaType::Primary)); etcdCM_->SetWorkerReady(); diff --git a/tests/st/master/replica_manager_test.cpp b/tests/st/master/replica_manager_test.cpp index 6754ef4..f869131 100644 --- a/tests/st/master/replica_manager_test.cpp +++ b/tests/st/master/replica_manager_test.cpp @@ -65,10 +65,13 @@ public: (void)primaryNodeId; return Status::OK(); } - Status CreateMetaManager(const std::string &dbName, RocksStore *objectRocksStore) override + + Status CreateMetaManager(const std::string &dbName, RocksStore *objectRocksStore, + RocksStore *streamRocksStore) override { (void)dbName; (void)objectRocksStore; + (void)streamRocksStore; return Status::OK(); } Status DestroyMetaManager(const std::string &dbName) override @@ -156,7 +159,9 @@ public: param.etcdCM = nullptr; param.masterWorkerService = nullptr; param.workerWorkerService = nullptr; + param.rpcSessionManager = nullptr; param.isOcEnabled = true; + param.isScEnabled = false; auto manager = std::make_unique(); RETURN_IF_NOT_OK(manager->Init(param)); std::lock_guard locker(mutex_); diff --git a/tests/st/master/stream_cache/pub_sub_topo_concurrent_test.cpp b/tests/st/master/stream_cache/pub_sub_topo_concurrent_test.cpp new file mode 100644 index 0000000..2a78976 --- /dev/null +++ b/tests/st/master/stream_cache/pub_sub_topo_concurrent_test.cpp @@ -0,0 +1,363 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test ObjectMeta Storage basic functions. + */ +#include +#include +#include "common.h" + +#include "common/stream_cache/stream_common.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/worker/stream_cache/worker_master_sc_api.h" + +namespace datasystem { +namespace st { +constexpr int K_TWO = 2; +using namespace datasystem::client::stream_cache; +class PubSubTopoConcurrentTest : public ExternalClusterTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 5; + opts.numEtcd = 1; + opts.isStreamCacheCase = true; + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + clientVector_.clear(); + ExternalClusterTest::TearDown(); + } + +protected: + void InitTest() + { + std::vector workerAddressVector(clientNum_); + for (int i = 0; i < clientNum_; ++i) { + DS_ASSERT_OK(cluster_->GetWorkerAddr(i, workerAddressVector[i])); + LOG(INFO) << FormatString("Worker%d: <%s>", i, workerAddressVector[i].ToString()); + } + + clientVector_.resize(clientNum_); + for (size_t i = 0; i < clientVector_.size(); i++) { + ConnectOptions option; + option.host = workerAddressVector[i].Host(); + option.port = workerAddressVector[i].Port(); + option.accessKey = accessKey_; + option.secretKey = secretKey_; + clientVector_[i] = std::make_unique(option); + EXPECT_NE(clientVector_[i], nullptr); + DS_ASSERT_OK(clientVector_[i]->Init()); + } + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + + Status TryAndDeleteStream(std::unique_ptr &spClient, std::string streamName) + { + // if pending notifications retry delete + Status rc = Status::OK(); + do { + rc = spClient->DeleteStream(streamName); + if (rc.IsError()) { + sleep(K_TWO); + } + } while (rc.GetCode() == StatusCode::K_SC_STREAM_NOTIFICATION_PENDING); + return rc; + } + + std::vector> clientVector_; + uint8_t clientNum_ = 5; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(PubSubTopoConcurrentTest, MNodeMPMC) +{ + std::string stream1("stream1"); + std::vector configVector = { SubscriptionConfig("sub1", SubscriptionType::STREAM), + SubscriptionConfig("sub2", SubscriptionType::STREAM), + SubscriptionConfig("sub3", SubscriptionType::STREAM), + SubscriptionConfig("sub4", SubscriptionType::STREAM), + SubscriptionConfig("sub5", SubscriptionType::STREAM) }; + { + ThreadPool pool(clientNum_); + auto t1 = pool.Submit([this, stream1, &configVector]() { + std::shared_ptr n0p0; + std::shared_ptr n4p1; + std::shared_ptr n0c0; + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, n0p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[4]->CreateProducer(stream1, n4p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->Subscribe(stream1, configVector[0], n0c0)); + DS_ASSERT_OK(n0p0->Close()); + DS_ASSERT_OK(n4p1->Close()); + DS_ASSERT_OK(n0c0->Close()); + }); + auto t2 = pool.Submit([this, stream1, &configVector]() { + std::shared_ptr n1p0; + std::shared_ptr n0p1; + std::shared_ptr n1c0; + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, n1p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, n0p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->Subscribe(stream1, configVector[1], n1c0)); + DS_ASSERT_OK(n1p0->Close()); + DS_ASSERT_OK(n0p1->Close()); + DS_ASSERT_OK(n1c0->Close()); + }); + auto t3 = pool.Submit([this, stream1, &configVector]() { + std::shared_ptr n2p0; + std::shared_ptr n1p1; + std::shared_ptr n2c0; + DS_ASSERT_OK(clientVector_[2]->CreateProducer(stream1, n2p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, n1p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[2]->Subscribe(stream1, configVector[2], n2c0)); + DS_ASSERT_OK(n2p0->Close()); + DS_ASSERT_OK(n1p1->Close()); + DS_ASSERT_OK(n2c0->Close()); + }); + auto t4 = pool.Submit([this, stream1, &configVector]() { + std::shared_ptr n3p0; + std::shared_ptr n2p1; + std::shared_ptr n3c0; + DS_ASSERT_OK(clientVector_[3]->CreateProducer(stream1, n3p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[2]->CreateProducer(stream1, n2p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[3]->Subscribe(stream1, configVector[3], n3c0)); + DS_ASSERT_OK(n3p0->Close()); + DS_ASSERT_OK(n2p1->Close()); + DS_ASSERT_OK(n3c0->Close()); + }); + auto t5 = pool.Submit([this, stream1, &configVector]() { + std::shared_ptr n4p0; + std::shared_ptr n3p1; + std::shared_ptr n4c0; + DS_ASSERT_OK(clientVector_[4]->CreateProducer(stream1, n4p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[3]->CreateProducer(stream1, n3p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[4]->Subscribe(stream1, configVector[4], n4c0)); + DS_ASSERT_OK(n4p0->Close()); + DS_ASSERT_OK(n3p1->Close()); + DS_ASSERT_OK(n4c0->Close()); + }); + t1.wait(); + t2.wait(); + t3.wait(); + t4.wait(); + t5.wait(); + // wait sync notification finish. + int delaySec = 3; + sleep(delaySec); + EXPECT_EQ(clientVector_[1]->DeleteStream(stream1), Status::OK()); + } +} + +TEST_F(PubSubTopoConcurrentTest, MNodeMPMCSerial) +{ + std::string stream1("stream1"); + std::vector configVector = { SubscriptionConfig("sub1", SubscriptionType::STREAM), + SubscriptionConfig("sub2", SubscriptionType::STREAM), + SubscriptionConfig("sub3", SubscriptionType::STREAM), + SubscriptionConfig("sub4", SubscriptionType::STREAM), + SubscriptionConfig("sub5", SubscriptionType::STREAM) }; + std::shared_ptr n0p0; + std::shared_ptr n4p1; + std::shared_ptr n0c0; + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, n0p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[4]->CreateProducer(stream1, n4p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->Subscribe(stream1, configVector[0], n0c0)); + DS_ASSERT_OK(n0p0->Close()); + DS_ASSERT_OK(n4p1->Close()); + DS_ASSERT_OK(n0c0->Close()); + + std::shared_ptr n1p0; + std::shared_ptr n0p1; + std::shared_ptr n1c0; + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, n1p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, n0p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->Subscribe(stream1, configVector[1], n1c0)); + DS_ASSERT_OK(n1p0->Close()); + DS_ASSERT_OK(n0p1->Close()); + DS_ASSERT_OK(n1c0->Close()); + + std::shared_ptr n2p0; + std::shared_ptr n1p1; + std::shared_ptr n2c0; + DS_ASSERT_OK(clientVector_[2]->CreateProducer(stream1, n2p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, n1p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[2]->Subscribe(stream1, configVector[2], n2c0)); + DS_ASSERT_OK(n2p0->Close()); + DS_ASSERT_OK(n1p1->Close()); + DS_ASSERT_OK(n2c0->Close()); + + std::shared_ptr n3p0; + std::shared_ptr n2p1; + std::shared_ptr n3c0; + DS_ASSERT_OK(clientVector_[3]->CreateProducer(stream1, n3p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[2]->CreateProducer(stream1, n2p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[3]->Subscribe(stream1, configVector[3], n3c0)); + DS_ASSERT_OK(n3p0->Close()); + DS_ASSERT_OK(n2p1->Close()); + DS_ASSERT_OK(n3c0->Close()); + + std::shared_ptr n4p0; + std::shared_ptr n3p1; + std::shared_ptr n4c0; + DS_ASSERT_OK(clientVector_[4]->CreateProducer(stream1, n4p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[3]->CreateProducer(stream1, n3p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[4]->Subscribe(stream1, configVector[4], n4c0)); + DS_ASSERT_OK(n4p0->Close()); + DS_ASSERT_OK(n3p1->Close()); + DS_ASSERT_OK(n4c0->Close()); + EXPECT_EQ(clientVector_[4]->DeleteStream(stream1), Status::OK()); +} + +TEST_F(PubSubTopoConcurrentTest, TwoNodeMPMC) +{ + std::string stream1("stream1"); + std::vector configVector = { SubscriptionConfig("sub0", SubscriptionType::STREAM), + SubscriptionConfig("sub1", SubscriptionType::STREAM) }; + { + ThreadPool pool(2); + auto t1 = pool.Submit([this, stream1, &configVector]() { + std::shared_ptr n0p0; + std::shared_ptr n1p1; + std::shared_ptr n0c0; + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, n0p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, n1p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->Subscribe(stream1, configVector[0], n0c0)); + DS_ASSERT_OK(n0p0->Close()); + DS_ASSERT_OK(n1p1->Close()); + DS_ASSERT_OK(n0c0->Close()); + }); + auto t2 = pool.Submit([this, stream1, &configVector]() { + std::shared_ptr n1p0; + std::shared_ptr n0p1; + std::shared_ptr n1c0; + DS_ASSERT_OK(clientVector_[1]->CreateProducer(stream1, n1p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->CreateProducer(stream1, n0p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->Subscribe(stream1, configVector[1], n1c0)); + DS_ASSERT_OK(n1p0->Close()); + DS_ASSERT_OK(n0p1->Close()); + DS_ASSERT_OK(n1c0->Close()); + }); + t1.wait(); + t2.wait(); + DS_ASSERT_OK(TryAndDeleteStream(clientVector_[1], stream1)); + } +} + +TEST_F(PubSubTopoConcurrentTest, DISABLED_MSMNodeMPMC) +{ + std::vector configVector = { SubscriptionConfig("sub1", SubscriptionType::STREAM), + SubscriptionConfig("sub2", SubscriptionType::STREAM), + SubscriptionConfig("sub3", SubscriptionType::STREAM), + SubscriptionConfig("sub4", SubscriptionType::STREAM), + SubscriptionConfig("sub5", SubscriptionType::STREAM) }; + ThreadPool pool(10); + for (int i = 0; i < 10; ++i) { + pool.Submit([this, i, &configVector]() { + std::string streamName = "stream" + std::to_string(i); + ThreadPool pool1(clientNum_); + auto t1 = pool1.Submit([this, streamName, &configVector]() { + std::shared_ptr n0p0; + std::shared_ptr n4p1; + std::shared_ptr n0c0; + std::shared_ptr n0p2; + DS_ASSERT_OK(clientVector_[0]->CreateProducer(streamName, n0p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[4]->CreateProducer(streamName, n4p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->Subscribe(streamName, configVector[0], n0c0)); + DS_ASSERT_OK(clientVector_[0]->CreateProducer(streamName, n0p2, defaultProducerConf_)); + DS_ASSERT_OK(n0p0->Close()); + DS_ASSERT_OK(n4p1->Close()); + DS_ASSERT_OK(n0c0->Close()); + DS_ASSERT_OK(n0p2->Close()); + }); + auto t2 = pool1.Submit([this, streamName, &configVector]() { + std::shared_ptr n1p0; + std::shared_ptr n0p1; + std::shared_ptr n1c0; + std::shared_ptr n1p2; + DS_ASSERT_OK(clientVector_[1]->CreateProducer(streamName, n1p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[0]->CreateProducer(streamName, n0p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->Subscribe(streamName, configVector[1], n1c0)); + DS_ASSERT_OK(clientVector_[1]->CreateProducer(streamName, n1p2, defaultProducerConf_)); + DS_ASSERT_OK(n1p0->Close()); + DS_ASSERT_OK(n0p1->Close()); + DS_ASSERT_OK(n1c0->Close()); + DS_ASSERT_OK(n1p2->Close()); + }); + auto t3 = pool1.Submit([this, streamName, &configVector]() { + std::shared_ptr n2p0; + std::shared_ptr n1p1; + std::shared_ptr n2c0; + std::shared_ptr n2p2; + DS_ASSERT_OK(clientVector_[2]->CreateProducer(streamName, n2p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[1]->CreateProducer(streamName, n1p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[2]->Subscribe(streamName, configVector[2], n2c0)); + DS_ASSERT_OK(clientVector_[2]->CreateProducer(streamName, n2p2, defaultProducerConf_)); + DS_ASSERT_OK(n2p0->Close()); + DS_ASSERT_OK(n1p1->Close()); + DS_ASSERT_OK(n2c0->Close()); + DS_ASSERT_OK(n2p2->Close()); + }); + auto t4 = pool1.Submit([this, streamName, &configVector]() { + std::shared_ptr n3p0; + std::shared_ptr n2p1; + std::shared_ptr n3c0; + std::shared_ptr n3p2; + DS_ASSERT_OK(clientVector_[3]->CreateProducer(streamName, n3p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[2]->CreateProducer(streamName, n2p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[3]->Subscribe(streamName, configVector[3], n3c0)); + DS_ASSERT_OK(clientVector_[3]->CreateProducer(streamName, n3p2, defaultProducerConf_)); + DS_ASSERT_OK(n3p0->Close()); + DS_ASSERT_OK(n2p1->Close()); + DS_ASSERT_OK(n3c0->Close()); + DS_ASSERT_OK(n3p2->Close()); + }); + auto t5 = pool1.Submit([this, streamName, &configVector]() { + std::shared_ptr n4p0; + std::shared_ptr n3p1; + std::shared_ptr n4c0; + std::shared_ptr n4p2; + DS_ASSERT_OK(clientVector_[4]->CreateProducer(streamName, n4p0, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[3]->CreateProducer(streamName, n3p1, defaultProducerConf_)); + DS_ASSERT_OK(clientVector_[4]->Subscribe(streamName, configVector[4], n4c0)); + DS_ASSERT_OK(clientVector_[4]->CreateProducer(streamName, n4p2, defaultProducerConf_)); + DS_ASSERT_OK(n4p0->Close()); + DS_ASSERT_OK(n3p1->Close()); + DS_ASSERT_OK(n4c0->Close()); + DS_ASSERT_OK(n4p2->Close()); + }); + t1.wait(); + t2.wait(); + t3.wait(); + t4.wait(); + t5.wait(); + EXPECT_EQ(clientVector_[1]->DeleteStream(streamName), Status::OK()); + }); + } +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/master/stream_cache/pub_sub_topo_test.cpp b/tests/st/master/stream_cache/pub_sub_topo_test.cpp new file mode 100644 index 0000000..aa49670 --- /dev/null +++ b/tests/st/master/stream_cache/pub_sub_topo_test.cpp @@ -0,0 +1,406 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test ObjectMeta Storage basic functions. + */ +#include + +#include + +#include "common.h" + +#include "common/stream_cache/stream_common.h" +#include "datasystem/client/mmap_manager.h" +#include "datasystem/stream_client.h" +#include "datasystem/stream/producer.h" +#include "datasystem/stream/consumer.h" +#include "datasystem/worker/stream_cache/worker_master_sc_api.h" + +namespace datasystem { +namespace st { +using namespace datasystem::client::stream_cache; +class PubSubTopoTest : public ExternalClusterTest { +public: + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 3; + opts.numRpcThreads = 1; + opts.numEtcd = 1; + opts.isStreamCacheCase = true; + } + + void SetUp() override + { + ExternalClusterTest::SetUp(); + InitTest(); + } + + void TearDown() override + { + client1_ = nullptr; + client2_ = nullptr; + client3_ = nullptr; + ExternalClusterTest::TearDown(); + } + + void InitStreamClient(uint32_t index, std::shared_ptr &client, int32_t timeoutMs = 60000) + { + HostPort workerAddress; + ASSERT_TRUE(index < cluster_->GetWorkerNum()); + DS_ASSERT_OK(cluster_->GetWorkerAddr(index, workerAddress)); + LOG(INFO) << "worker index " << index << ": " << workerAddress.ToString(); + ConnectOptions connectOptions; + connectOptions = { .host = workerAddress.Host(), .port = workerAddress.Port(), .connectTimeoutMs = timeoutMs }; + connectOptions.accessKey = accessKey_; + connectOptions.secretKey = secretKey_; + client = std::make_shared(connectOptions); + DS_ASSERT_OK(client->Init()); + } + +protected: + void InitTest() + { + DS_ASSERT_OK(cluster_->GetWorkerAddr(0, workerAddress1_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(1, workerAddress2_)); + DS_ASSERT_OK(cluster_->GetWorkerAddr(2, workerAddress3_)); + LOG(INFO) << FormatString("\n Worker1: <%s>\n Worker2: <%s>\n Worker3: <%s>", workerAddress1_.ToString(), + workerAddress2_.ToString(), workerAddress3_.ToString()); + InitStreamClient(0, client1_); + InitStreamClient(1, client2_); + InitStreamClient(2, client3_); // index is 2. + defaultProducerConf_.maxStreamSize = TEST_STREAM_SIZE; + } + + HostPort workerAddress1_; + HostPort workerAddress2_; + HostPort workerAddress3_; + + std::shared_ptr client1_ = nullptr; + std::shared_ptr client2_ = nullptr; + std::shared_ptr client3_ = nullptr; + ProducerConf defaultProducerConf_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; +}; + +TEST_F(PubSubTopoTest, SingleStreamSingleProducerSingleConsumerBySequence) +{ + std::string stream1("stream1"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + + uint64_t producerNum = 0; + uint64_t consumerNum = 0; + + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + DS_ASSERT_OK(client1_->QueryGlobalProducersNum(stream1, producerNum)); + + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + + producerNum = 0; + consumerNum = 0; + std::shared_ptr node3Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Producer1, defaultProducerConf_)); + std::shared_ptr node3Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Consumer1)); + DS_ASSERT_OK(client3_->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(client3_->QueryGlobalConsumersNum(stream1, consumerNum)); + EXPECT_EQ(consumerNum, size_t(3)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(client2_->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(client2_->QueryGlobalConsumersNum(stream1, consumerNum)); + EXPECT_EQ(consumerNum, size_t(3)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(client1_->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(client1_->QueryGlobalConsumersNum(stream1, consumerNum)); + + EXPECT_EQ(consumerNum, size_t(3)); + + consumerNum = 0; + DS_ASSERT_OK(node1Consumer1->Close()); + DS_ASSERT_OK(client2_->QueryGlobalConsumersNum(stream1, consumerNum)); + EXPECT_EQ(consumerNum, size_t(2)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(client3_->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(client3_->QueryGlobalConsumersNum(stream1, consumerNum)); + EXPECT_EQ(consumerNum, size_t(2)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(node2Consumer1->Close()); + DS_ASSERT_OK(client1_->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(client2_->QueryGlobalConsumersNum(stream1, consumerNum)); + EXPECT_EQ(consumerNum, size_t(1)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(node3Producer1->Close()); + DS_ASSERT_OK(client2_->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(client3_->QueryGlobalConsumersNum(stream1, consumerNum)); + EXPECT_EQ(consumerNum, size_t(1)); + + producerNum = 0; + consumerNum = 0; + DS_ASSERT_OK(node3Consumer1->Close()); + DS_ASSERT_OK(client1_->QueryGlobalProducersNum(stream1, producerNum)); + DS_ASSERT_OK(client3_->QueryGlobalConsumersNum(stream1, consumerNum)); + EXPECT_EQ(consumerNum, size_t(0)); +} + +TEST_F(PubSubTopoTest, SingleStreamSingleProducerSingleConsumerByRandom) +{ + std::string stream1("stream1"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node3Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Producer1, defaultProducerConf_)); + + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + std::shared_ptr node3Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Consumer1)); + + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(node1Consumer1->Close()); + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(node2Consumer1->Close()); + DS_ASSERT_OK(node3Producer1->Close()); + DS_ASSERT_OK(node3Consumer1->Close()); +} + +TEST_F(PubSubTopoTest, MultiStreamSingleProducerSingleConsumerBySequence) +{ + std::string stream1("stream1"); + std::string stream2("stream2"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + + std::shared_ptr node1Stream1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Stream1Producer1, defaultProducerConf_)); + std::shared_ptr node2Stream1Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Stream1Producer1, defaultProducerConf_)); + std::shared_ptr node3Stream1Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Stream1Producer1, defaultProducerConf_)); + + std::shared_ptr node1Stream1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Stream1Consumer1)); + std::shared_ptr node2Stream1Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Stream1Consumer1)); + std::shared_ptr node3Stream1Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Stream1Consumer1)); + + DS_ASSERT_OK(node1Stream1Producer1->Close()); + DS_ASSERT_OK(node1Stream1Consumer1->Close()); + DS_ASSERT_OK(node2Stream1Producer1->Close()); + DS_ASSERT_OK(node2Stream1Consumer1->Close()); + DS_ASSERT_OK(node3Stream1Producer1->Close()); + DS_ASSERT_OK(node3Stream1Consumer1->Close()); + + std::shared_ptr node1Stream2Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Stream2Producer1, defaultProducerConf_)); + std::shared_ptr node2Stream2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Stream2Producer1, defaultProducerConf_)); + std::shared_ptr node3Stream2Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Stream2Producer1, defaultProducerConf_)); + + std::shared_ptr node1Stream2Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Stream2Consumer1)); + std::shared_ptr node2Stream2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Stream2Consumer1)); + std::shared_ptr node3Stream2Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Stream2Consumer1)); + + DS_ASSERT_OK(node1Stream2Producer1->Close()); + DS_ASSERT_OK(node1Stream2Consumer1->Close()); + DS_ASSERT_OK(node2Stream2Producer1->Close()); + DS_ASSERT_OK(node2Stream2Consumer1->Close()); + DS_ASSERT_OK(node3Stream2Producer1->Close()); + DS_ASSERT_OK(node3Stream2Consumer1->Close()); +} + +TEST_F(PubSubTopoTest, MultiStreamSingleProducerSingleConsumerByRandom) +{ + std::string stream1("stream1"); + std::string stream2("stream2"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + + // n1p1->n3c1->n2p1->n1c1->n3p1->n2c1 + std::shared_ptr node1Stream1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Stream1Producer1, defaultProducerConf_)); + std::shared_ptr node3Stream1Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Stream1Consumer1)); + std::shared_ptr node2Stream1Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Stream1Producer1, defaultProducerConf_)); + std::shared_ptr node1Stream1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Stream1Consumer1)); + std::shared_ptr node3Stream1Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Stream1Producer1, defaultProducerConf_)); + std::shared_ptr node2Stream1Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Stream1Consumer1)); + + DS_ASSERT_OK(node3Stream1Consumer1->Close()); + DS_ASSERT_OK(node1Stream1Producer1->Close()); + DS_ASSERT_OK(node3Stream1Producer1->Close()); + DS_ASSERT_OK(node1Stream1Consumer1->Close()); + DS_ASSERT_OK(node2Stream1Consumer1->Close()); + DS_ASSERT_OK(node2Stream1Producer1->Close()); + + // n2c1->n3c1->n1p1->n3p1->n1c1->n2p1 + std::shared_ptr node2Stream2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Stream2Consumer1)); + std::shared_ptr node3Stream2Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Stream2Consumer1)); + std::shared_ptr node1Stream2Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Stream2Producer1, defaultProducerConf_)); + std::shared_ptr node3Stream2Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Stream2Producer1, defaultProducerConf_)); + std::shared_ptr node1Stream2Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Stream2Consumer1)); + std::shared_ptr node2Stream2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Stream2Producer1, defaultProducerConf_)); + + // n3c1->n2p1->n1c1->n1p1->n2c1->n3p1 + DS_ASSERT_OK(node3Stream2Consumer1->Close()); + DS_ASSERT_OK(node2Stream2Producer1->Close()); + DS_ASSERT_OK(node1Stream2Consumer1->Close()); + DS_ASSERT_OK(node1Stream2Producer1->Close()); + DS_ASSERT_OK(node2Stream2Consumer1->Close()); + DS_ASSERT_OK(node3Stream2Producer1->Close()); +} + +TEST_F(PubSubTopoTest, SingleStreamMultiProducerSingleConsumerBySequence) +{ + std::string stream1("stream1"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + std::shared_ptr node1Producer2; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer2, defaultProducerConf_)); + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Producer2; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer2, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + + std::shared_ptr node3Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Producer1, defaultProducerConf_)); + std::shared_ptr node3Producer2; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node3Producer2, defaultProducerConf_)); + std::shared_ptr node3Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Consumer1)); + + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(node1Producer2->Close()); + DS_ASSERT_OK(node1Consumer1->Close()); + + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(node2Producer2->Close()); + DS_ASSERT_OK(node2Consumer1->Close()); + + DS_ASSERT_OK(node3Producer1->Close()); + DS_ASSERT_OK(node3Producer2->Close()); + DS_ASSERT_OK(node3Consumer1->Close()); +} + +TEST_F(PubSubTopoTest, SingleStreamMultiProducerMultiConsumerBySequence) +{ + std::string stream1("stream1"); + SubscriptionConfig config1("sub1", SubscriptionType::STREAM); + SubscriptionConfig config1Cpy("sub1_cpy", SubscriptionType::STREAM); + SubscriptionConfig config2("sub2", SubscriptionType::STREAM); + SubscriptionConfig config2Cpy("sub2_cpy", SubscriptionType::STREAM); + SubscriptionConfig config3("sub3", SubscriptionType::STREAM); + SubscriptionConfig config3Cpy("sub3_cpy", SubscriptionType::STREAM); + + std::shared_ptr node1Producer1; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer1, defaultProducerConf_)); + std::shared_ptr node1Producer2; + DS_ASSERT_OK(client1_->CreateProducer(stream1, node1Producer2, defaultProducerConf_)); + std::shared_ptr node1Consumer1; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1, node1Consumer1)); + std::shared_ptr node1Consumer2; + DS_ASSERT_OK(client1_->Subscribe(stream1, config1Cpy, node1Consumer2)); + + std::shared_ptr node2Producer1; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer1, defaultProducerConf_)); + std::shared_ptr node2Producer2; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node2Producer2, defaultProducerConf_)); + std::shared_ptr node2Consumer1; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2, node2Consumer1)); + std::shared_ptr node2Consumer2; + DS_ASSERT_OK(client2_->Subscribe(stream1, config2Cpy, node2Consumer2)); + + std::shared_ptr node3Producer1; + DS_ASSERT_OK(client3_->CreateProducer(stream1, node3Producer1, defaultProducerConf_)); + std::shared_ptr node3Producer2; + DS_ASSERT_OK(client2_->CreateProducer(stream1, node3Producer2, defaultProducerConf_)); + std::shared_ptr node3Consumer1; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3, node3Consumer1)); + std::shared_ptr node3Consumer2; + DS_ASSERT_OK(client3_->Subscribe(stream1, config3Cpy, node3Consumer2)); + + DS_ASSERT_OK(node1Producer1->Close()); + DS_ASSERT_OK(node1Producer2->Close()); + DS_ASSERT_OK(node1Consumer1->Close()); + DS_ASSERT_OK(node1Consumer2->Close()); + + DS_ASSERT_OK(node2Producer1->Close()); + DS_ASSERT_OK(node2Producer2->Close()); + DS_ASSERT_OK(node2Consumer1->Close()); + DS_ASSERT_OK(node2Consumer2->Close()); + + DS_ASSERT_OK(node3Producer1->Close()); + DS_ASSERT_OK(node3Producer2->Close()); + DS_ASSERT_OK(node3Consumer1->Close()); + DS_ASSERT_OK(node3Consumer2->Close()); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/worker/object_cache/evict_mem_test.cpp b/tests/st/worker/object_cache/evict_mem_test.cpp index a652d9e..1e04966 100644 --- a/tests/st/worker/object_cache/evict_mem_test.cpp +++ b/tests/st/worker/object_cache/evict_mem_test.cpp @@ -80,14 +80,11 @@ TEST_F(EvictMemTest, EvictThresHoldTest) auto put = [&client0](uint64_t mb) { uint64_t dataSize = mb*MB_TO_BYTES; // Put, spill will be triggered - CreateParam param; + SetParam param; param.writeMode = WriteMode::NONE_L2_CACHE_EVICT; std::string objectKey = FormatString("key_%lu", mb); std::string data(dataSize, 'x'); - std::shared_ptr buffer; - DS_ASSERT_OK(client0->Object()->Create(objectKey, dataSize, param, buffer)); - DS_ASSERT_OK(buffer->MemoryCopy(data.data(), dataSize)); - DS_ASSERT_OK(buffer->Publish()); + DS_ASSERT_OK(client0->KV()->Set(objectKey, data, param)); }; inject::Set("Exist.QueryLocalMem", "call()"); auto exist = [&client0](std::string objectKey) -> bool { diff --git a/tests/st/worker/object_cache/worker_oc_eviction_test.cpp b/tests/st/worker/object_cache/worker_oc_eviction_test.cpp index 64ee434..146c870 100644 --- a/tests/st/worker/object_cache/worker_oc_eviction_test.cpp +++ b/tests/st/worker/object_cache/worker_oc_eviction_test.cpp @@ -65,6 +65,16 @@ static Status RetryCreate(std::shared_ptr client, const std::strin return rc; } +static Status RetrySet(std::shared_ptr client, const std::string &objectKey, std::string &data, + SetParam param) +{ + Status rc; + do { + rc = client->Set(objectKey, data, param); + } while (rc.GetCode() == K_OUT_OF_MEMORY); + return rc; +} + static bool ExistsNone(std::vector> &buffers) { return std::any_of(buffers.cbegin(), buffers.cend(), [](const Optional &buffer) { return !buffer; }); @@ -1001,9 +1011,11 @@ TEST_F(EvictionManagerSaveToRedisTest, TestEvictWriteThroughObj) { std::shared_ptr client1; std::shared_ptr client2; + std::shared_ptr client; InitTestClient(0, client1); InitTestClient(1, client2); - CreateParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + InitTestKVClient(0, client); + SetParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; HostPort metaAddress; cluster_->GetMetaServerAddr(metaAddress); @@ -1016,10 +1028,7 @@ TEST_F(EvictionManagerSaveToRedisTest, TestEvictWriteThroughObj) for (int i = 0; i < objNum; i++) { // Put exceed shared_memory_size_mb, will trigger evict std::string objectKey = "key_" + std::to_string(i); - std::shared_ptr buffer; - DS_ASSERT_OK(RetryCreate(client1, objectKey, dataSize, param, buffer)); - buffer->MemoryCopy(reinterpret_cast(const_cast(data.c_str())), data.size()); - DS_ASSERT_OK(buffer->Publish()); + DS_ASSERT_OK(RetrySet(client, objectKey, data, param)); // Remote get exceed shared_memory_size_mb, will trigger evict std::vector> buffers; @@ -1129,9 +1138,11 @@ TEST_F(EvictionManagerEndToEndTest, LEVEL2_TestEvictWriteThroughObj) { std::shared_ptr client1; std::shared_ptr client2; + std::shared_ptr client; InitTestClient(0, client1); InitTestClient(1, client2); - CreateParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + InitTestKVClient(0, client); + SetParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; HostPort metaAddress; cluster_->GetMetaServerAddr(metaAddress); @@ -1144,10 +1155,7 @@ TEST_F(EvictionManagerEndToEndTest, LEVEL2_TestEvictWriteThroughObj) for (int i = 0; i < objNum; i++) { // Put exceed shared_memory_size_mb, will trigger evict std::string objectKey = "key_" + std::to_string(i); - std::shared_ptr buffer; - DS_ASSERT_OK(RetryCreate(client1, objectKey, dataSize, param, buffer)); - buffer->MemoryCopy(reinterpret_cast(const_cast(data.c_str())), data.size()); - DS_ASSERT_OK(buffer->Publish()); + DS_ASSERT_OK(RetrySet(client, objectKey, data, param)); // Remote get exceed shared_memory_size_mb, will trigger evict std::vector> buffers; @@ -1476,7 +1484,7 @@ TEST_F(EvictionManagerEndToEndTest, DISABLED_MutableSpillMultiNodeTest) InitTestClient(0, client0); InitTestClient(1, client1); uint64_t dataSize = 500 * 1024; - CreateParam param{ .writeMode = WriteMode::NONE_L2_CACHE, .consistencyType = ConsistencyType::CAUSAL }; + CreateParam param{ .consistencyType = ConsistencyType::CAUSAL }; HostPort metaAddress; cluster_->GetMetaServerAddr(metaAddress); int objNum = 30; // obj num is 30; @@ -1547,7 +1555,7 @@ TEST_F(EvictionManagerEndToEndTest2, DISABLED_TestEvictWriteThroughSpaceFull) std::shared_ptr client2; InitTestClient(0, client1); InitTestClient(1, client2); - CreateParam param{ .writeMode = WriteMode::WRITE_THROUGH_L2_CACHE }; + CreateParam param{}; HostPort metaAddress; cluster_->GetMetaServerAddr(metaAddress); diff --git a/tests/st/worker/stream_cache/master_worker_sc_api_test.cpp b/tests/st/worker/stream_cache/master_worker_sc_api_test.cpp new file mode 100644 index 0000000..e0277a0 --- /dev/null +++ b/tests/st/worker/stream_cache/master_worker_sc_api_test.cpp @@ -0,0 +1,65 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/common/rpc/rpc_stub_cache_mgr.h" +#include "datasystem/master/stream_cache/master_worker_sc_api.h" + +#include "common.h" + +using datasystem::master::MasterWorkerSCApi; + +namespace datasystem { +namespace st { +class MasterWorkerSCApiTest : public ExternalClusterTest { +public: + void SetUp() override + { + std::shared_ptr akSkManager_ = std::make_shared(0); + akSkManager_->SetClientAkSk(accessKey_, secretKey_); + ClusterTest::SetUp(); + InitMasterWorkerSCApi(); + } + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 1; + opts.numEtcd = 1; + opts.isStreamCacheCase = true; + } + +protected: + void InitMasterWorkerSCApi() + { + HostPort metaAddr; + DS_ASSERT_OK(cluster_->GetMetaServerAddr(metaAddr)); + HostPort localHost("127.0.0.2", 18888); + int stubCacheNum = 100; + RpcStubCacheMgr::Instance().Init(stubCacheNum); + masterWorkerSCApi_ = MasterWorkerSCApi::CreateMasterWorkerSCApi(localHost, localHost, akSkManager_, nullptr); + } + + std::shared_ptr masterWorkerSCApi_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + std::shared_ptr akSkManager_{ nullptr }; +}; + +TEST_F(MasterWorkerSCApiTest, TestMasterWorkerDiffIp) +{ + DS_ASSERT_OK(masterWorkerSCApi_->Init()); +} +} // namespace st +} // namespace datasystem diff --git a/tests/st/worker/stream_cache/worker_master_sc_api_test.cpp b/tests/st/worker/stream_cache/worker_master_sc_api_test.cpp new file mode 100644 index 0000000..d97a01d --- /dev/null +++ b/tests/st/worker/stream_cache/worker_master_sc_api_test.cpp @@ -0,0 +1,64 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datasystem/worker/stream_cache/worker_master_sc_api.h" + +#include "common.h" + +using datasystem::worker::stream_cache::WorkerMasterSCApi; + +namespace datasystem { +namespace st { +class WorkerMasterSCApiTest : public ExternalClusterTest { +public: + void SetUp() override + { + ClusterTest::SetUp(); + InitWorkerMasterSCApi(); + } + + void SetClusterSetupOptions(ExternalClusterOptions &opts) override + { + opts.numWorkers = 1; + opts.numEtcd = 1; + opts.isStreamCacheCase = true; + } + +protected: + void InitWorkerMasterSCApi() + { + HostPort metaAddr; + DS_ASSERT_OK(cluster_->GetMetaServerAddr(metaAddr)); + akSkManager_ = std::make_shared(0); + akSkManager_->SetClientAkSk(accessKey_, secretKey_); + int stubCacheNum = 100; + RpcStubCacheMgr::Instance().Init(stubCacheNum); + workerMasterSCApi_ = + WorkerMasterSCApi::CreateWorkerMasterSCApi(metaAddr, HostPort("127.0.0.1", 18888), akSkManager_); + } + + std::shared_ptr workerMasterSCApi_; + std::string accessKey_ = "QTWAOYTTINDUT2QVKYUC"; + std::string secretKey_ = "MFyfvK41ba2giqM7**********KGpownRZlmVmHc"; + std::shared_ptr akSkManager_; +}; + +TEST_F(WorkerMasterSCApiTest, TestWorkerMasterDiffIp) +{ + DS_ASSERT_OK(workerMasterSCApi_->Init()); +} +} // namespace st +} // namespace datasystem diff --git a/tests/ut/CMakeLists.txt b/tests/ut/CMakeLists.txt index 2710012..e3e022d 100644 --- a/tests/ut/CMakeLists.txt +++ b/tests/ut/CMakeLists.txt @@ -21,6 +21,7 @@ set(DS_UT_DEPEND_LIBS common_ak_sk common_persistence_api common_immutable_string + string_ref httpclient master_object_cache worker_object_cache @@ -35,6 +36,7 @@ include_directories(${PROJECT_DIR}/src) # fetch ut test files and remove the files we don't care about. file(GLOB_RECURSE DS_TEST_UT_SRCS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") +file(GLOB_RECURSE DS_UT_STREAM_SRCS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/**/stream_cache/*.cpp") file(GLOB_RECURSE DS_UT_OBJECT_SRCS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/**/object_cache/*.cpp") set(UT_COMMON_SRCS @@ -44,21 +46,28 @@ set(UT_COMMON_SRCS list(FILTER DS_TEST_UT_SRCS EXCLUDE REGEX .*/device/.*) list(FILTER DS_TEST_UT_SRCS EXCLUDE REGEX .*/binmock/.*) list(FILTER DS_TEST_UT_SRCS EXCLUDE REGEX .*/flags/.*) +list(FILTER DS_TEST_UT_SRCS EXCLUDE REGEX .*/stream_cache/.*) list(FILTER DS_TEST_UT_SRCS EXCLUDE REGEX .*/object_cache/.*) +list(APPEND DS_UT_STREAM_SRCS) + add_executable(ds_ut ${DS_TEST_UT_SRCS}) +add_executable(ds_ut_stream ${DS_UT_STREAM_SRCS} ${UT_COMMON_SRCS}) add_executable(ds_ut_object ${DS_UT_OBJECT_SRCS} ${UT_COMMON_SRCS}) target_link_libraries(ds_ut PRIVATE ${DS_UT_DEPEND_LIBS}) +target_link_libraries(ds_ut_stream PRIVATE ${DS_UT_DEPEND_LIBS} worker_stream_cache ds_worker) target_link_libraries(ds_ut_object PRIVATE ${DS_UT_DEPEND_LIBS}) set(BIN_PATH_LIST LLT_BIN_PATH="${CMAKE_CURRENT_BINARY_DIR}") target_compile_definitions(ds_ut PRIVATE ${BIN_PATH_LIST}) +target_compile_definitions(ds_ut_stream PRIVATE ${BIN_PATH_LIST}) target_compile_definitions(ds_ut_object PRIVATE ${BIN_PATH_LIST}) set(TEST_ENVIRONMENT "LD_LIBRARY_PATH=${OpenSSL_LIB_PATH}:${gRPC_LIB_PATH}:${OBS_LIB_PATH}:${ICONV_ROOT}/lib:$ENV{LD_LIBRARY_PATH}" CACHE STRING "") add_datasystem_test(ds_ut TEST_ENVIRONMENTS ${TEST_ENVIRONMENT}) +add_datasystem_test(ds_ut_stream TEST_ENVIRONMENTS ${TEST_ENVIRONMENT}) add_datasystem_test(ds_ut_object TEST_ENVIRONMENTS ${TEST_ENVIRONMENT}) add_executable(flags_ut "${CMAKE_CURRENT_SOURCE_DIR}/common/flags/flags_test.cpp") diff --git a/tests/ut/common.cpp b/tests/ut/common.cpp index 53f9d5c..41b1ac7 100644 --- a/tests/ut/common.cpp +++ b/tests/ut/common.cpp @@ -83,7 +83,8 @@ CommonTest::CommonTest() #ifdef USE_URMA // Test fixture can override SetUp(), so keeping init here if (UrmaManager::IsUrmaEnabled()) { - UrmaManager::Instance().Init("127.0.0.1"); + HostPort addr("127.0.0.1", 0); + UrmaManager::Instance().Init(addr); FLAGS_arena_per_tenant = 1; } #endif diff --git a/tests/ut/common/kvstore/rocks_replica_test.cpp b/tests/ut/common/kvstore/rocks_replica_test.cpp index b466295..0d12710 100644 --- a/tests/ut/common/kvstore/rocks_replica_test.cpp +++ b/tests/ut/common/kvstore/rocks_replica_test.cpp @@ -26,6 +26,7 @@ #include "common.h" #include "datasystem/common/util/file_util.h" +#include "datasystem/master/stream_cache/store/rocks_stream_meta_store.h" #include "datasystem/utils/status.h" namespace datasystem { @@ -165,6 +166,17 @@ TEST_F(RocksReplicaTest, TestFullSync) primary.RegisterRpcChannel(&mockChannel); backup.RegisterRpcChannel(&mockChannel); + // write stream metadata + using master::stream_cache::RocksStreamMetaStore; + DS_ASSERT_OK(Replica::CreateScTable(primary.GetStreamRocksStore())); + RocksStreamMetaStore primaryMetaStore(primary.GetStreamRocksStore()); + ProducerMetaPb producerMeta; + producerMeta.set_stream_name("stream1"); + producerMeta.mutable_worker_address()->set_host("127.0.0.1"); + int port = 8080; + producerMeta.mutable_worker_address()->set_port(port); + DS_ASSERT_OK(primaryMetaStore.AddPubNode(producerMeta)); + auto masterStore = primary.GetObjectRocksStore(); DS_ASSERT_OK(masterStore->CreateTable("table")); int logNum = 100000; @@ -184,6 +196,12 @@ TEST_F(RocksReplicaTest, TestFullSync) auto followerStore = backup.GetObjectRocksStore(); DS_ASSERT_OK(followerStore->Get("table", "key_0", value)); ASSERT_EQ(value, "value"); + + RocksStreamMetaStore backupMetaStore(backup.GetStreamRocksStore()); + std::vector producerMetaPbs; + DS_ASSERT_OK(backupMetaStore.GetOneStreamProducers("stream1", producerMetaPbs)); + ASSERT_TRUE(!producerMetaPbs.empty()); + LOG(INFO) << producerMetaPbs[0].ShortDebugString(); } } // namespace ut } // namespace datasystem diff --git a/tests/ut/common/log/logging_test.cpp b/tests/ut/common/log/logging_test.cpp index 9c34489..0a1103e 100644 --- a/tests/ut/common/log/logging_test.cpp +++ b/tests/ut/common/log/logging_test.cpp @@ -318,7 +318,7 @@ TEST_F(LoggingTest, TestEnvSucceed) RETRY_TIMEOUT_SECONDS, interval)); } -TEST_F(LoggingTest, TestMultiTimeCostLoggerRecord) +TEST_F(LoggingTest, DISABLED_TestMultiTimeCostLoggerRecord) { FLAGS_log_monitor = true; FLAGS_max_log_size = 10; @@ -518,7 +518,7 @@ TEST_F(LoggingTest, TestWriteLogWhenChangeEnv) } } -TEST_F(LoggingTest, TestMinLogLevel) +TEST_F(LoggingTest, TestMinLogLevelNotWriteToFile) { int replace = 1; FLAGS_logbufsecs = 0; @@ -553,6 +553,22 @@ TEST_F(LoggingTest, TestMinLogLevel) ASSERT_TRUE(isAccessLogExist); } +TEST_F(LoggingTest, TestMinLogLevelNotCallFunction) +{ + auto expensiveCall = [] { + sleep(1); + return "hello"; + }; + FLAGS_minloglevel = 1; + Timer timer; + const int loopCount = 10; + for (int i = 0; i < loopCount; i++) { + LOG(INFO) << expensiveCall(); + } + ASSERT_LT(timer.ElapsedSecond(), 1); + LOG(ERROR) << "cost: " << timer.ElapsedMicroSecond(); +} + TEST_F(LoggingTest, TestDisableClientLogMonitor) { int replace = 1; @@ -643,6 +659,5 @@ TEST_F(LoggingTest, TestLogName) filepath = FLAGS_log_dir + "/test_client.INFO.log"; ASSERT_TRUE(FileExist(filepath)); } - } // namespace ut } // namespace datasystem diff --git a/tests/ut/common/log/spdlog/log_message_test.cpp b/tests/ut/common/log/spdlog/log_message_test.cpp index 341d787..15bfdb2 100644 --- a/tests/ut/common/log/spdlog/log_message_test.cpp +++ b/tests/ut/common/log/spdlog/log_message_test.cpp @@ -327,6 +327,7 @@ TEST_F(LogMessageTest, AllCheckMacrosPassWhenConditionTrue) const char *first_string = "apple"; const char *second_string = "banana"; CHECK_STRNE(first_string, second_string); + ASSERT_STREQ(SafeStringOutput(nullptr), "(null)"); } using LogMessageDeathTest = LogMessageTest; diff --git a/tests/ut/common/shared_memory/allocator_test.cpp b/tests/ut/common/shared_memory/allocator_test.cpp index 1fafc50..dae2814 100644 --- a/tests/ut/common/shared_memory/allocator_test.cpp +++ b/tests/ut/common/shared_memory/allocator_test.cpp @@ -71,6 +71,7 @@ struct AllocatorConfig { ssize_t decayMs = 5'000; int objectThreshold = 100; int streamThreshold = 100; + ServiceType serviceType = ServiceType::OBJECT; memory::CacheType cacheType = memory::CacheType::MEMORY; AllocatorConfig() = default; @@ -100,7 +101,7 @@ public: } return datasystem::memory::Allocator::Instance()->Init(config_.shmSize, config_.shdSize, config_.populate, config_.scaling, config_.decayMs, - config_.objectThreshold); + config_.objectThreshold, config_.streamThreshold); } uint64_t MaxSize() @@ -111,17 +112,19 @@ public: Status AllocateMemory(uint64_t needSize, ShmUnitInfo &unit, const std::string &tenantId = DEFAULT_TENANT_ID) { return datasystem::memory::Allocator::Instance()->AllocateMemory( - tenantId, needSize, config_.populate, unit.pointer, unit.fd, unit.offset, unit.mmapSize, config_.cacheType); + tenantId, needSize, config_.populate, unit.pointer, unit.fd, unit.offset, unit.mmapSize, + config_.serviceType, config_.cacheType); } Status AllocateMemory(ShmUnit &unit, uint64_t needSize, const std::string &tenantId = DEFAULT_TENANT_ID) { - return unit.AllocateMemory(tenantId, needSize, config_.populate, config_.cacheType); + return unit.AllocateMemory(tenantId, needSize, config_.populate, config_.serviceType, config_.cacheType); } Status FreeMemory(void *&pointer, const std::string &tenantId = DEFAULT_TENANT_ID) { - return datasystem::memory::Allocator::Instance()->FreeMemory(tenantId, pointer, config_.cacheType); + return datasystem::memory::Allocator::Instance()->FreeMemory(tenantId, pointer, config_.serviceType, + config_.cacheType); } Status FreeMemory(ShmUnit &unit) @@ -151,9 +154,9 @@ public: return datasystem::memory::Allocator::Instance()->GetMemoryUsage(tenantId, config_.cacheType); } - uint64_t GetMaxMemorySize() + uint64_t GetMaxMemorySize(ServiceType serviceType = ServiceType::OBJECT) { - return datasystem::memory::Allocator::Instance()->GetMaxMemorySize(config_.cacheType); + return datasystem::memory::Allocator::Instance()->GetMaxMemorySize(serviceType, config_.cacheType); } protected: @@ -320,26 +323,32 @@ TEST_F(AllocatorTest, TestAllocateMemoryWithThreshold) LOG(INFO) << "Test allocate memory with threshold."; auto *allocator = datasystem::memory::Allocator::Instance(); uint64_t maxSize = 64 * 1024ul * 1024ul; - uint64_t ocShmPercentage = 80; - uint64_t maxOcSize = (maxSize * ocShmPercentage) / 100; + uint64_t scShmPercentage = 90, ocShmPercentage = 80; + uint64_t maxOcSize = (maxSize * ocShmPercentage) / 100, maxScSize = (maxSize * scShmPercentage) / 100; ssize_t decayMs = 5000; ShmUnit shmUnit; ResetShmUnit(shmUnit); - DS_ASSERT_OK(allocator->Init(maxSize, 0, false, true, decayMs, ocShmPercentage)); - ASSERT_EQ(allocator->GetMaxMemorySize(), size_t(maxOcSize)); + DS_ASSERT_OK(allocator->Init(maxSize, 0, false, true, decayMs, ocShmPercentage, scShmPercentage)); + ASSERT_EQ(allocator->GetMaxMemorySize(ServiceType::OBJECT), size_t(maxOcSize)); + ASSERT_EQ(allocator->GetMaxMemorySize(ServiceType::STREAM), size_t(maxScSize)); ASSERT_EQ(allocator->GetMemoryUsage(), size_t(0)); EXPECT_EQ(allocator ->AllocateMemory(DEFAULT_TENANT_ID, maxOcSize + 1, false, shmUnit.pointer, shmUnit.fd, shmUnit.offset, - shmUnit.mmapSize) + shmUnit.mmapSize, ServiceType::OBJECT) .GetCode(), StatusCode::K_OUT_OF_MEMORY); ExpectUnChanged(shmUnit); + uint64_t needSize = 16; ResetShmUnit(shmUnit); + DS_ASSERT_OK(allocator->AllocateMemory(DEFAULT_TENANT_ID, needSize, false, shmUnit.pointer, shmUnit.fd, + shmUnit.offset, shmUnit.mmapSize, ServiceType::STREAM)); + ASSERT_EQ(allocator->GetMemoryUsage(), needSize); - DS_ASSERT_NOT_OK(allocator->FreeMemory(shmUnit.pointer)); + DS_ASSERT_NOT_OK(allocator->FreeMemory(shmUnit.pointer, ServiceType::OBJECT)); + DS_ASSERT_OK(allocator->FreeMemory(shmUnit.pointer, ServiceType::STREAM)); ASSERT_EQ(allocator->GetMemoryUsage(), size_t(0)); allocator->Shutdown(); @@ -546,6 +555,7 @@ TEST_F(AllocatorTest, TestAllocateMemoryInMultiThreads) allocator->Shutdown(); } + void AllocatorTest::TestAllocatedAddresses() { LOG(INFO) << "Test allocated addresses."; @@ -647,7 +657,7 @@ TEST_F(AllocatorTest, TestArenaBasicFunction) { AllocatorConfig config; config.shmSize = 64 * 1024ul * 1024ul; // 64 MB - config.decayMs = 1'000; // 1'000 MS + config.decayMs = 1'000; // 1'000 MS DS_ASSERT_OK(Init(config)); TestArenaBasicFunction(); } @@ -656,7 +666,7 @@ TEST_F(AllocatorTest, TestArenaBasicFunctionDisk) { AllocatorConfig config; config.shdSize = 64 * 1024ul * 1024ul; // 64 MB - config.decayMs = 1'000; // 1'000 MS + config.decayMs = 1'000; // 1'000 MS config.cacheType = memory::CacheType::DISK; DS_ASSERT_OK(Init(config)); TestArenaBasicFunction(); @@ -964,7 +974,7 @@ void AllocatorTest::TestUsedupAndFree2() TEST_F(AllocatorTest, TestUsedupAndFree2) { AllocatorConfig config; - config.shmSize = 2 * 1024ul * 1024ul * 1024ul; // 2 GB + config.shmSize = 2 * 1024ul * 1024ul * 1024ul; // 2 GB config.decayMs = 10 * SEC_PER_MIN * MS_PER_SECOND; // 10 minus. DS_ASSERT_OK(Init(config)); TestUsedupAndFree2(); @@ -973,7 +983,7 @@ TEST_F(AllocatorTest, TestUsedupAndFree2) TEST_F(AllocatorTest, TestUsedupAndFree2Disk) { AllocatorConfig config; - config.shdSize = 2 * 1024ul * 1024ul * 1024ul; // 2 GB + config.shdSize = 2 * 1024ul * 1024ul * 1024ul; // 2 GB config.decayMs = 10 * SEC_PER_MIN * MS_PER_SECOND; // 10 minus. config.cacheType = memory::CacheType::DISK; DS_ASSERT_OK(Init(config)); @@ -1167,7 +1177,7 @@ TEST_F(AllocatorTest, FakeAllocate) { FLAGS_arena_per_tenant = 1; AllocatorConfig config; - config.shmSize = 10 * 1024ul * 1024ul * 1024ul; // 10 GB + config.shmSize = 10 * 1024ul * 1024ul * 1024ul; // 10 GB config.decayMs = 10 * SEC_PER_MIN * MS_PER_SECOND; // 10 minus. DS_ASSERT_OK(Init(config)); FakeAllocate(); @@ -1177,7 +1187,7 @@ TEST_F(AllocatorTest, FakeAllocateDisk) { FLAGS_shared_disk_arena_per_tenant = 1; AllocatorConfig config; - config.shdSize = 10 * 1024ul * 1024ul * 1024ul; // 10 GB + config.shdSize = 10 * 1024ul * 1024ul * 1024ul; // 10 GB config.decayMs = 10 * SEC_PER_MIN * MS_PER_SECOND; // 10 minus. config.cacheType = memory::CacheType::DISK; DS_ASSERT_OK(Init(config)); @@ -1398,13 +1408,14 @@ public: Status AllocateMemory(uint64_t needSize, ShmUnitInfo &unit, memory::CacheType cacheType, const std::string &tenantId = DEFAULT_TENANT_ID) { - return datasystem::memory::Allocator::Instance()->AllocateMemory( - tenantId, needSize, false, unit.pointer, unit.fd, unit.offset, unit.mmapSize, cacheType); + return datasystem::memory::Allocator::Instance()->AllocateMemory(tenantId, needSize, false, unit.pointer, + unit.fd, unit.offset, unit.mmapSize, + ServiceType::OBJECT, cacheType); } Status FreeMemory(void *&pointer, memory::CacheType cacheType, const std::string &tenantId = DEFAULT_TENANT_ID) { - return datasystem::memory::Allocator::Instance()->FreeMemory(tenantId, pointer, cacheType); + return datasystem::memory::Allocator::Instance()->FreeMemory(tenantId, pointer, ServiceType::OBJECT, cacheType); } }; diff --git a/tests/ut/common/stream_cache/shared_mem_view_lock_test.cpp b/tests/ut/common/stream_cache/shared_mem_view_lock_test.cpp new file mode 100644 index 0000000..732c91e --- /dev/null +++ b/tests/ut/common/stream_cache/shared_mem_view_lock_test.cpp @@ -0,0 +1,65 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Testing SharedMemViewLock + */ +#include "datasystem/common/stream_cache/cursor.h" + +#include "common.h" + +namespace datasystem { +namespace ut { +class SharedMemViewLockTest : public CommonTest { +public: + SharedMemView view_; +}; + +TEST_F(SharedMemViewLockTest, WriteLockTimeoutTest) +{ + // Protect the lock size from changes. + const uint EXPECTED_SIZE_OF_SHAREDMEMVIEWLOCK = 4; + ASSERT_EQ(sizeof(view_.lock_), EXPECTED_SIZE_OF_SHAREDMEMVIEWLOCK); + + // Simulate the following: + // 1. Thread A holding the read lock for a long period of time. + // 2. Thread B try to get a write lock but timeout. + // 3. Read lock is obtainable after Thread A and Thread B finish. + + // Thread A + const uint TWO_SECS = 2; + ThreadPool pool(1); + auto func = [this]() { + SharedMemViewLock lock(&view_.lock_); + return lock.LockSharedAndExec([]() { sleep(TWO_SECS); }, ONE_THOUSAND); + }; + std::future fut = pool.Submit(func); + + sleep(1); + + // Thread B + SharedMemViewLock lock(&view_.lock_); + ASSERT_EQ(lock.LockExclusiveAndExec([]() { sleep(1); }, ONE_THOUSAND).GetCode(), K_TRY_AGAIN); + + // Wait for Thread A to finish. + DS_ASSERT_OK(fut.get()); + + // Should be able to get the read lock again as no one is waiting for the write lock. + DS_ASSERT_OK(lock.LockSharedAndExec([]() { sleep(1); }, ONE_THOUSAND)); +} + +} // namespace ut +} // namespace datasystem \ No newline at end of file diff --git a/tests/ut/common/stream_cache/stream_meta_shm_test.cpp b/tests/ut/common/stream_cache/stream_meta_shm_test.cpp new file mode 100644 index 0000000..fdda0b8 --- /dev/null +++ b/tests/ut/common/stream_cache/stream_meta_shm_test.cpp @@ -0,0 +1,59 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Testing StreamMetaShm. + */ +#include + +#include "common.h" +#include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/shared_memory/shm_unit.h" +#include "datasystem/common/stream_cache/stream_meta_shm.h" + +namespace datasystem { +namespace ut { +class StreamMetaShmTest : public CommonTest { +public: + const uint64_t streamMetaShmSize_ = 64; + const uint64_t maxStreamSize_ = 4 * 1024; +}; + +TEST_F(StreamMetaShmTest, BasicTest) +{ + size_t maxSize = 1024 * 1024ul * 1024ul; + DS_ASSERT_OK(datasystem::memory::Allocator::Instance()->Init(maxSize)); + + auto shmUnitOfStreamMeta = std::make_unique(); + DS_ASSERT_OK(shmUnitOfStreamMeta->AllocateMemory("", streamMetaShmSize_, false, ServiceType::STREAM)); + auto rc = memset_s(shmUnitOfStreamMeta->GetPointer(), streamMetaShmSize_, 0, streamMetaShmSize_); + ASSERT_EQ(rc, 0); + auto streamMetaShm = std::make_unique("stream0", shmUnitOfStreamMeta->GetPointer(), + streamMetaShmSize_, maxStreamSize_); + DS_ASSERT_OK(streamMetaShm->Init()); + + DS_ASSERT_NOT_OK(streamMetaShm->TryDecUsage(1)); + DS_ASSERT_OK(streamMetaShm->TryIncUsage(1)); + DS_ASSERT_OK(streamMetaShm->TryDecUsage(1)); + DS_ASSERT_NOT_OK(streamMetaShm->TryDecUsage(1)); + DS_ASSERT_OK(streamMetaShm->TryIncUsage(maxStreamSize_)); + DS_ASSERT_NOT_OK(streamMetaShm->TryIncUsage(1)); + DS_ASSERT_OK(streamMetaShm->TryDecUsage(1)); + DS_ASSERT_OK(streamMetaShm->TryIncUsage(1)); +} + +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/common/string_intern/string_ref_bench_test.cpp b/tests/ut/common/string_intern/string_ref_bench_test.cpp new file mode 100644 index 0000000..678267b --- /dev/null +++ b/tests/ut/common/string_intern/string_ref_bench_test.cpp @@ -0,0 +1,228 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Testing StringRef. + */ +#include +#include +#include +#include + +#include "common.h" +#include "datasystem/common/immutable_string/immutable_string.h" +#include "datasystem/common/string_intern/string_ref.h" +#include "datasystem/common/util/timer.h" +#include "datasystem/common/util/wait_post.h" + +namespace datasystem { +namespace ut { +const size_t threadCnt8 = 8; +class StringRefBenchTest : public CommonTest { +public: + std::string GetBenchCost(std::vector> &costsVec) + { + std::vector costs; + for (const auto &c : costsVec) { + for (const auto &cost : c) { + costs.emplace_back(cost); + } + } + std::stringstream ss; + std::sort(costs.begin(), costs.end()); + // remove last + double totalTimeCost = std::accumulate(costs.begin(), costs.end(), 0.0); // ms + double avg = totalTimeCost / static_cast(costs.size()); + auto count = costs.size(); + const size_t p90 = 90; + const size_t p99 = 90; + const size_t pmax = 100; + ss << costs.size(); + ss << "," << avg; // avg ms + ss << "," << costs[0]; // min ms + ss << "," << costs[count * p90 / pmax]; // p90 ms + ss << "," << costs[count * p99 / pmax]; // p99 ms + ss << "," << costs[count - 1]; // max ms + return ss.str(); + } + + static std::string GenUniqueString(size_t threadIdx, size_t iter) + { + const size_t appendSize = 64; + return "thread_" + std::to_string(threadIdx) + "_iter_" + std::to_string(iter) + "_" + + std::string(appendSize, 'a'); + } + + static std::string GenDupString(size_t threadIdx, size_t iter) + { + (void)threadIdx; + const size_t appendSize = 64; + const size_t count = 100; + return "thread_x_iter_" + std::to_string(iter % count) + "_" + std::string(appendSize, 'a'); + } + + template + void PerfIntern(size_t threadCnt, G &&gen, F &&fn) + { + const size_t countPerThread = 100'000; + const size_t batchCnt = 1024; + std::vector threads; + // generate string + std::vector> datas; + std::vector> costPerThread1; + std::vector> costPerThread2; + std::vector> costPerThread3; + datas.resize(threadCnt); + costPerThread1.resize(threadCnt); + costPerThread2.resize(threadCnt); + costPerThread3.resize(threadCnt); + Barrier barrier1(threadCnt); + Barrier barrier2(threadCnt); + Barrier barrier3(threadCnt); + using TbbMap = tbb::concurrent_hash_map; + TbbMap secondMap; + for (size_t i = 0; i < threadCnt; i++) { + auto &data = datas[i]; + auto &costs1 = costPerThread1[i]; + auto &costs2 = costPerThread2[i]; + auto &costs3 = costPerThread3[i]; + data.reserve(countPerThread); + costs1.reserve(countPerThread / batchCnt); + costs2.reserve(countPerThread / batchCnt); + costs3.reserve(countPerThread / batchCnt); + + for (size_t n = 0; n < countPerThread; n++) { + data.emplace_back(std::move(gen(i, n))); + } + std::shuffle(data.begin(), data.end(), std::mt19937{ std::random_device{}() }); + threads.emplace_back([&] { + std::vector keys; + keys.reserve(countPerThread); + barrier1.Wait(); + // test intern + for (size_t i = 0; i < countPerThread; i += batchCnt) { + Timer timer1; + for (size_t b = 0; b < batchCnt && i + b < countPerThread; b++) { + keys.emplace_back(fn(data[i + b])); + } + costs1.emplace_back(timer1.ElapsedMilliSecond()); + } + barrier2.Wait(); + // test intern key insert + size_t count = 0; + Timer timer2; + for (const auto &key : keys) { + typename TbbMap::const_accessor accessor; + secondMap.insert(accessor, key); + if (count == batchCnt) { + count = 0; + costs2.emplace_back(timer2.ElapsedMilliSecondAndReset()); + } + count += 1; + } + + barrier3.Wait(); + // test find + Timer timer3; + count = 0; + for (const auto &key : keys) { + typename TbbMap::const_accessor accessor; + secondMap.find(accessor, key); + if (count == batchCnt) { + count = 0; + costs3.emplace_back(timer3.ElapsedMilliSecondAndReset()); + } + count += 1; + } + }); + } + for (auto &t : threads) { + t.join(); + } + std::string caseName; + std::string name; + ut::GetCurTestName(caseName, name); + LOG(INFO) << "BENCHMARK," << caseName << "," << name << ",Thread-" << threadCnt << ", intern," + << GetBenchCost(costPerThread1); + LOG(INFO) << "BENCHMARK," << caseName << "," << name << ",Thread-" << threadCnt << ", key insert," + << GetBenchCost(costPerThread2); + LOG(INFO) << "BENCHMARK," << caseName << "," << name << ",Thread-" << threadCnt << ", key find," + << GetBenchCost(costPerThread3); + } +}; + +TEST_F(StringRefBenchTest, StringRefInternUnique1) +{ + PerfIntern(1, GenUniqueString, ObjectKey::Intern); +} + +TEST_F(StringRefBenchTest, StringRefInternDuplicate1) +{ + PerfIntern(1, GenDupString, ObjectKey::Intern); +} + +TEST_F(StringRefBenchTest, StringRefInternUnique8) +{ + PerfIntern(threadCnt8, GenUniqueString, ObjectKey::Intern); +} + +TEST_F(StringRefBenchTest, StringRefInternDuplicate8) +{ + PerfIntern(threadCnt8, GenDupString, ObjectKey::Intern); +} + +TEST_F(StringRefBenchTest, StdStringUnique1) +{ + PerfIntern(1, GenUniqueString, [](const std::string &str) { return std::string(str); }); +} + +TEST_F(StringRefBenchTest, StdStringDuplicate1) +{ + PerfIntern(1, GenDupString, [](const std::string &str) { return std::string(str); }); +} + +TEST_F(StringRefBenchTest, StdStringUnique8) +{ + PerfIntern(threadCnt8, GenUniqueString, [](const std::string &str) { return std::string(str); }); +} + +TEST_F(StringRefBenchTest, StdStringDuplicate8) +{ + PerfIntern(threadCnt8, GenDupString, [](const std::string &str) { return std::string(str); }); +} + +TEST_F(StringRefBenchTest, ImmutableStringInternUnique1) +{ + PerfIntern(1, GenUniqueString, [](const std::string &str) { return ImmutableString(str); }); +} + +TEST_F(StringRefBenchTest, ImmutableStringInternDuplicate1) +{ + PerfIntern(1, GenDupString, [](const std::string &str) { return ImmutableString(str); }); +} + +TEST_F(StringRefBenchTest, ImmutableStringInternUnique8) +{ + PerfIntern(threadCnt8, GenUniqueString, + [](const std::string &str) { return ImmutableString(str); }); +} + +TEST_F(StringRefBenchTest, ImmutableStringInternDuplicate8) +{ + PerfIntern(threadCnt8, GenDupString, [](const std::string &str) { return ImmutableString(str); }); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/common/string_intern/string_ref_test.cpp b/tests/ut/common/string_intern/string_ref_test.cpp new file mode 100644 index 0000000..0de77d7 --- /dev/null +++ b/tests/ut/common/string_intern/string_ref_test.cpp @@ -0,0 +1,389 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Testing StringRef. + */ +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "datasystem/common/string_intern/string_ref.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/common/util/uuid_generator.h" + +namespace datasystem { +namespace ut { +using ImmutableString = OtherKey; +using ImmutableStringPool = OtherKeyPool; +class StringRefTest : public CommonTest { +public: + static void CheckImmutableStringEqual(const ImmutableString &im1, const ImmutableString &im2) + { + ASSERT_EQ(im1, im2); + ASSERT_EQ(im1.ToString(), im2.ToString()); + ASSERT_EQ(&im1.ToString(), &im2.ToString()); + } + + static void CheckSetErase(tbb::concurrent_unordered_set> &set1, + tbb::concurrent_unordered_set> &set2) + { + set1.unsafe_erase("123"); + set2.unsafe_erase(ImmutableString("123")); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1UL); + + set1.unsafe_erase("456"); + set2.unsafe_erase(std::string("456")); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 0UL); + } + + template + static void CheckSetErase(T &set1, T &set2) + { + set1.erase("123"); + set2.erase(ImmutableString("123")); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1UL); + + set1.erase("456"); + set2.erase(std::string("456")); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 0UL); + } + + template + static void ImMapCheckMemoryReduce() + { + auto key1 = GetStringUuid(); + + auto value1 = RandomData().GetRandomUint32(); + auto value2 = RandomData().GetRandomUint32(); + T map1; + T map2; + { + auto im1 = ImmutableString(key1); + // insert by ImmutableString + map1[im1] = value1; + ASSERT_EQ(map1[key1], value1); + + // insert by std::string + map2[key1] = value2; + ASSERT_EQ(map2[im1], value2); + + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1UL); + map1.erase(key1); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1UL); + map2.erase(key1); + } + + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 0UL); + } + + template + static void ImSetCheckMemoryReduce() + { + T set1; + T set2; + std::string test1 = "123"; + std::string test2 = "456"; + + auto pair = set1.insert(test1); + ASSERT_TRUE(pair.second); + pair = set1.insert(ImmutableString(test2)); + ASSERT_TRUE(pair.second); + pair = set1.insert(ImmutableString(test2)); + ASSERT_FALSE(pair.second); + // After insert, 2 RefCountString in pool. + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 2UL); + // find by ImmutableString + auto iter = set1.find(ImmutableString(test1)); + ASSERT_TRUE(iter != set1.end()); + ASSERT_EQ(*iter, test1); + // find by std::string + iter = set1.find(test2); + ASSERT_TRUE(iter != set1.end()); + ASSERT_EQ(*iter, test2); + + // find by const char* + iter = set1.find("456"); + ASSERT_TRUE(iter != set1.end()); + ASSERT_EQ(*iter, test2); + + // find by not exist key + iter = set1.find("789"); + ASSERT_TRUE(iter == set1.end()); + + pair = set2.insert(ImmutableString(test1)); + ASSERT_TRUE(pair.second); + pair = set2.insert(ImmutableString(test2)); + ASSERT_TRUE(pair.second); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 2UL); + + auto iterInSet1 = set1.find(test1); + auto iterInSet2 = set2.find(test1); + ASSERT_EQ(*iterInSet1, *iterInSet2); + + CheckSetErase(set1, set2); + } +}; + +TEST_F(StringRefTest, TestConstructor) +{ + std::string test1 = "123"; + std::string test2 = "456"; + char test3[] = "123"; + { + ImmutableString im1 = ImmutableString(test1); + ImmutableString im2 = ImmutableString(test1); + ImmutableString im3 = ImmutableString(test2); + ImmutableString im4 = ImmutableString("123"); + ImmutableString im5 = ImmutableString("456"); + ImmutableString im6 = ImmutableString(test3); + + LOG(INFO) << "check im1, im2"; + CheckImmutableStringEqual(im1, im2); + LOG(INFO) << "check im1, im4"; + CheckImmutableStringEqual(im1, im4); + LOG(INFO) << "check im1, im6"; + CheckImmutableStringEqual(im1, im6); + LOG(INFO) << "check im3, im5"; + CheckImmutableStringEqual(im3, im5); + CHECK_NE(im1, im3); + CHECK_NE(im1.ToString(), im3.ToString()); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 2UL); + } + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 0UL); +} + +TEST_F(StringRefTest, TestBigString) +{ + const size_t strSize = 1024UL * 1024 * 1024; + const std::string str = RandomData().GetPartRandomString(strSize, 100); + const size_t imSize = 2; + std::vector imVec; + imVec.reserve(imSize); + for (size_t i = 0; i < imSize; i++) { + LOG(INFO) << "loop: " << i; + // Need copy once. + imVec.emplace_back(str); + } + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1UL); + imVec.clear(); + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 0UL); +} + +TEST_F(StringRefTest, TestDestructorInParallel) +{ + size_t strNum = 2; + const size_t strLen = 100; + std::vector strVec; + strVec.reserve(strNum); + for (size_t i = 0; i < strNum; i++) { + strVec.emplace_back(RandomData().GetRandomString(strLen)); + } + + const size_t threadNum = 32; + const size_t loopCnt = 10000; + auto pool = std::make_unique(threadNum); + for (size_t i = 0; i < threadNum; i++) { + pool->Execute([&strVec, i, strNum]() { + for (size_t j = 0; j < loopCnt; j++) { + ImmutableString im = ImmutableString(strVec[i % strNum]); + } + }); + } + pool.reset(); + + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 0UL); +} + +TEST_F(StringRefTest, TestImInTbbUnorderedSet) +{ + ImSetCheckMemoryReduce>>(); +} + +TEST_F(StringRefTest, TestImInSTLUnorderedSet) +{ + ImSetCheckMemoryReduce>(); +} + +TEST_F(StringRefTest, TestImInSTLSet) +{ + ImSetCheckMemoryReduce>(); +} + +TEST_F(StringRefTest, ImInTbbHashMap) +{ + auto key1 = GetStringUuid(); + + auto value1 = RandomData().GetRandomUint32(); + auto value2 = RandomData().GetRandomUint32(); + + using MapType = tbb::concurrent_hash_map; + + MapType map1; + MapType map2; + + auto im1 = ImmutableString(key1); + MapType::accessor ac; + // insert by ImmutableString + map1.insert(ac, im1); + ac->second = value1; + ac.release(); + map1.find(ac, im1); + ASSERT_EQ(ac->second, value1); + ac.release(); + + // insert by std::string + map2.insert(ac, key1); + ac->second = value2; + ac.release(); + map2.find(ac, key1); + ASSERT_EQ(ac->second, value2); + ac.release(); + + EXPECT_EQ(ImmutableStringPool::Instance().Size(), 1UL); +} + +TEST_F(StringRefTest, ImInUnorderedMapInParrel) +{ + auto key1 = GetStringUuid(); + using MapType = std::unordered_map; + MapType map1; + const size_t minThreadCnt = 10; + auto pool = std::make_unique(minThreadCnt); + std::shared_timed_mutex mutex; + + for (size_t i = 0; i < minThreadCnt; i++) { + pool->Execute([&key1, &map1, &mutex]() { + std::lock_guard lck(mutex); + auto iter = map1.find(key1); + if (iter == map1.end()) { + map1.emplace(key1, 1); + } else { + iter->second = 2; // change to 2 + } + }); + pool->Execute([&key1, &map1, &mutex]() { + std::shared_lock lck(mutex); + auto iter = map1.find(key1); + if (iter != map1.end()) { + LOG(INFO) << iter->first; + } + }); + } + pool.reset(); +} + +TEST_F(StringRefTest, ImInSTLHashMap) +{ + using MapType = std::map; + ImMapCheckMemoryReduce(); +} + +TEST_F(StringRefTest, ImInSTLUnorderedMap) +{ + using MapType = std::unordered_map; + ImMapCheckMemoryReduce(); +} + +TEST_F(StringRefTest, BaseTest) +{ + auto s1 = ObjectKey::Intern("abc"); + auto s2 = ObjectKey::Intern("abc"); + auto s3 = ObjectKey::Intern("abcd"); + ASSERT_EQ(s1, s2); + ASSERT_NE(s1, s3); + + std::unordered_map map; + map.emplace(s1, 1); + ASSERT_EQ(map[s2], 1); +} + +TEST_F(StringRefTest, TestMove) +{ + auto s1 = ObjectKey::Intern("abc"); + auto s2 = std::move(s1); + ASSERT_NE(s1, s2); + ASSERT_TRUE(s1.Size() == 0); + ASSERT_EQ(s1.ToString(), ""); + ASSERT_EQ(s2.ToString(), "abc"); + + auto s3 = ObjectKey::Intern("abcd"); + s3 = std::move(s2); + ASSERT_EQ(s2.ToString(), ""); + ASSERT_EQ(s3.ToString(), "abc"); + + std::string ss1 = "abc"; + std::string ss2 = "abcd"; + + ss2 = std::move(ss1); + ASSERT_EQ(ss1, ""); + ASSERT_EQ(ss2, "abc"); +} + +TEST_F(StringRefTest, InternAndErase) +{ + std::string key = "hello"; + std::vector threads; + const size_t threadCnt = 8; + const size_t testCnt = 10000; + + for (size_t i = 0; i < threadCnt; i++) { + threads.emplace_back([&key] { + for (size_t n = 0; n < testCnt; n++) { + auto k = ObjectKey::Intern(key); + (void)k; + } + }); + } + + for (auto &t : threads) { + t.join(); + } +} + +TEST_F(StringRefTest, SupportAddString) +{ + ObjectKey key = ObjectKey::Intern("hello"); + auto v1 = "world"; + ASSERT_EQ(key + v1, "helloworld"); + ASSERT_EQ(v1 + key, "worldhello"); + ASSERT_EQ(key.ToString(), "hello"); + + std::string v2("world"); + ASSERT_EQ(key + v2, "helloworld"); + ASSERT_EQ(v2 + key, "worldhello"); + ASSERT_EQ(key.ToString(), "hello"); +} + +TEST_F(StringRefTest, TestEmptyString) +{ + ObjectKey key1 = ObjectKey::Intern(""); + ObjectKey key2 = ObjectKey::Intern(""); + ObjectKey key3 = ObjectKey::Intern("123"); + + ASSERT_EQ(key1, key2); + ASSERT_NE(key1, key3); + ObjectKey key4 = std::move(key3); + ASSERT_EQ(key1, key3); + ASSERT_NE(key3, key4); + ASSERT_EQ(key4.ToString(), "123"); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/common/util/immutable_string_test.cpp b/tests/ut/common/util/immutable_string_test.cpp index 3166de5..41dc2d4 100644 --- a/tests/ut/common/util/immutable_string_test.cpp +++ b/tests/ut/common/util/immutable_string_test.cpp @@ -163,7 +163,7 @@ TEST_F(ImmutableStringTest, TestBigString) { size_t strSize = 1024ul * 1024 * 1024; std::string str = RandomData().GetPartRandomString(strSize, 100); - size_t imSize = 10; + size_t imSize = 2; std::vector imVec; imVec.reserve(imSize); for (size_t i = 0; i < imSize; i++) { diff --git a/tests/ut/common/util/shm_lock_test.cpp b/tests/ut/common/util/shm_lock_test.cpp index 4346e81..5b731b5 100644 --- a/tests/ut/common/util/shm_lock_test.cpp +++ b/tests/ut/common/util/shm_lock_test.cpp @@ -20,131 +20,14 @@ #include "common.h" #include "datasystem/common/log/log.h" #include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/stream_cache/cursor.h" #include "datasystem/common/util/safe_shm_lock.h" -#include "datasystem/common/util/timer.h" #include "datasystem/utils/status.h" #include "securec.h" namespace datasystem { namespace ut { -class SharedMemViewLock { -public: - explicit SharedMemViewLock(uint32_t *lockWord); - Status LockExclusiveAndExec(const std::function &writeFunc, uint64_t timeoutMs); - Status LockSharedAndExec(const std::function &readFunc, uint64_t timeoutMs); - -private: - uint32_t *lockWord_; - constexpr static const uint32_t WRITER = 1; - constexpr static const uint32_t READER = 2; - constexpr static const int TIMEOUT_WARNING_LIMIT_MS = 3000; -}; - -SharedMemViewLock::SharedMemViewLock(uint32_t *lockWord) : lockWord_(lockWord) -{ -} - -Status SharedMemViewLock::LockExclusiveAndExec(const std::function &writeFunc, uint64_t timeoutMs) -{ - Timer timer; - bool isFirstTimeout = false; - Status rc; - do { - uint32_t val = __atomic_load_n(lockWord_, __ATOMIC_ACQUIRE); - uint32_t expected = val & ~WRITER; - if (!__atomic_compare_exchange_n(lockWord_, &expected, val | WRITER, true, __ATOMIC_ACQUIRE, - __ATOMIC_RELAXED)) { - if (timer.ElapsedMilliSecond() > TIMEOUT_WARNING_LIMIT_MS && !isFirstTimeout) { - isFirstTimeout = true; - LOG(WARNING) << "Fetching a write-lock on shared memory takes more than " << TIMEOUT_WARNING_LIMIT_MS - << " ms, waiting for writer to release the lock."; - } - // If timeout send an error - CHECK_FAIL_RETURN_STATUS(timer.ElapsedMilliSecond() < timeoutMs, K_TRY_AGAIN, - FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); - continue; - } - // Write bit has been set, we must unset the writer bit before going out of scope. - while (val & ~WRITER) { - // Wait for all readers to go away - val = __atomic_load_n(lockWord_, __ATOMIC_ACQUIRE); - if (timer.ElapsedMilliSecond() > TIMEOUT_WARNING_LIMIT_MS && !isFirstTimeout) { - isFirstTimeout = true; - LOG(WARNING) << "Fetching a write-lock on shared memory takes more than " << TIMEOUT_WARNING_LIMIT_MS - << " ms, waiting for readers to release the lock."; - } - // If timeout send an error - if (timer.ElapsedMilliSecond() >= timeoutMs) { - // Unset the writer bit before returning error. - __atomic_fetch_sub(lockWord_, WRITER, __ATOMIC_RELEASE); - RETURN_STATUS(K_TRY_AGAIN, - FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); - } - } - // cache exception to avoid the lock not released. - try { - // Execute the user function after we get the lock in X - writeFunc(); - } catch (const std::exception &e) { - auto msg = FormatString("Exception when execute writeFunc get: %s", e.what()); - rc = Status(K_RUNTIME_ERROR, msg); - } - __atomic_fetch_sub(lockWord_, WRITER, __ATOMIC_RELEASE); - if (isFirstTimeout) { - LOG(WARNING) << "Fetching a write-lock on shared memory takes " << timer.ElapsedMilliSecond() << " ms"; - } - if (rc.IsError()) { - LOG(ERROR) << rc.GetMsg(); - } - return rc; - } while (true); -} - -Status SharedMemViewLock::LockSharedAndExec(const std::function &readFunc, uint64_t timeoutMs) -{ - Timer timer; - bool isFirstTimeout = false; - Status rc; - do { - while (__atomic_load_n(lockWord_, __ATOMIC_ACQUIRE) & WRITER) { - // Block on writer - if (timer.ElapsedMilliSecond() > TIMEOUT_WARNING_LIMIT_MS && !isFirstTimeout) { - isFirstTimeout = true; - LOG(WARNING) << "Fetching a read-lock on shared memory takes more than " << TIMEOUT_WARNING_LIMIT_MS - << " ms, waiting for writer to release the lock"; - } - - // If timeout send an error - CHECK_FAIL_RETURN_STATUS(timer.ElapsedMilliSecond() < timeoutMs, K_TRY_AGAIN, - FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); - } - if ((__atomic_add_fetch(lockWord_, READER, __ATOMIC_ACQUIRE) & WRITER) == 0) { - // cache exception to avoid the lock not released. - try { - // Execute user function after we get the lock in shared mode - readFunc(); - } catch (const std::exception &e) { - auto msg = FormatString("Exception when execute readFunc get: %s", e.what()); - rc = Status(K_RUNTIME_ERROR, msg); - } - - __atomic_fetch_sub(lockWord_, READER, __ATOMIC_RELEASE); - if (isFirstTimeout) { - LOG(WARNING) << "Fetching a read-lock on shared memory takes " << timer.ElapsedMilliSecond() << " ms"; - } - if (rc.IsError()) { - LOG(ERROR) << rc.GetMsg(); - } - return rc; - } - __atomic_fetch_sub(lockWord_, READER, __ATOMIC_RELEASE); // A writer beats us. retry again - // If timeout send an error - CHECK_FAIL_RETURN_STATUS(timer.ElapsedMilliSecond() < timeoutMs, K_TRY_AGAIN, - FormatString("[%s:%s] Timeout after %zu ms", __FUNCTION__, __LINE__, timeoutMs)); - } while (true); -} - class ShmLockTest : public CommonTest { protected: uint32_t lockWord = 0; diff --git a/tests/ut/common/util/validator_test.cpp b/tests/ut/common/util/validator_test.cpp index bc3a381..69a8398 100644 --- a/tests/ut/common/util/validator_test.cpp +++ b/tests/ut/common/util/validator_test.cpp @@ -38,6 +38,8 @@ TEST_F(ValidatorTest, TestValidator1) EXPECT_FALSE(Validator::ValidateRealPath("FlagName", "/path/not/exist")); EXPECT_TRUE(Validator::ValidatePathString("FlagName", "/path/To/Dir/")); EXPECT_FALSE(Validator::ValidateL2CacheType("FlagName", "whatever")); + EXPECT_TRUE(Validator::ValidateRocksdbModeType("FlagName", "async")); + EXPECT_FALSE(Validator::ValidateRocksdbModeType("FlagName", "whatever")); std::vector validPaths = { "/home/sn/ttt", "~/home/sn/ttt", "!/home/sn/ttt", "qqq/" }; std::vector notValidPath = { "/ /sdaa", " /wdq//w", "///", "~//ef", "/home/ sn/ttt" }; for (auto &path : validPaths) { diff --git a/tests/ut/master/object_cache/master_dev_dead_lock_manager_test.cpp b/tests/ut/master/object_cache/master_dev_dead_lock_manager_test.cpp new file mode 100644 index 0000000..670c7dc --- /dev/null +++ b/tests/ut/master/object_cache/master_dev_dead_lock_manager_test.cpp @@ -0,0 +1,70 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "datasystem/master/object_cache/device/master_dev_dead_lock_manager.h" +#include "gtest/gtest.h" + +using namespace datasystem::master; +namespace datasystem { +namespace ut { + +class MasterDevDeadLockManagerTest : public CommonTest { +public: + void SetUp() + { + masterDevDeadLockManager_ = std::make_shared(); + } + + void TearDown() + { + masterDevDeadLockManager_.reset(); + } + + std::shared_ptr masterDevDeadLockManager_{ nullptr }; +}; + +TEST_F(MasterDevDeadLockManagerTest, TestIsExistDeadlock_Ring) +{ + masterDevDeadLockManager_->AddDependencyEdge("client1", "client2"); + masterDevDeadLockManager_->AddDependencyEdge("client2", "client3"); + masterDevDeadLockManager_->AddDependencyEdge("client1", "client3"); + ASSERT_TRUE(masterDevDeadLockManager_->IsExistDeadlock()); + + masterDevDeadLockManager_->RemoveDependencyEdge("client1", "client3"); + ASSERT_FALSE(masterDevDeadLockManager_->IsExistDeadlock()); +} + +TEST_F(MasterDevDeadLockManagerTest, TestIsExistDeadlock_MutualDependence) +{ + masterDevDeadLockManager_->AddDependencyEdge("client1", "client2"); + ASSERT_FALSE(masterDevDeadLockManager_->IsExistDeadlock()); + + masterDevDeadLockManager_->AddDependencyEdge("client2", "client1"); + ASSERT_TRUE(masterDevDeadLockManager_->IsExistDeadlock()); + + masterDevDeadLockManager_->RemoveDependencyEdge("client2", "client1"); + ASSERT_FALSE(masterDevDeadLockManager_->IsExistDeadlock()); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/master/object_cache/master_dev_oc_manager_test.cpp b/tests/ut/master/object_cache/master_dev_oc_manager_test.cpp index 3a90c7b..beb94b9 100644 --- a/tests/ut/master/object_cache/master_dev_oc_manager_test.cpp +++ b/tests/ut/master/object_cache/master_dev_oc_manager_test.cpp @@ -166,7 +166,7 @@ TEST_F(MasterDevOcManagerTest, DISABLED_TestHcclSelect_ExitingHccl) // client b npu 1 -> client c npu 2 AddLocs(manager, tempObjectKey, clientB, clientC, clientBDeviceId, clientCDeviceId); - // add graph_ + // add hcclRelationshipGraph_ // client b npu 1 -> client a npu 0 AddGraph(manager, clientB, clientA, clientBDeviceId, clientADeviceId); diff --git a/tests/ut/master/object_cache/object_meta_store_test.cpp b/tests/ut/master/object_cache/object_meta_store_test.cpp index e0ac527..4ad793b 100644 --- a/tests/ut/master/object_cache/object_meta_store_test.cpp +++ b/tests/ut/master/object_cache/object_meta_store_test.cpp @@ -127,6 +127,7 @@ TEST_F(ObjectMetaStoreTest, TestCreateQueryRemoveMeta) EXPECT_EQ(this->StoreRemove(removeIds), Status::OK()); // Remove not exist EXPECT_EQ(this->StoreRemove(removeIds), Status::OK()); + sleep(1); } } // namespace ut diff --git a/tests/ut/master/stream_cache/rocks_streammeta_store_test.cpp b/tests/ut/master/stream_cache/rocks_streammeta_store_test.cpp new file mode 100644 index 0000000..1db4dc3 --- /dev/null +++ b/tests/ut/master/stream_cache/rocks_streammeta_store_test.cpp @@ -0,0 +1,356 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test ObjectMeta Storage basic functions. + */ + +#include +#include +#include +#include +#include + +#include "common.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/common/util/file_util.h" +#include "datasystem/common/util/random_data.h" +#include "datasystem/common/util/uuid_generator.h" +#include "datasystem/master/stream_cache/store/rocks_stream_meta_store.h" +#include "datasystem/protos/master_stream.pb.h" +#include "datasystem/stream/stream_config.h" + +using namespace datasystem::master::stream_cache; +DS_DECLARE_string(rocksdb_write_mode); +namespace datasystem { +namespace ut { +class RocksStreamMetaStoreTest : public CommonTest { +public: + void SetUp() + { + FLAGS_rocksdb_write_mode = "sync"; + backStorePath_ = "rocks_streammeta_store_" + random_.GetRandomString(8); + rocksStore_ = RocksStore::GetInstance(backStorePath_); + rocksStreamMetaStore_ = std::make_unique(rocksStore_.get()); + CHECK_EQ(rocksStreamMetaStore_->Init(), Status::OK()); + } + + void TearDown() + { + rocksStreamMetaStore_.reset(); + DS_ASSERT_OK(RemoveAll(backStorePath_)); + } + + void MakePubWorkerMetas(size_t createNum, std::unordered_map &produerMetaPbs) + { + for (size_t i = 0; i < createNum; i++) { + ProducerMetaPb pubMetaPb; + std::string streamName = "streamName"; + pubMetaPb.set_stream_name(streamName + std::to_string(i)); + pubMetaPb.mutable_worker_address()->set_host("127.0.0.1"); + pubMetaPb.mutable_worker_address()->set_port(1000); + produerMetaPbs.emplace(pubMetaPb.stream_name(), std::move(pubMetaPb)); + } + } + + void MakeConsumerMetas(size_t createNum, std::unordered_map &consumerMetas) + { + for (size_t i = 0; i < createNum; i++) { + ConsumerMetaPb subMeta; + subMeta.set_stream_name("streamName" + std::to_string(i)); + subMeta.mutable_worker_address()->set_host("127.0.0.1"); + subMeta.mutable_worker_address()->set_port(1234); + subMeta.set_consumer_id(GetStringUuid()); + std::string subName("sub_" + std::to_string(i)); + subMeta.mutable_sub_config()->set_subscription_name(subName); + subMeta.mutable_sub_config()->set_subscription_type(SubscriptionTypePb::STREAM_PB); + subMeta.set_last_ack_cursor(0); + consumerMetas.emplace(subMeta.stream_name(), subMeta); + } + } + + inline std::string HostPb2Str(const HostPortPb &hostPb) noexcept + { + HostPort addr(hostPb.host(), hostPb.port()); + return addr.ToString(); + } + + void MakeStreamMetas(size_t createNum, std::vector &streamMetas) + { + for (size_t i = 0; i < createNum; i++) { + std::string streamName = "streamName" + std::to_string(i); + streamMetas.emplace_back(streamName); + } + } + + void MakeProducerExistIds(const std::unordered_map &producerMetas, + std::unordered_map &queryIds) + { + size_t index = 0; + for (const auto &meta : producerMetas) { + if (index % 2 == 0) { + LOG(INFO) << "Producer streamName:" << meta.first; + queryIds.emplace(meta.first, meta.second); + } + index++; + } + } + + void MakeConsumerExistIds(const std::unordered_map &consumerMetas, + std::unordered_map &queryIds) + { + size_t index = 0; + for (const auto &meta : consumerMetas) { + LOG(INFO) << "consumer streamName:" << meta.second.stream_name(); + queryIds.emplace(meta.second.stream_name(), meta.second.consumer_id()); + index++; + } + } + + void MakeStreamExistIds(const std::vector &streamMetas, std::list &queryIds) + { + size_t index = 0; + for (const auto &meta : streamMetas) { + queryIds.emplace_back(meta); + LOG(INFO) << "meta.second.streamName:" << meta; + index++; + } + } + + Status StorePubWorker(const std::unordered_map &producerMetas) + { + for (const auto &kv : producerMetas) { + RETURN_IF_NOT_OK(rocksStreamMetaStore_->AddPubNode(kv.second)); + } + return Status::OK(); + } + + Status StoreConsumer(const std::unordered_map &consumerMetas) + { + LOG(INFO) << "inMetas:" << consumerMetas.size(); + for (const auto &kv : consumerMetas) { + RETURN_IF_NOT_OK(rocksStreamMetaStore_->AddSubNode(kv.second)); + } + return Status::OK(); + } + + Status StoreStream(const std::vector &inMetas, const StreamFields &streamFields) + { + for (const auto &meta : inMetas) { + RETURN_IF_NOT_OK(rocksStreamMetaStore_->AddStream(meta, streamFields)); + } + return Status::OK(); + } + + Status GetAllPubWorkers(const std::unordered_map &removeKeys, + std::vector &pubWorkerMetas) + { + for (const auto &removeKey : removeKeys) { + std::string streamName = removeKey.first; + LOG(INFO) << " PubWorker streamName:" << streamName; + RETURN_IF_NOT_OK(rocksStreamMetaStore_->GetOneStreamProducers(streamName, pubWorkerMetas)); + } + RETURN_IF_NOT_OK(rocksStreamMetaStore_->GetOneStreamProducers("streamName4", pubWorkerMetas)); + return Status::OK(); + } + + Status GetAllConsumer(const std::unordered_map &removeKeys, + std::vector &consumerMetas) + { + for (const auto &removeKey : removeKeys) { + std::string streamName = removeKey.first; + LOG(INFO) << " Consumer streamName:" << streamName; + RETURN_IF_NOT_OK(rocksStreamMetaStore_->GetOneStreamConsumers(streamName, consumerMetas)); + } + return Status::OK(); + } + + Status StoreGetStream(std::vector &streamMetas) + { + RETURN_IF_NOT_OK(rocksStreamMetaStore_->GetAllStream(streamMetas)); + return Status::OK(); + } + + Status AddRemovePubWorker(const std::unordered_map &removeKeys) + { + for (const auto &removeKey : removeKeys) { + RETURN_IF_NOT_OK(this->rocksStreamMetaStore_->DelPubNode(removeKey.second)); + } + return Status::OK(); + } + + Status AddRemoveConsumer(const std::unordered_map &removeKeys) + { + for (const auto &removeKey : removeKeys) { + const std::string &streamName = removeKey.first; + const std::string &consumerId = removeKey.second; + RETURN_IF_NOT_OK(this->rocksStreamMetaStore_->DelSubNode(streamName, consumerId)); + } + return Status::OK(); + } + + Status AddRemoveStream(std::list &removeKeys) + { + for (const auto &removeKey : removeKeys) { + const std::string &streamName = removeKey; + RETURN_IF_NOT_OK(this->rocksStreamMetaStore_->DelStream(streamName)); + } + return Status::OK(); + } + + std::string backStorePath_; + static RandomData random_; + std::shared_ptr rocksStore_; + std::unique_ptr rocksStreamMetaStore_; +}; + +RandomData RocksStreamMetaStoreTest::random_; + +TEST_F(RocksStreamMetaStoreTest, TestCreateQueryRemovePubWorkerMeta) +{ + // Create + size_t createNum = 10; + std::unordered_map producerMetas; + this->MakePubWorkerMetas(createNum, producerMetas); + EXPECT_EQ(this->StorePubWorker(producerMetas), Status::OK()); + // Create same + EXPECT_EQ(this->StorePubWorker(producerMetas), Status::OK()); + std::unordered_map removeKeys; + this->MakeProducerExistIds(producerMetas, removeKeys); + + // get data from rocksdb + std::vector producerOutMetas; + EXPECT_EQ(this->GetAllPubWorkers(removeKeys, producerOutMetas), Status::OK()); + + // Remove exist + EXPECT_EQ(this->AddRemovePubWorker(removeKeys), Status::OK()); + // Remove not exist + EXPECT_EQ(this->AddRemovePubWorker(removeKeys), Status::OK()); +} + +TEST_F(RocksStreamMetaStoreTest, TestCreateQueryRemoveSubWorkerMeta) +{ + // Create + size_t createNum = 2; + std::unordered_map consumerMetas; + this->MakeConsumerMetas(createNum, consumerMetas); + EXPECT_EQ(this->StoreConsumer(consumerMetas), Status::OK()); + // Create same + EXPECT_EQ(this->StoreConsumer(consumerMetas), Status::OK()); + std::unordered_map removeKeys; + this->MakeConsumerExistIds(consumerMetas, removeKeys); + + // get data from rocksdb + std::vector consumerOutMetas; + EXPECT_EQ(this->GetAllConsumer(removeKeys, consumerOutMetas), Status::OK()); + + // Remove exist + EXPECT_EQ(this->AddRemoveConsumer(removeKeys), Status::OK()); + // Remove not exist + EXPECT_EQ(this->AddRemoveConsumer(removeKeys), Status::OK()); +} + +TEST_F(RocksStreamMetaStoreTest, TestCreateQueryRemoveStreamWorkerMeta) +{ + // Create + size_t createNum = 10; + const uint64_t maxStreamSize = 1024 * 1024 * 1024; // 1G max stream size + const int64_t pageSize = 1024 * 1024; + StreamFields streamFields(maxStreamSize, pageSize, false, 0, false, 0, StreamMode::MPMC); + std::vector streamMetas; + this->MakeStreamMetas(createNum, streamMetas); + EXPECT_EQ(this->StoreStream(streamMetas, streamFields), Status::OK()); + // Create same + EXPECT_EQ(this->StoreStream(streamMetas, streamFields), Status::OK()); + std::list removeKeys; + this->MakeStreamExistIds(streamMetas, removeKeys); + + // get data from rocksdb + std::vector streamOutMetas; + EXPECT_EQ(this->StoreGetStream(streamOutMetas), Status::OK()); + + for (const auto &meta : streamOutMetas) { + EXPECT_EQ(meta.max_stream_size(), maxStreamSize); + EXPECT_EQ(meta.page_size(), pageSize); + EXPECT_EQ(meta.auto_cleanup(), false); + } + + // Remove exist + EXPECT_EQ(this->AddRemoveStream(removeKeys), Status::OK()); + // Remove not exist + EXPECT_EQ(this->AddRemoveStream(removeKeys), Status::OK()); +} + +TEST_F(RocksStreamMetaStoreTest, TestAddRemoveGetNotification) +{ + std::string worekrAddr1 = "127.0.0.1:8001"; + std::string worekrAddr2 = "127.0.0.1:8002"; + std::string worekrAddr3 = "127.0.0.1:8003"; + + const size_t createNum = 3; + // test pub + NotifyPubPb pubPbs[createNum]; + pubPbs[0].set_stream_name("stream1"); + pubPbs[0].set_worker_addr(worekrAddr2); + pubPbs[1].set_stream_name("stream2"); + pubPbs[1].set_worker_addr(worekrAddr3); + pubPbs[1].set_is_close(true); + pubPbs[2].set_stream_name("stream3"); + pubPbs[2].set_worker_addr(worekrAddr1); + + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifyPub(worekrAddr1, pubPbs[0])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifyPub(worekrAddr2, pubPbs[1])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifyPub(worekrAddr3, pubPbs[0])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifyPub(worekrAddr3, pubPbs[1])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifyPub(worekrAddr3, pubPbs[2])); + std::vector> pubs; + DS_ASSERT_OK(rocksStreamMetaStore_->GetAllNotifyPub(pubs)); + ASSERT_EQ(pubs.size(), 5ul); + DS_ASSERT_OK(rocksStreamMetaStore_->RemoveNotifyPub(worekrAddr1, pubPbs[0])); + DS_ASSERT_OK(rocksStreamMetaStore_->GetAllNotifyPub(pubs)); + ASSERT_EQ(pubs.size(), 4ul); + + // test sub + NotifyConsumerPb subPbs[createNum]; + std::unordered_map consumerMetas; + MakeConsumerMetas(createNum, consumerMetas); + int index = 0; + for (auto &kv : consumerMetas) { + *subPbs[index].mutable_consumer() = std::move(kv.second); + index++; + } + + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifySub(worekrAddr1, subPbs[0])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifySub(worekrAddr2, subPbs[1])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifySub(worekrAddr3, subPbs[0])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifySub(worekrAddr3, subPbs[1])); + DS_ASSERT_OK(rocksStreamMetaStore_->AddNotifySub(worekrAddr3, subPbs[2])); + std::vector> subs; + DS_ASSERT_OK(rocksStreamMetaStore_->GetAllNotifySub(subs)); + ASSERT_EQ(subs.size(), 5ul); + DS_ASSERT_OK(rocksStreamMetaStore_->RemoveNotifySub(worekrAddr1, subPbs[0])); + DS_ASSERT_OK(rocksStreamMetaStore_->GetAllNotifySub(subs)); + ASSERT_EQ(subs.size(), 4ul); + + // test remove by worker + DS_ASSERT_OK(rocksStreamMetaStore_->RemoveNotificationByWorker(worekrAddr3)); + DS_ASSERT_OK(rocksStreamMetaStore_->GetAllNotifyPub(pubs)); + DS_ASSERT_OK(rocksStreamMetaStore_->GetAllNotifySub(subs)); + ASSERT_EQ(pubs.size(), 1ul); + ASSERT_EQ(subs.size(), 1ul); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/master/stream_cache/sc_migrate_metadata_manager_test.cpp b/tests/ut/master/stream_cache/sc_migrate_metadata_manager_test.cpp new file mode 100644 index 0000000..e27fd38 --- /dev/null +++ b/tests/ut/master/stream_cache/sc_migrate_metadata_manager_test.cpp @@ -0,0 +1,106 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test SCMigrateMetadataManager. + */ + +#include +#include +#include "common.h" +#include "datasystem/common/inject/inject_point.h" +#include "datasystem/common/util/net_util.h" +#include "datasystem/master/stream_cache/sc_migrate_metadata_manager.h" +#include "../../../common/binmock/binmock.h" +#include "datasystem/utils/status.h" +#include "datasystem/common/log/logging.h" + +using namespace datasystem::master; +using namespace ::testing; + +namespace datasystem { +namespace ut { + +class SCMigrateMetadataManagerTest : public CommonTest { +public: + void SetUp() override + { + Logging::GetInstance()->Start("ds_llt", true, 1); + } + + void MigrateMetaDataWithRetry(SCMigrateMetadataManager::MigrateMetaInfo &info) + { + datasystem::inject::Set("SCMigrateMetadataManager.MigrateMetaDataWithRetry.interval", "call(5)"); + BINEXPECT_CALL(&SCMigrateMetadataManager::MigrateMetaData, (_, _)) + .WillRepeatedly(Invoke(this, &SCMigrateMetadataManagerTest::MockMigrateMetaDataFailed)); + HostPort hostPort; + std::shared_ptr scMetadataManager = + std::make_shared(hostPort, nullptr, nullptr, nullptr, nullptr, ""); + DS_ASSERT_OK(migrateManager_.MigrateMetaDataWithRetry(scMetadataManager, info, true)); + } + + Status MockMigrateMetaDataFailed(const std::shared_ptr &scMetadataManager, + SCMigrateMetadataManager::MigrateMetaInfo &info) + { + (void)scMetadataManager; + info.failedStreamNames = info.streamNames; + return Status::OK(); + } + + void MigrateMetaDataWithError(SCMigrateMetadataManager::MigrateMetaInfo &info) + { + BINEXPECT_CALL(&SCMigrateMetadataManager::MigrateMetaData, (_, _)) + .WillRepeatedly(Invoke(this, &SCMigrateMetadataManagerTest::MockMigrateMetaDataRetError)); + HostPort hostPort; + std::shared_ptr scMetadataManager = + std::make_shared(hostPort, nullptr, nullptr, nullptr, nullptr, ""); + ASSERT_EQ(migrateManager_.MigrateMetaData(scMetadataManager, info).GetCode(), K_RUNTIME_ERROR); + } + + Status MockMigrateMetaDataRetError(const std::shared_ptr &scMetadataManager, + SCMigrateMetadataManager::MigrateMetaInfo &info) + { + (void)scMetadataManager; + info.failedStreamNames = info.streamNames; + return Status{ K_RUNTIME_ERROR, "runtime error" }; + } + +protected: + SCMigrateMetadataManager migrateManager_; +}; + +TEST_F(SCMigrateMetadataManagerTest, MigrateLimitedRetry) +{ + SCMigrateMetadataManager::MigrateMetaInfo info; + info.destAddr = "127.0.0.1:1"; + info.streamNames.emplace_back("stream1"); + info.streamNames.emplace_back("stream2"); + info.streamNames.emplace_back("stream3"); + MigrateMetaDataWithRetry(info); +} + +TEST_F(SCMigrateMetadataManagerTest, MigrateMeetError) +{ + SCMigrateMetadataManager::MigrateMetaInfo info; + info.destAddr = "127.0.0.1:1"; + info.streamNames.emplace_back("stream1"); + info.streamNames.emplace_back("stream2"); + info.streamNames.emplace_back("stream3"); + MigrateMetaDataWithError(info); +} + +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp b/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp index 11111de..4ba0c0c 100644 --- a/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp +++ b/tests/ut/worker/object_cache/worker_oc_eviction_test.cpp @@ -40,6 +40,7 @@ #include "datasystem/worker/object_cache/worker_oc_eviction_manager.h" #include "datasystem/worker/object_cache/worker_oc_service_impl.h" #include "datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.h" +#include "datasystem/worker/stream_cache/worker_sc_allocate_memory.h" #include "eviction_manager_common.h" using namespace datasystem::object_cache; @@ -240,18 +241,21 @@ public: allocator = datasystem::memory::Allocator::Instance(); akSkManager_ = std::make_shared(0); - allocator->Init(maxSize_, 0, false, true, 5000, ocPercent_); // decay is 5000 ms. + allocator->Init(maxSize_, 0, false, true, 5000, ocPercent_, scPercent_); // decay is 5000 ms. std::shared_ptr &objectTable = GetObjectTable(); evictionManager_ = std::make_shared( objectTable, HostPort("127.0.0.1", 32131), // worker port is 32131, HostPort("127.0.0.1", 52319)); // master port is 52319; auto globalRefTable = std::make_shared(); DS_ASSERT_OK(evictionManager_->Init(globalRefTable, akSkManager_)); + scAllocateManager_ = std::make_shared(evictionManager_); } std::shared_ptr akSkManager_; std::shared_ptr evictionManager_; + std::shared_ptr scAllocateManager_; uint64_t maxSize_ = 0; + int scPercent_ = 0; int ocPercent_ = 0; int limit = 100 * 1024 * 1024; // spill limit size is 100 * 1024 * 1024; }; @@ -259,28 +263,71 @@ public: TEST_F(ScEvictionObjectTest, DISABLED_TestEvictSc50Oc50) { maxSize_ = 50 * 1024 * 1024; // shared memory size 50 * 1024 * 1024 + scPercent_ = 70; // sc shared memory max size is 70 / 100 * maxSize_ ocPercent_ = 100; // oc shared memory max size is 50 / 100 * maxSize_ constexpr size_t limit = 100 * 1024 * 1024; // FLAGS_spill_size_limit = limit; FLAGS_spill_directory = "./spill_TestEvictSc50Oc50"; InitTest(); + auto streamSize = 1 * 1024 * 1024; // stream page size is 1 * 1024 * 1024; for (int i = 0; i < 30; i++) { // object num is 30 auto prefix = "test_for_evict_"; auto objectSize = 1 * 1024 * 1024; DS_ASSERT_OK(CreateObject(prefix + std::to_string(i), objectSize)); evictionManager_->Add(prefix + std::to_string(i)); } + auto unit = std::make_shared(); + for (int i = 0; i < 30; i++) { // stream num is 30 + DS_ASSERT_OK(scAllocateManager_->AllocateMemoryForStream(DEFAULT_TENANT_ID, "qwer" + std::to_string(i), + streamSize, true, *unit, true)); + } +} + +TEST_F(ScEvictionObjectTest, TestEvictScSizeMax) +{ + maxSize_ = 50 * 1024 * 1024; // shared memory size 50 * 1024 * 1024 + scPercent_ = 50; // sc shared memory max size is 50% * maxSize_ + ocPercent_ = 100; // oc shared memory max size is 100% * maxSize_ + FLAGS_spill_size_limit = limit; + FLAGS_spill_directory = "./spill_TestEvictScSizeMax"; + InitTest(); + auto size = 27 * 1024 * 1024; // stream page size is 27 * 1024 * 1024; + auto unit = std::make_shared(); + auto status = unit->AllocateMemory("", size, true, ServiceType::STREAM); + ASSERT_EQ(status.GetCode(), StatusCode::K_OUT_OF_MEMORY) << status.GetMsg(); + status = scAllocateManager_->AllocateMemoryForStream(DEFAULT_TENANT_ID, "qwer", size, true, *unit, true); + ASSERT_TRUE(status.GetMsg().find("Stream cache memory size overflow, maxStreamSize") != std::string::npos); +} + +TEST_F(ScEvictionObjectTest, TestScNotEvictObject) +{ + maxSize_ = 50 * 1024 * 1024; // shared memory size 50 * 1024 * 1024 + scPercent_ = 100; // sc shared memory max size is 100% * maxSize_ + ocPercent_ = 100; // oc shared memory max size is 100% * maxSize_ + FLAGS_spill_size_limit = limit; + FLAGS_spill_directory = "./spill_TestScNotEvictObject"; + auto streamSize = 2 * 1024 * 1024; // stream page size is 2 * 1024 * 1024; + InitTest(); + auto unit = std::make_shared(); + for (int i = 0; i < 9; i++) { // stream page num is 9 + DS_ASSERT_OK(scAllocateManager_->AllocateMemoryForStream(DEFAULT_TENANT_ID, "qwer" + std::to_string(i), + streamSize, true, *unit, true)); + } } TEST_F(ScEvictionObjectTest, TestEvictObject) { LOG_IF_ERROR(inject::Set("worker.Spill.Sync", "return()"), "set inject point failed"); maxSize_ = 10 * 1024 * 1024; // shared memory size 10 * 1024 * 1024 + scPercent_ = 100; // sc shared memory max size is 100% * maxSize_ ocPercent_ = 50; // oc shared memory max size is 50% * maxSize_ constexpr size_t limit = 100 * 1024 * 1024; // spill limit size is 100 * 1024 * 1024; FLAGS_spill_size_limit = limit; FLAGS_spill_directory = "./spill_TestEvictObject"; + auto streamSize = 8 * 1024 * 1024; InitTest(); + auto unit = std::make_shared(); + DS_ASSERT_OK(scAllocateManager_->AllocateMemoryForStream(DEFAULT_TENANT_ID, "qwer", streamSize, true, *unit, true)); const int kNumObjectsToCreate = 10; for (int i = 0; i < kNumObjectsToCreate; i++) { auto prefix = "test_for_evict_"; diff --git a/tests/ut/worker/stream_cache/lock_map_test.cpp b/tests/ut/worker/stream_cache/lock_map_test.cpp new file mode 100644 index 0000000..ae4a471 --- /dev/null +++ b/tests/ut/worker/stream_cache/lock_map_test.cpp @@ -0,0 +1,344 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Usage Monitor test + */ + +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/common/util/lock_map.h" +#include +namespace datasystem { +namespace ut { +static constexpr uint64_t DEFAULT_TOTAL_SIZE = 10; +class LockMapTest : public CommonTest { +public: + LockMapTest() + { + } + + ~LockMapTest() override = default; + + void SetUp() override + { + FLAGS_v = 2; // vlog is 2. + lockMap_ = std::make_unique>(); + CommonTest::SetUp(); + } + + void TearDown() override + { + CommonTest::TearDown(); + } + + void JoinThreads() + { + if (!writerThreads_.empty()) { + for (auto &writerThread : writerThreads_) { + writerThread.join(); + } + } + + if (!readerThreads_.empty()) { + for (auto &readerThread : readerThreads_) { + readerThread.join(); + } + } + } + + std::unique_ptr> lockMap_; + std::vector writerThreads_; + std::vector readerThreads_; +}; + +TEST_F(LockMapTest, TestWriteLock) +{ + const int numWriterThreads = 10; + const int numLoops = 10; + int count = 0; + const int interval = 1000; + + for (int threadId = 0; threadId < numWriterThreads; ++threadId) { + writerThreads_.push_back(std::thread([this, threadId, &numLoops, &count]() { + // Do numLoops iterations of writes from this thread. + for (int i = 0; i < numLoops; ++i) { + LockMap::Accessor lock; + lockMap_->Insert(lock, "s"); + int temp = count + 1; + usleep(interval); + count = temp; + lockMap_->TryErase(lock); + } + })); + } + JoinThreads(); + ASSERT_TRUE(count == numWriterThreads * numLoops); + ASSERT_TRUE(lockMap_->Size() == 0); +} + +TEST_F(LockMapTest, TestPerEntryWriteLock) +{ + const int numWriterThreads = 15; + const int numLoops = 10; + const int entries = 3; + std::vector count(entries); + const int interval = 1000; + + for (int threadId = 0; threadId < numWriterThreads; ++threadId) { + writerThreads_.push_back(std::thread([this, threadId, &numLoops, &count, &entries]() { + // Do numLoops iterations of writes from this thread. + int index = threadId % entries; + for (int i = 0; i < numLoops; ++i) { + LockMap::Accessor lock; + lockMap_->Insert(lock, "s" + std::to_string(index)); + int temp = count[index] + 1; + usleep(interval); + count[index] = temp; + lockMap_->TryErase(lock); + } + })); + } + JoinThreads(); + + const int numContentionThreads = numWriterThreads / entries; + for (auto &entry : count) { + ASSERT_TRUE(entry == numContentionThreads * numLoops); + } + ASSERT_TRUE(lockMap_->Size() == 0); +} + +TEST_F(LockMapTest, TestConcurrentPerEntryReadWrite) +{ + const int numWriterThreads = 3; + const int numReaderThreads = 15; + const int numLoops = 10; + const int entries = 3; + std::vector count(entries); + const int interval = 1000; + + for (int threadId = 0; threadId < numWriterThreads; ++threadId) { + writerThreads_.push_back(std::thread([this, threadId, &numLoops, &count, &entries]() { + // Do numLoops iterations of writes from this thread. + int index = threadId % entries; + for (int i = 0; i < numLoops; ++i) { + LockMap::Accessor lock; + lockMap_->Insert(lock, "s" + std::to_string(index)); + int temp = count[index] + 1; + usleep(interval); + count[index] = temp; + } + })); + } + for (int threadId = 0; threadId < numReaderThreads; ++threadId) { + readerThreads_.push_back(std::thread([this, threadId, &numLoops, &count, &entries]() { + // Do numLoops iterations of reads from this thread. + int lastCount = 0; + int index = threadId % entries; + for (int i = 0; i < numLoops; ++i) { + LockMap::ConstAccessor lock; + if (lockMap_->Find(lock, "s" + std::to_string(index))) { + ASSERT_TRUE(count[index] >= lastCount); + lastCount = count[index]; + } + } + })); + } + JoinThreads(); + + const int numContentionThreads = numWriterThreads / entries; + for (auto &entry : count) { + ASSERT_TRUE(entry == numContentionThreads * numLoops); + } + ASSERT_TRUE(lockMap_->Size() == entries); + for (int i = 0; i < entries; i++) { + LockMap::Accessor lock; + ASSERT_TRUE(lockMap_->Find(lock, "s" + std::to_string(i))); + ASSERT_TRUE(lockMap_->TryErase(lock)); + } + ASSERT_TRUE(lockMap_->Size() == 0); +} + +TEST_F(LockMapTest, TestConcurrentPerValueReadWrite) +{ + const int numWriterThreads = 3; + const int numReaderThreads = 15; + const int numLoops = 10; + const int entries = 3; + const int interval = 1000; + + for (int threadId = 0; threadId < numWriterThreads; ++threadId) { + writerThreads_.push_back(std::thread([this, threadId, &numLoops, &entries]() { + // Do numLoops iterations of writes from this thread. + int index = threadId % entries; + for (int i = 0; i < numLoops; ++i) { + LockMap::Accessor lock; + lockMap_->Insert(lock, "s" + std::to_string(index)); + int temp = lock.entry->data + 1; + usleep(interval); + lock.entry->data = temp; + } + })); + } + for (int threadId = 0; threadId < numReaderThreads; ++threadId) { + readerThreads_.push_back(std::thread([this, threadId, &numLoops, &entries]() { + // Do numLoops iterations of reads from this thread. + int lastCount = 0; + int index = threadId % entries; + for (int i = 0; i < numLoops; ++i) { + LockMap::ConstAccessor lock; + if (lockMap_->Find(lock, "s" + std::to_string(index))) { + ASSERT_TRUE(lock.entry->data >= lastCount); + lastCount = lock.entry->data; + } + } + })); + } + JoinThreads(); + + const int numContentionThreads = numWriterThreads / entries; + for (auto &entry : *lockMap_) { + ASSERT_TRUE(entry.second.data == numContentionThreads * numLoops); + } + ASSERT_TRUE(lockMap_->Size() == entries); + for (int i = 0; i < entries; i++) { + LockMap::Accessor lock; + ASSERT_TRUE(lockMap_->Find(lock, "s" + std::to_string(i))); + ASSERT_TRUE(lockMap_->TryErase(lock)); + } + ASSERT_TRUE(lockMap_->Size() == 0); +} + +TEST_F(LockMapTest, TestFind) +{ + LockMap::Accessor lock; + // Try erase with invalid accessor + ASSERT_FALSE(lockMap_->Find(lock, "s")); + lockMap_->Insert(lock, "s"); + lock.Release(); + ASSERT_TRUE(lockMap_->Find(lock, "s")); + ASSERT_FALSE(lockMap_->Find(lock, "s1")); + std::thread thread1([this]() { + LockMap::Accessor lock; + ASSERT_TRUE(lockMap_->Find(lock, "s")); + }); + lock.Release(); + thread1.join(); + ASSERT_TRUE(lockMap_->Find(lock, "s")); + ASSERT_TRUE(lockMap_->TryErase(lock)); + ASSERT_TRUE(lockMap_->Size() == 0); +} + +TEST_F(LockMapTest, TestErase) +{ + const int interval = 100000; + LockMap::Accessor lock; + // Try erase with invalid accessor + ASSERT_FALSE(lockMap_->TryErase(lock)); + lockMap_->Insert(lock, "s"); + + std::thread thread1([this]() { + LockMap::ConstAccessor lock; + lockMap_->Insert(lock, "s"); + }); + std::thread thread2([this]() { + LockMap::Accessor lock; + lockMap_->Insert(lock, "s"); + }); + std::thread thread3([this]() { + LockMap::ConstAccessor lock; + ASSERT_TRUE(lockMap_->Find(lock, "s")); + }); + std::thread thread4([this]() { + LockMap::Accessor lock; + ASSERT_TRUE(lockMap_->Find(lock, "s")); + }); + + usleep(interval); + ASSERT_FALSE(lockMap_->TryErase(lock)); + + lock.Release(); + thread1.join(); + thread2.join(); + thread3.join(); + thread4.join(); + ASSERT_FALSE(lockMap_->TryErase(lock)); + lockMap_->Insert(lock, "s"); + ASSERT_TRUE(lockMap_->TryErase(lock)); + ASSERT_TRUE(lockMap_->Size() == 0); +} + +TEST_F(LockMapTest, TestBlockingEraseReaders) +{ + // Test BlockingErase will always erase the entry + const int interval = 1000; + const int numReaderThreads = 30; + const int numLoops = 10; + LockMap::Accessor lock; + lockMap_->Insert(lock, "s"); + for (int threadId = 0; threadId < numReaderThreads; ++threadId) { + readerThreads_.push_back(std::thread([this, threadId, &numLoops]() { + // Do numLoops iterations of reads from this thread. + for (int i = 0; i < numLoops; ++i) { + LockMap::ConstAccessor lock; + if (lockMap_->Find(lock, "s")) { + usleep(interval); + } else { + ASSERT_TRUE(lockMap_->Size() == 0); + } + } + })); + } + usleep(interval); + lockMap_->BlockingErase(lock); + JoinThreads(); + ASSERT_TRUE(lockMap_->Size() == 0); +} + +TEST_F(LockMapTest, TestBlockingEraseWriters) +{ + // Test BlockingErase will always erase the entry + const int interval = 1000; + const int numWriterThreads = 2; + const int numLoops = 10; + LockMap::Accessor lock; + lockMap_->Insert(lock, "s"); + for (int threadId = 0; threadId < numWriterThreads; ++threadId) { + writerThreads_.push_back(std::thread([this, threadId, &numLoops]() { + // Do numLoops iterations of reads from this thread. + for (int i = 0; i < numLoops; ++i) { + LockMap::Accessor lock; + if (lockMap_->Find(lock, "s")) { + usleep(interval); + } else { + ASSERT_TRUE(lockMap_->Size() == 0); + } + } + })); + } + usleep(interval); + lockMap_->BlockingErase(lock); + JoinThreads(); + ASSERT_TRUE(lockMap_->Size() == 0); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/worker/stream_cache/shared_page_queue_group_test.cpp b/tests/ut/worker/stream_cache/shared_page_queue_group_test.cpp new file mode 100644 index 0000000..78fbdc0 --- /dev/null +++ b/tests/ut/worker/stream_cache/shared_page_queue_group_test.cpp @@ -0,0 +1,109 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: SharedPageQueueGroup test + */ + +#include + +#include "common.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue_group.h" +#include "datasystem/common/util/uuid_generator.h" + +DS_DECLARE_uint32(sc_shared_page_group_count); + +using namespace datasystem::worker::stream_cache; +namespace datasystem { +namespace ut { +class SharedPageQueueGroupTest : public CommonTest { +public: + void SetUp() override + { + HostPort hostPort; + DS_ASSERT_OK(hostPort.ParseString("127.0.0.1:9000")); + const int pageGroupCount = 8; + FLAGS_sc_shared_page_group_count = pageGroupCount; + svc_ = std::make_shared(hostPort, hostPort, nullptr, nullptr, nullptr); + svc_->Init(); + pageQueueGroup_ = std::make_unique(hostPort, nullptr, svc_.get()); + CommonTest::SetUp(); + } + + void TearDown() override + { + CommonTest::TearDown(); + } + +protected: + std::unique_ptr pageQueueGroup_; + std::shared_ptr svc_; +}; + +TEST_F(SharedPageQueueGroupTest, TestPageQueueGroup) +{ + std::string tenantId1 = "tanant1"; + std::string streamName = "stream1"; + std::shared_ptr pageQueue1; + pageQueueGroup_->GetOrCreateSharedPageQueue(tenantId1 + "$" + streamName, pageQueue1); + std::shared_ptr pageQueue2; + pageQueueGroup_->GetOrCreateSharedPageQueue(tenantId1 + "$" + streamName, pageQueue2); + // using the same tenant and stream, will get the same page queue. + ASSERT_EQ(pageQueue1, pageQueue2); + std::shared_ptr pageQueue3; + pageQueueGroup_->GetOrCreateSharedPageQueue(streamName, pageQueue2); + // different tenant using different page queue. + ASSERT_NE(pageQueue1, pageQueue3); + + // get page queue by tenant and stream. + pageQueue2 = nullptr; + DS_ASSERT_OK(pageQueueGroup_->GetSharedPageQueue(tenantId1 + "$" + streamName, pageQueue2)); + ASSERT_EQ(pageQueue1, pageQueue2); + + // using unknown tenant. + pageQueue3 = nullptr; + DS_ASSERT_NOT_OK(pageQueueGroup_->GetSharedPageQueue("unknown$" + streamName, pageQueue3)); + + // remove page queue for tenant + DS_ASSERT_OK(pageQueueGroup_->RemoveSharedPageQueueForTenant(tenantId1)); + // remove page queue for unknown tenant + DS_ASSERT_NOT_OK(pageQueueGroup_->RemoveSharedPageQueueForTenant("unknown")); + // get after delete. + DS_ASSERT_NOT_OK(pageQueueGroup_->GetSharedPageQueue(tenantId1 + "$" + streamName, pageQueue2)); + DS_ASSERT_OK(pageQueueGroup_->GetSharedPageQueue(streamName, pageQueue2)); + + DS_ASSERT_OK(pageQueue1->AfterAck()); + std::vector recvElements; + DS_ASSERT_OK(pageQueue1->SendElements(nullptr, 0, 0, "127.0.0.1:9000", recvElements)); +} + +TEST_F(SharedPageQueueGroupTest, TestMaxPageQueue) +{ + std::string tenantId = "tanant1"; + std::set pageQueueMap; + int count = 200; + for (int i = 0; i < count; i++) { + std::shared_ptr pageQueue; + std::string streamName = GetStringUuid(); + pageQueueGroup_->GetOrCreateSharedPageQueue(tenantId + "$" + streamName, pageQueue); + pageQueueMap.insert(pageQueue.get()); + } + + ASSERT_LE(pageQueueMap.size(), FLAGS_sc_shared_page_group_count); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/worker/stream_cache/shared_page_queue_test.cpp b/tests/ut/worker/stream_cache/shared_page_queue_test.cpp new file mode 100644 index 0000000..0f89a4d --- /dev/null +++ b/tests/ut/worker/stream_cache/shared_page_queue_test.cpp @@ -0,0 +1,273 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: SharedPageQueue test + */ + +#include + +#include "common.h" +#include "../../../common/binmock/binmock.h" +#include "datasystem/common/constants.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/common/stream_cache/stream_fields.h" +#include "datasystem/worker/stream_cache/page_queue/page_queue_handler.h" +#include "datasystem/worker/stream_cache/page_queue/shared_page_queue.h" +#include "datasystem/worker/stream_cache/remote_worker_manager.h" +#include "datasystem/common/util/uuid_generator.h" + +DS_DECLARE_uint32(sc_shared_page_size_mb); + +using namespace ::testing; +using namespace datasystem::worker::stream_cache; + +namespace datasystem { +namespace ut { +constexpr int K_TWO = 2; +constexpr uint64_t SHM_CAP = 128L * 1024L * 1024L; +class SharedPageQueueTest : public CommonTest { +public: + void SetUp() override + { + datasystem::memory::Allocator::Instance()->Init(SHM_CAP); + allocate_ = std::make_shared(nullptr); + HostPort hostPort; + LOG_IF_ERROR(hostPort.ParseString("127.0.0.1:9000"), "ParseString failed"); + svc_ = std::make_shared(hostPort, hostPort, nullptr, nullptr, allocate_); + svc_->Init(); + CommonTest::SetUp(); + } + void TearDown() override + { + CommonTest::TearDown(); + } + + std::shared_ptr CreateSharedPageQueue() + { + HostPort hostPort; + LOG_IF_ERROR(hostPort.ParseString("127.0.0.1:9000"), "ParseString failed"); + FLAGS_sc_shared_page_size_mb = 1; + return std::make_shared("tenant", std::move(hostPort), 0, allocate_, svc_.get()); + } + +protected: + std::shared_ptr allocate_; + std::shared_ptr svc_; +}; + +TEST_F(SharedPageQueueTest, TestAllocateSharedPage) +{ + StreamFields streamFields; + streamFields.streamMode_ = StreamMode::MPSC; + PageQueueHandler handler(nullptr, Optional(streamFields)); + auto pageQueue = CreateSharedPageQueue(); + handler.SetSharedPageQueue(pageQueue); + // create cursor + std::shared_ptr out; + ShmView view; + DS_ASSERT_OK(handler.AddCursor("id", true, out, view)); + // client will set esyCatcher to K_CURSOR_SIZE_V2 + out->SetClientVersion(Cursor::K_CURSOR_SIZE_V2); + + // allocate shared page. + const int timeoutMs = 3000; + ShmView preLastView; + std::shared_ptr lastPage; + DS_ASSERT_OK(pageQueue->CreateOrGetLastDataPage(timeoutMs, preLastView, lastPage, false)); + ASSERT_TRUE(lastPage->IsSharedPage()); + preLastView = lastPage->GetShmView(); + std::shared_ptr lastPage2; + DS_ASSERT_OK(pageQueue->CreateOrGetLastDataPage(timeoutMs, preLastView, lastPage2, false)); + ASSERT_TRUE(lastPage2->IsSharedPage()); + ASSERT_EQ(lastPage, lastPage2); + + // create a client shm unit info. + auto shmUnitInfo = lastPage->GetShmUnitInfo(); + shmUnitInfo->pointer = reinterpret_cast(shmUnitInfo->pointer) - shmUnitInfo->offset; + int lockId = 12; + auto clientPage = std::make_shared(shmUnitInfo, lockId, true, false); + DS_ASSERT_OK(clientPage->Init()); + ASSERT_TRUE(clientPage->IsSharedPage()); + + int lockTimeout = 10000; // 10s. + // before lock, client will record page. + out->SetLastLockedPage(clientPage->GetShmView(), lockTimeout); + // lock page. + StreamPageLock pageLock(clientPage); + DS_ASSERT_OK(pageLock.Lock(lockTimeout)); + + // unlock. + pageQueue->TryUnlockByLockId(lockId); + + // lock again. + StreamPageLock pageLock2(clientPage); + DS_ASSERT_OK(pageLock2.Lock(lockTimeout)); +} + +TEST_F(SharedPageQueueTest, TestReadWrite) +{ + auto pageQueue = CreateSharedPageQueue(); + const int timeoutMs = 3000; + const size_t eleSz1 = 102ul; + const size_t eleSz2 = 1004ul; + std::string data1(eleSz1, 'a'); + std::string data2(eleSz2, 'b'); + ShmView preLastView; + std::shared_ptr lastPage; + DS_ASSERT_OK(pageQueue->CreateOrGetLastDataPage(timeoutMs, preLastView, lastPage, false)); + + uint64_t streamNo1 = 111; + uint64_t streamNo2 = 112; + HeaderAndData element1((uint8_t *)data1.c_str(), data1.size(), streamNo1); + HeaderAndData element2((uint8_t *)data2.c_str(), data2.size(), streamNo2); + InsertFlags flags = InsertFlags::NONE; + DS_ASSERT_OK(lastPage->Insert(element1, timeoutMs, flags)); + DS_ASSERT_OK(lastPage->Insert(element2, timeoutMs, flags)); + + std::vector out; + DS_ASSERT_OK(lastPage->Receive(0, timeoutMs, out)); + ASSERT_EQ(out.size(), 2ul); + std::string r1((char *)out[0].ptr, out[0].size); + std::string r2((char *)out[1].ptr, out[1].size); + ASSERT_EQ(data1, r1); + ASSERT_EQ(out[0].streamNo_, streamNo1); + ASSERT_EQ(data2, r2); + ASSERT_EQ(out[1].streamNo_, streamNo2); +} + +TEST_F(SharedPageQueueTest, TestWriteMaxSize) +{ + auto pageQueue = CreateSharedPageQueue(); + const int timeoutMs = 3000; + InsertFlags flags = InsertFlags::NONE; + ShmView preLastView; + std::shared_ptr lastPage; + DS_ASSERT_OK(pageQueue->CreateOrGetLastDataPage(timeoutMs, preLastView, lastPage, false)); + + const size_t elementMaxSize = FLAGS_sc_shared_page_size_mb * MB_TO_BYTES - StreamDataPage::PageOverhead(true); + std::string data1(elementMaxSize + 1, 'a'); + uint64_t streamNo1 = 111; + HeaderAndData element1((uint8_t *)data1.c_str(), data1.size(), streamNo1); + DS_ASSERT_NOT_OK(lastPage->Insert(element1, timeoutMs, flags)); + + std::string data2(elementMaxSize, 'a'); + uint64_t streamNo2 = 112; + HeaderAndData element2((uint8_t *)data2.c_str(), data2.size(), streamNo2); + DS_ASSERT_OK(lastPage->Insert(element2, timeoutMs, flags)); + + std::vector out; + DS_ASSERT_OK(lastPage->Receive(0, timeoutMs, out)); + ASSERT_EQ(out.size(), 1ul); + std::string r2((char *)out[0].ptr, out[0].size); + ASSERT_EQ(data2, r2); + ASSERT_EQ(out[0].streamNo_, streamNo2); +} + +TEST_F(SharedPageQueueTest, TestWriteFull) +{ + auto pageQueue = CreateSharedPageQueue(); + const int timeoutMs = 3000; + InsertFlags flags = InsertFlags::NONE; + ShmView preLastView; + std::shared_ptr lastPage; + DS_ASSERT_OK(pageQueue->CreateOrGetLastDataPage(timeoutMs, preLastView, lastPage, false)); + + size_t remainingSize = lastPage->PagePayloadSize(); + size_t count = 30; + size_t perSize = remainingSize % count; + uint64_t streamNo = 111; + size_t metaSize = StreamDataPage::GetMetaSize(true); + while (remainingSize > StreamDataPage::GetMetaSize(true)) { + size_t dataWithMetaSize = std::max(perSize, remainingSize); + size_t dataSize = dataWithMetaSize - metaSize; + remainingSize -= dataWithMetaSize; + std::string data(dataSize, 'a'); + HeaderAndData element((uint8_t *)data.c_str(), data.size(), streamNo); + DS_ASSERT_OK(lastPage->Insert(element, timeoutMs, flags)); + } + ASSERT_EQ(lastPage->GetFreeSpaceSize(), remainingSize); + + std::string data(1, 'a'); + uint64_t streamNo1 = 111; + HeaderAndData element((uint8_t *)data.c_str(), data.size(), streamNo1); + DS_ASSERT_NOT_OK(lastPage->Insert(element, timeoutMs, flags)); + + std::vector out; + DS_ASSERT_OK(lastPage->Receive(0, timeoutMs, out)); + size_t recvSize = 0; + for (const auto &e : out) { + ASSERT_EQ(e.streamNo_, streamNo); + recvSize += e.size + metaSize; + } + ASSERT_EQ(recvSize + remainingSize, lastPage->PagePayloadSize()); +} + +TEST_F(SharedPageQueueTest, TestAllocMemeoryAndRelease) +{ + BINEXPECT_CALL(&PageQueueHandler::CreateExclusivePageQueue, (_, _)).Times(1).WillOnce(Return(nullptr)); + PageQueueHandler handler(nullptr, Optional()); + auto pageQueue = CreateSharedPageQueue(); + handler.SetSharedPageQueue(pageQueue); + ASSERT_EQ(handler.GetPageSize(), FLAGS_sc_shared_page_size_mb * MB_TO_BYTES); + size_t pageSize = 1024 * 1024 * 5; + std::shared_ptr pageUnitInfo; + DS_ASSERT_OK(handler.AllocMemory(pageSize, true, pageUnitInfo, false)); + handler.DumpPoolPages(1); + ShmView shmView{ + .fd = pageUnitInfo->fd, .mmapSz = pageUnitInfo->mmapSize, .off = pageUnitInfo->offset, .sz = pageUnitInfo->size + }; + DS_ASSERT_OK(handler.ReleaseMemory(shmView)); + handler.DumpPoolPages(1); +} + +TEST_F(SharedPageQueueTest, TestRemoteAck) +{ + auto pageQueue = CreateSharedPageQueue(); + const int timeoutMs = 3000; + const size_t eleSz1 = 102ul; + const size_t eleSz2 = 1004ul; + std::string data1(eleSz1, 'a'); + std::string data2(eleSz2, 'b'); + ShmView preLastView; + std::shared_ptr lastPage; + DS_ASSERT_OK(pageQueue->CreateOrGetLastDataPage(timeoutMs, preLastView, lastPage, false)); + + uint64_t streamNo1 = 111; + uint64_t streamNo2 = 112; + HeaderAndData element1((uint8_t *)data1.c_str(), data1.size(), streamNo1); + HeaderAndData element2((uint8_t *)data2.c_str(), data2.size(), streamNo2); + InsertFlags flags = InsertFlags::NONE; + DS_ASSERT_OK(lastPage->Insert(element1, timeoutMs, flags)); + DS_ASSERT_OK(lastPage->Insert(element2, timeoutMs, flags)); + + std::vector out; + DS_ASSERT_OK(lastPage->Receive(0, timeoutMs, out)); + ASSERT_EQ(out.size(), 2ul); + std::string r1((char *)out[0].ptr, out[0].size); + std::string r2((char *)out[1].ptr, out[1].size); + ASSERT_EQ(data1, r1); + ASSERT_EQ(out[0].streamNo_, streamNo1); + ASSERT_EQ(data2, r2); + ASSERT_EQ(out[1].streamNo_, streamNo2); + + BINEXPECT_CALL(&RemoteWorkerManager::GetLastAckCursor, (_)).WillRepeatedly(Return(K_TWO)); + DS_ASSERT_OK(pageQueue->RemoteAck()); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/worker/stream_cache/stream_bufferpool_test.cpp b/tests/ut/worker/stream_cache/stream_bufferpool_test.cpp new file mode 100644 index 0000000..368db07 --- /dev/null +++ b/tests/ut/worker/stream_cache/stream_bufferpool_test.cpp @@ -0,0 +1,168 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Stream buffer pool test + */ + +#include +#include +#include +#include +#include +#include + +#include "datasystem/worker/stream_cache/buffer_pool.h" +#include "common.h" + +using namespace datasystem::worker::stream_cache; +using namespace std::placeholders; +namespace datasystem { +namespace ut { +constexpr static int numDirtyList = 8; +class StreamData : public BaseBufferData { +public: + explicit StreamData(std::string streamName) : streamName_(std::move(streamName)) + { + } + ~StreamData() override = default; + std::string StreamName() const override + { + return streamName_; + } + + std::string ProducerName() const override + { + return streamName_; + } + + std::string ProducerInstanceId() const override + { + return streamName_; + } + + Status ReleasePage() override + { + return Status::OK(); + } + + uint64_t StreamHash() const override + { + return std::hash{}(streamName_); + } + + Status GetStreamStatus() + { + return Status::OK(); + } + + bool IfEOSReply() + { + return false; + } + +private: + std::string streamName_; +}; +class BufferPoolTest : public CommonTest { +public: + BufferPoolTest() : flushCount_(numDirtyList) + { + } + ~BufferPoolTest() override = default; + + void SetUp() override + { + bp_ = + std::make_unique(numDirtyList, "BufferPool", std::bind(&BufferPoolTest::FlushFn, this, _1, _2)); + for (auto i = 0; i < numDirtyList; ++i) { + flushCount_[i] = 0; + } + CommonTest::SetUp(); + } + void TearDown() override + { + if (bp_) { + bp_->Stop(); + bp_.reset(); + } + CommonTest::TearDown(); + } + + Status FlushFn(int id, PendingFlushList &flushList) + { + (void)id; + for (auto &ele : flushList) { + auto &datalist = ele.second; + flushCount_[id] += datalist.size(); + datalist.clear(); + } + return Status::OK(); + } + +protected: + std::unique_ptr bp_; + std::vector flushCount_; + +private: +}; + +TEST_F(BufferPoolTest, TestOneStream) +{ + DS_ASSERT_OK(bp_->Init()); + const std::string streamName(RandomData().GetRandomString(8)); + const int numElements = 1000; + for (auto i = 0; i < numElements; ++i) { + auto data = std::make_shared(streamName); + bp_->Insert(data); + } + sleep(1); + int total = 0; + for (auto i = 0; i < numDirtyList; ++i) { + total += flushCount_[i]; + } + ASSERT_EQ(total, numElements); +} + +TEST_F(BufferPoolTest, TestMultiStreams) +{ + DS_ASSERT_OK(bp_->Init()); + const int numStreams = 100; + std::vector streamNames; + streamNames.reserve(numStreams); + for (int i = 0; i < numStreams; ++i) { + streamNames.emplace_back(RandomData().GetRandomString(8)); + } + const int numElements = 1000; + auto threadPool = std::make_unique(5); + for (auto i = 0; i < numElements; ++i) { + for (auto j = 0; j < numStreams; ++j) { + threadPool->Execute([this, &streamNames, j]() { + auto data = std::make_shared(streamNames[j]); + bp_->Insert(data); + }); + } + } + threadPool.reset(); + sleep(1); + int total = 0; + for (auto i = 0; i < numDirtyList; ++i) { + total += flushCount_[i]; + } + EXPECT_EQ(total, numElements * numStreams); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/worker/stream_cache/stream_cursor_test.cpp b/tests/ut/worker/stream_cache/stream_cursor_test.cpp new file mode 100644 index 0000000..8056930 --- /dev/null +++ b/tests/ut/worker/stream_cache/stream_cursor_test.cpp @@ -0,0 +1,124 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test StreamPage StreamPageOwner classes. + */ + +#include "common.h" +#include "datasystem/common/constants.h" +#include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/stream_cache/cursor.h" +#include "datasystem/common/util/thread_pool.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/stream_cache/producer.h" + +namespace datasystem { +namespace ut { +constexpr size_t NUM_THREADS = 4; +constexpr uint64_t SHM_CAP = 64L * 1024L * 1024L; + +class StreamCursorTest : public CommonTest { +protected: + void SetUp() override; + void TearDown() override; + StreamCursorTest() : pool_(NUM_THREADS) + { + } + ~StreamCursorTest() override = default; + std::shared_ptr shmUnit_; + std::shared_ptr cursor_; + ThreadPool pool_; +}; + +void StreamCursorTest::SetUp() +{ + FLAGS_v = datasystem::SC_INTERNAL_LOG_LEVEL; + datasystem::memory::Allocator::Instance()->Init(SHM_CAP); + shmUnit_ = std::make_shared(); + DS_ASSERT_OK(shmUnit_->AllocateMemory("CursorTest", Cursor::K_CURSOR_SIZE_V2, false)); + cursor_ = std::make_shared(shmUnit_->GetPointer(), Cursor::K_CURSOR_SIZE_V2, 0); + ASSERT_EQ(cursor_->Init(), Status::OK()); +} + +void StreamCursorTest::TearDown() +{ +} + +TEST_F(StreamCursorTest, TestShmViewAndFutexArea) +{ + cursor_->InitFutexArea(); + const int32_t val = K_OUT_OF_MEMORY; + auto fut = pool_.Submit([this]() { + ShmView view; + Timer t(DEFAULT_TIMEOUT_MS); + while (t.GetRemainingTimeMs() > 0) { + RETURN_IF_NOT_OK(cursor_->GetLastPageView(view, DEFAULT_TIMEOUT_MS)); + if (view != ShmView()) { + break; + } + } + LOG(INFO) << "LastPage view: " << view.ToStr(); + CHECK_FAIL_RETURN_STATUS(view == shmUnit_->GetShmView(), K_RUNTIME_ERROR, "GetLastPageView error"); + size_t numWaiter; + RETURN_IF_NOT_OK(cursor_->Wake(val, numWaiter)); + LOG(INFO) << "Wake up " << numWaiter << " waiters"; + return Status::OK(); + }); + // Send a ShmView + DS_ASSERT_OK(cursor_->SetLastPage(shmUnit_->GetShmView(), DEFAULT_TIMEOUT_MS)); + // Get a reply + int32_t fetchVal; + DS_ASSERT_OK(cursor_->Wait(DEFAULT_TIMEOUT_MS, fetchVal)); + ASSERT_EQ(val, fetchVal); + DS_ASSERT_OK(fut.get()); +} + +TEST_F(StreamCursorTest, TestForceClose) +{ + auto fut = pool_.Submit([this]() { + Timer t(DEFAULT_TIMEOUT_MS); + while (t.GetRemainingTimeMs() > 0) { + if (cursor_->ForceClose()) { + return Status::OK(); + } + } + RETURN_STATUS(K_RUNTIME_ERROR, "Force close error"); + }); + cursor_->SetForceClose(); + DS_ASSERT_OK(fut.get()); +} + +TEST_F(StreamCursorTest, TestSetCusorToProducer) +{ + auto producer = std::make_shared("producerId", "StreamName", nullptr); + auto copyCursor = cursor_; + producer->SetCursor(std::move(copyCursor)); + + auto fut = pool_.Submit([this]() { + Timer t(DEFAULT_TIMEOUT_MS); + while (t.GetRemainingTimeMs() > 0) { + if (cursor_->ForceClose()) { + return Status::OK(); + } + } + RETURN_STATUS(K_RUNTIME_ERROR, "Force close error"); + }); + producer->SetForceClose(); + DS_ASSERT_OK(fut.get()); +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/worker/stream_cache/stream_data_page_test.cpp b/tests/ut/worker/stream_cache/stream_data_page_test.cpp new file mode 100644 index 0000000..fe0e878 --- /dev/null +++ b/tests/ut/worker/stream_cache/stream_data_page_test.cpp @@ -0,0 +1,269 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Test StreamPage StreamPageOwner classes. + */ + +#include "common.h" +#include "datasystem/common/perf/perf_manager.h" +#include "datasystem/common/shared_memory/allocator.h" +#include "datasystem/common/stream_cache/stream_data_page.h" +#include "datasystem/stream/stream_config.h" +#include "datasystem/worker/stream_cache/buffer_pool.h" + +namespace datasystem { +namespace ut { +using namespace datasystem::worker::stream_cache; +constexpr static int FOUR_K = 4096; +constexpr uint64_t SHM_CAP = 128L * 1024L * 1024L; + +class StreamDataPageTest : public CommonTest { +protected: + void SetUp() override; + void TearDown() override; + + RandomData randomData_; + std::shared_ptr pageUnit_; + std::shared_ptr page_; + static bool VerifyElement(const std::string &expected, const Element &ele); + std::string GetRandomString(); + Status InsertUntilPageFull(std::shared_ptr &pageUnit, uint32_t lockId, bool isClient, + std::vector &out); + Status ReceiveUntilTimeout(std::shared_ptr &pageUnit, uint64_t &lastRecvCursor, uint32_t lockId, + uint64_t timeoutMs, std::vector &out); + static std::string ShmUnitInfoToStr(std::shared_ptr &shm) + { + ShmView v = { .fd = shm->fd, .mmapSz = shm->mmapSize, .off = shm->offset, .sz = shm->size }; + return v.ToStr(); + } +}; + +void StreamDataPageTest::SetUp() +{ + FLAGS_v = datasystem::SC_INTERNAL_LOG_LEVEL; + datasystem::memory::Allocator::Instance()->Init(SHM_CAP); + pageUnit_ = std::make_shared(); + DS_ASSERT_OK(pageUnit_->AllocateMemory("", FOUR_K + StreamDataPage::PageOverhead(), false)); + page_ = std::make_shared(pageUnit_, 0, false); + ASSERT_EQ(page_->Init(), Status::OK()); +} + +void StreamDataPageTest::TearDown() +{ +} + +bool StreamDataPageTest::VerifyElement(const std::string &expected, const Element &ele) +{ + std::string str(reinterpret_cast(ele.ptr), ele.size); + return expected == str; +} + +std::string StreamDataPageTest::GetRandomString() +{ + const int maxLen = 20; + auto str = randomData_.GetRandomString((randomData_.GetRandomUint8() % maxLen) + 1); + return str; +} + +Status StreamDataPageTest::InsertUntilPageFull(std::shared_ptr &pageUnit, uint32_t lockId, bool isClient, + std::vector &out) +{ + auto page = std::make_shared(pageUnit, lockId, isClient); + RETURN_IF_NOT_OK(page->Init()); + Status rc; + while (rc.IsOk()) { + auto str = GetRandomString(); + HeaderAndData ele(reinterpret_cast(str.data()), str.length(), 0); + auto flag = InsertFlags::NONE; + rc = page->Insert(ele, 0, flag); + if (rc.GetCode() == K_NO_SPACE) { + rc = Status::OK(); + break; + } + if (rc.GetCode() == K_TRY_AGAIN) { + rc = Status::OK(); + continue; + } + RETURN_IF_NOT_OK(rc); + out.emplace_back(std::move(str)); + } + return rc; +} + +Status StreamDataPageTest::ReceiveUntilTimeout(std::shared_ptr &pageUnit, uint64_t &lastRecvCursor, + uint32_t lockId, uint64_t timeoutMs, std::vector &out) +{ + auto page = std::make_shared(pageUnit, lockId, true); + RETURN_IF_NOT_OK(page->Init()); + Status rc; + while (rc.IsOk()) { + std::vector elements; + rc = page->Receive(lastRecvCursor, timeoutMs, elements); + if (rc.GetCode() == K_TRY_AGAIN) { + rc = Status::OK(); + break; + } + RETURN_IF_NOT_OK(rc); + LOG(INFO) << "Receive " << elements.size() << " elements"; + lastRecvCursor += elements.size(); + out.insert(out.end(), elements.begin(), elements.end()); + } + return rc; +} + +TEST_F(StreamDataPageTest, TestCreateDataPageSuccess) +{ + ASSERT_EQ(page_->InitEmptyPage(), Status::OK()); +} + +TEST_F(StreamDataPageTest, TestCreateDataPageFail) +{ + const int NotBigEnough = 48; + auto pageUnit = std::make_shared(); + DS_ASSERT_OK(pageUnit->AllocateMemory("", NotBigEnough, false)); + auto page = std::make_shared(pageUnit, 0, false); + Status rc = page->Init(); + LOG(INFO) << rc.ToString(); + ASSERT_NE(rc, Status::OK()); +} + +TEST_F(StreamDataPageTest, TestMultiElementsRW) +{ + ASSERT_EQ(page_->InitEmptyPage(), Status::OK()); + std::vector strs; + Status rc = InsertUntilPageFull(pageUnit_, 0, false, strs); + ASSERT_EQ(rc, Status::OK()); + ASSERT_GT(strs.size(), 0u); + LOG(INFO) << FormatString("Number of elements inserted %zu", strs.size()); + // Receive all the elements on the page starting from cursor 1. + std::vector v; + rc = page_->Receive(0, 0, v); + ASSERT_EQ(rc, Status::OK()); + ASSERT_EQ(strs.size(), v.size()); + // Compare what we insert + for (size_t i = 0; i < strs.size(); ++i) { + ASSERT_TRUE(VerifyElement(strs[i], v[i])); + } +} + +TEST_F(StreamDataPageTest, TestMultiElementsSPSC) +{ + // Two threads. One simulate producer and one simulate consumer. + const int poolSz = 2; + auto pool = std::make_unique(poolSz); + ASSERT_EQ(page_->InitEmptyPage(), Status::OK()); + std::vector strs; + // Producer simulation + auto res1 = pool->Submit([this, &strs]() { + // Pick lockID 1 + DS_ASSERT_OK(InsertUntilPageFull(pageUnit_, 1, true, strs)); + ASSERT_GT(strs.size(), 0u); + LOG(INFO) << FormatString("Number of elements inserted %zu", strs.size()); + }); + // Consumer simulation + std::vector v; + uint64_t lastRecvCursor = 0; + auto res2 = pool->Submit([this, &v, &lastRecvCursor]() { + // Pick lockID 2 + // We will pick a long timeout value (5s), keep receiving until we time out + const uint64_t timeoutMs = 5 * 1000; + DS_ASSERT_OK(ReceiveUntilTimeout(pageUnit_, lastRecvCursor, 2, timeoutMs, v)); + }); + // Wait for both threads to come back. + res1.get(); + res2.get(); + ASSERT_EQ(strs.size(), v.size()); + // Compare what we insert + for (size_t i = 0; i < strs.size(); ++i) { + ASSERT_TRUE(VerifyElement(strs[i], v[i])); + } +} + +TEST_F(StreamDataPageTest, TestMultiElementsMPSC) +{ + // A few number of producers and 1 consumer + const int numProducers = 2; + const int numConsumers = 1; + auto pool = std::make_unique(numProducers + numConsumers); + ASSERT_EQ(page_->InitEmptyPage(), Status::OK()); + std::atomic lockId(1); + // Producer simulation + std::vector> producerRes; + std::vector> strs(numProducers); + for (auto i = 0; i < numProducers; ++i) { + producerRes.emplace_back(pool->Submit([this, &lockId, &strs, i]() { + DS_ASSERT_OK(InsertUntilPageFull(pageUnit_, lockId.fetch_add(1), true, strs[i])); + ASSERT_GT(strs[i].size(), 0u); + LOG(INFO) << FormatString("Number of elements inserted %zu", strs[i].size()); + })); + } + // Consumer simulation + std::vector v; + uint64_t lastRecvCursor = 0; + auto res2 = pool->Submit([this, &v, &lastRecvCursor, &lockId]() { + // We will pick a long timeout value (5s), keep receiving until we time out + const uint64_t timeoutMs = 5 * 1000; + DS_ASSERT_OK(ReceiveUntilTimeout(pageUnit_, lastRecvCursor, lockId.fetch_add(1), timeoutMs, v)); + }); + for (auto &res : producerRes) { + res.get(); + } + res2.get(); + LOG(INFO) << FormatString("Consumer receives %zu elements", v.size()); + size_t totalElements = 0; + for (auto &str : strs) { + totalElements += str.size(); + } + ASSERT_EQ(totalElements, v.size()); +} + +TEST_F(StreamDataPageTest, LEVEL2_TestMillionInsert) +{ + PerfManager *perfManager = PerfManager::Instance(); + FLAGS_v = 0; + const size_t pageSize = 1048576ul; + const size_t eleSz = 1024ul; + std::string a(eleSz, 'a'); + const size_t numElements = 1'280'000ul; + auto pageUnit = std::make_shared(); + DS_ASSERT_OK(pageUnit->AllocateMemory("", pageSize + StreamDataPage::PageOverhead(), false)); + auto page = std::make_shared(pageUnit, 0, false); + DS_ASSERT_OK(page->Init()); + DS_ASSERT_OK(page->InitEmptyPage()); + Status rc; + HeaderAndData ele((uint8_t *)a.c_str(), a.size(), 0); + Timer t; + size_t numPagesNeeded = 1; + auto flags = InsertFlags::NONE; + for (size_t i = 0; i < numElements; ++i) { + rc = page->Insert(ele, 0, flags); + if (rc.GetCode() == K_NO_SPACE) { + DS_ASSERT_OK(page->InitEmptyPage()); + ++numPagesNeeded; + DS_ASSERT_OK(page->Insert(ele, 0, flags)); + continue; + } + DS_ASSERT_OK(rc); + } + LOG(INFO) << FormatString("Elapsed time [%.6lf]s. Total pages %zu", t.ElapsedSecond(), numPagesNeeded); + if (perfManager != nullptr) { + perfManager->Tick(); + perfManager->PrintPerfLog(); + } +} +} // namespace ut +} // namespace datasystem diff --git a/tests/ut/worker/stream_cache/stream_usagemonitor_test.cpp b/tests/ut/worker/stream_cache/stream_usagemonitor_test.cpp new file mode 100644 index 0000000..040e8ab --- /dev/null +++ b/tests/ut/worker/stream_cache/stream_usagemonitor_test.cpp @@ -0,0 +1,291 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Description: Usage Monitor test + */ + +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "datasystem/common/util/status_helper.h" +#include "datasystem/worker/stream_cache/client_worker_sc_service_impl.h" +#include "datasystem/worker/stream_cache/usage_monitor.h" + +using namespace datasystem::worker::stream_cache; +namespace datasystem { +namespace worker { +namespace stream_cache { +class ClientWorkerSCServiceImplMock : public ClientWorkerSCServiceImpl { +public: + ClientWorkerSCServiceImplMock() : ClientWorkerSCServiceImpl(HostPort(), HostPort(), nullptr, nullptr, nullptr){}; + Status SendBlockProducerReq(const std::string &streamName, const std::string &remoteWorkerAddr) + { + std::string id = streamName + remoteWorkerAddr; + blockedProducers.insert(id); + return Status::OK(); + } + Status SendUnBlockProducerReq(const std::string &streamName, const std::string &remoteWorkerAddr) + { + std::string id = streamName + remoteWorkerAddr; + blockedProducers.erase(id); + return Status::OK(); + } + std::multiset blockedProducers; +}; +} // namespace stream_cache +} // namespace worker +namespace ut { +static constexpr uint64_t DEFAULT_TOTAL_SIZE = 10; +class UsageMonitorTest : public CommonTest { +public: + UsageMonitorTest() + { + } + ~UsageMonitorTest() override = default; + + void SetUp() override + { + FLAGS_v = 2; // vlog is 2. + cliWorkerMockPtr_ = std::make_shared(); + impl_ = cliWorkerMockPtr_.get(); + usageMonitor_ = std::make_unique(impl_, DEFAULT_TOTAL_SIZE); + DS_ASSERT_OK(usageMonitor_->Init()); + CommonTest::SetUp(); + } + void TearDown() override + { + if (usageMonitor_) { + usageMonitor_->Stop(); + usageMonitor_.reset(); + } + CommonTest::TearDown(); + } + std::shared_ptr cliWorkerMockPtr_; + ClientWorkerSCServiceImpl *impl_; + std::unique_ptr usageMonitor_; +}; + +TEST_F(UsageMonitorTest, TestNormalCase) +{ + const std::string streamName(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr(RandomData().GetRandomString(8)); + // Dummy Reserve with size of 0 + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, 0)); + // Should not be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr, 1); + usageMonitor_->DecUsage(streamName, remoteWorkerAddr, 1); + ASSERT_FALSE(usageMonitor_->CheckOverUsed().IsError()); + + // 80% of memory is used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr, 4); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr, 5); + ASSERT_TRUE(usageMonitor_->CheckOverUsed(0.8).IsError()); + ASSERT_FALSE(usageMonitor_->CheckOverUsed(1).IsError()); + + // Should be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr, 11); + ASSERT_TRUE(usageMonitor_->CheckOverUsed().IsError()); +} + +TEST_F(UsageMonitorTest, TestNormalCaseUsagePerStream) +{ + const std::string streamName(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr1(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr2(RandomData().GetRandomString(8)); + // Dummy Reserve with size of 0 + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, 0)); + // Should not be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 1); + usageMonitor_->DecUsage(streamName, remoteWorkerAddr1, 1); + ASSERT_FALSE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, 1, 0).IsError()); + + // 80% of memory is used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 4); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr2, 5); + const double threshold = 0.8; + ASSERT_TRUE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, threshold, 0).IsError()); + ASSERT_FALSE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, 1, 0).IsError()); + + // Should be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 11); + ASSERT_TRUE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, 1, 0).IsError()); +} + +TEST_F(UsageMonitorTest, TestBGThreadLogic) +{ + const std::string streamName(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr(RandomData().GetRandomString(8)); + // Dummy Reserve with size of 0 + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, 0)); + // Not over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr, 1); + std::shared_ptr usageItem; + usageMonitor_->GetMostUsed(usageItem); + ASSERT_FALSE(usageMonitor_->CheckOverUsed().IsError()); + sleep(1); + // usage not blocked by background thread + ASSERT_FALSE(usageItem->usageBlocked); + + // Should be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr, 11); + ASSERT_TRUE(usageMonitor_->CheckOverUsed().IsError()); + sleep(1); + // usage blocked by background thread + ASSERT_TRUE(usageItem->usageBlocked); + + // Next largest Item + const std::string remoteWorkerAddr1(RandomData().GetRandomString(8)); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 1); + std::shared_ptr usageItem1; + usageMonitor_->GetMostUsed(usageItem1); + sleep(1); + // usage blocked by background thread + ASSERT_TRUE(usageItem1->usageBlocked); + + usageMonitor_->DecUsage(streamName, remoteWorkerAddr, 12); + sleep(1); + // Background thread unblocks all items + ASSERT_FALSE(usageItem->usageBlocked); + ASSERT_FALSE(usageItem1->usageBlocked); +} + +TEST_F(UsageMonitorTest, TestGetMostUsed) +{ + const std::string streamName(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr1(RandomData().GetRandomString(8)); + // Dummy Reserve with size of 0 + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, 0)); + ASSERT_FALSE(usageMonitor_->CheckOverUsed().IsError()); + // Should be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 1); + + const std::string remoteWorkerAddr2(RandomData().GetRandomString(8)); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr2, 10); + + const std::string remoteWorkerAddr3(RandomData().GetRandomString(8)); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr3, 1); + + // Background thread blocked the most used producer + std::shared_ptr usageItem; + usageMonitor_->GetMostUsed(usageItem); + ASSERT_EQ(usageItem->streamName, streamName); + ASSERT_EQ(usageItem->remoteWorkerAddr, remoteWorkerAddr2); + ASSERT_EQ(usageItem->usage, DEFAULT_TOTAL_SIZE); +} + +TEST_F(UsageMonitorTest, DISABLED_TestGetMostUsedBlocked) +{ + const std::string streamName(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr1(RandomData().GetRandomString(8)); + // Dummy Reserve with size of 0 + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, 0)); + ASSERT_FALSE(usageMonitor_->CheckOverUsed().IsError()); + // Should be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 1); + + const std::string remoteWorkerAddr2(RandomData().GetRandomString(8)); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr2, 10); + std::shared_ptr usageItem; + usageMonitor_->GetMostUsed(usageItem); + usageItem->usageBlocked = true; + + const std::string remoteWorkerAddr3(RandomData().GetRandomString(8)); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr3, 2); + + // As previous most used is blocked we should get next most used + std::shared_ptr usageItem1; + usageMonitor_->GetMostUsed(usageItem1); + ASSERT_EQ(usageItem1->streamName, streamName); + ASSERT_EQ(usageItem1->remoteWorkerAddr, remoteWorkerAddr3); + ASSERT_EQ(usageItem1->usage, (uint64_t)2); +} + +TEST_F(UsageMonitorTest, TestBasicReserve1) +{ + const int moreThanHalf = 6; + const std::string streamName(RandomData().GetRandomString(8)); + const std::string streamName2(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr1(RandomData().GetRandomString(8)); + + // First test that memory cannot be reserved if memory are already reserved + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, moreThanHalf)); + DS_ASSERT_NOT_OK(usageMonitor_->ReserveMemory(streamName2, moreThanHalf)); + usageMonitor_->UndoReserveMemory(streamName); + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName2, moreThanHalf)); + usageMonitor_->UndoReserveMemory(streamName2); + + // Then test that memory can be reserved even if not enough memory is available upfront + // As long as the total reserved memory is still in bound, it is allowed to reserve + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, 1)); + DS_ASSERT_OK(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, 1, moreThanHalf)); + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName2, moreThanHalf)); + + // Also test that one cannot use other's reserved memory + // Stream1 reserved 1, Stream2 reserved 6, so stream1 can at most use 4, while stream2 can use at most 9 + DS_ASSERT_NOT_OK(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, 1, moreThanHalf)); + DS_ASSERT_NOT_OK( + usageMonitor_->CheckNIncOverUsedForStream(streamName2, remoteWorkerAddr1, 0, 1, DEFAULT_TOTAL_SIZE)); + DS_ASSERT_OK( + usageMonitor_->CheckNIncOverUsedForStream(streamName2, remoteWorkerAddr1, 0, 1, DEFAULT_TOTAL_SIZE - 1)); +} + +TEST_F(UsageMonitorTest, TestBasicReserve2) +{ + const int moreThanHalf = 6; + const std::string streamName(RandomData().GetRandomString(8)); + const std::string streamName2(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr1(RandomData().GetRandomString(8)); + + // Test that reserved memory adjustment cannot go over the limit + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, moreThanHalf)); + DS_ASSERT_NOT_OK(usageMonitor_->ReserveMemory(streamName, DEFAULT_TOTAL_SIZE + 1)); + + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName2, 1)); + DS_ASSERT_NOT_OK(usageMonitor_->ReserveMemory(streamName, DEFAULT_TOTAL_SIZE)); +} + +TEST_F(UsageMonitorTest, TestLowerBound) +{ + const std::string streamName(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr1(RandomData().GetRandomString(8)); + const std::string remoteWorkerAddr2(RandomData().GetRandomString(8)); + // Dummy Reserve with size of 0 + DS_ASSERT_OK(usageMonitor_->ReserveMemory(streamName, 0)); + + // Should not be over used as lower bound is 1 though the ratio is 0.01 + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 1); + const double lowerTh = 0.01; + ASSERT_TRUE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, lowerTh, 0).IsError()); + ASSERT_FALSE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 1, lowerTh, 0).IsError()); + LOG_IF_ERROR(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 0, lowerTh, 0), "OOM"); + // 80% of memory is used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 4); + usageMonitor_->IncUsage(streamName, remoteWorkerAddr2, 5); + const double higherTh = 0.8; + ASSERT_TRUE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 1, higherTh, 0).IsError()); + ASSERT_FALSE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 1, 1, 0).IsError()); + + // Should be over used + usageMonitor_->IncUsage(streamName, remoteWorkerAddr1, 11); + ASSERT_TRUE(usageMonitor_->CheckNIncOverUsedForStream(streamName, remoteWorkerAddr1, 1, 1, 0).IsError()); +} +} // namespace ut +} // namespace datasystem diff --git a/third_party/P2P-Transfer/include/communicator/P2PCommunicatorManager.h b/third_party/P2P-Transfer/include/communicator/P2PCommunicatorManager.h index 1acc652..f563d93 100644 --- a/third_party/P2P-Transfer/include/communicator/P2PCommunicatorManager.h +++ b/third_party/P2P-Transfer/include/communicator/P2PCommunicatorManager.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/include/external/ra.h b/third_party/P2P-Transfer/include/external/ra.h index dbf8f10..7de2772 100644 --- a/third_party/P2P-Transfer/include/external/ra.h +++ b/third_party/P2P-Transfer/include/external/ra.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/include/external/tsd.h b/third_party/P2P-Transfer/include/external/tsd.h index 1cdcc5e..fa75285 100644 --- a/third_party/P2P-Transfer/include/external/tsd.h +++ b/third_party/P2P-Transfer/include/external/tsd.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/include/tools/host-interface.h b/third_party/P2P-Transfer/include/tools/host-interface.h index 254872d..63cc010 100644 --- a/third_party/P2P-Transfer/include/tools/host-interface.h +++ b/third_party/P2P-Transfer/include/tools/host-interface.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communication/TcpClient.cpp b/third_party/P2P-Transfer/source/communication/TcpClient.cpp index 36a4047..f716183 100644 --- a/third_party/P2P-Transfer/source/communication/TcpClient.cpp +++ b/third_party/P2P-Transfer/source/communication/TcpClient.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communication/TcpServer.cpp b/third_party/P2P-Transfer/source/communication/TcpServer.cpp index 4b9cbc9..9abda1c 100644 --- a/third_party/P2P-Transfer/source/communication/TcpServer.cpp +++ b/third_party/P2P-Transfer/source/communication/TcpServer.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communicator/P2PCommunicator.cpp b/third_party/P2P-Transfer/source/communicator/P2PCommunicator.cpp index c5f11e9..60b41d0 100644 --- a/third_party/P2P-Transfer/source/communicator/P2PCommunicator.cpp +++ b/third_party/P2P-Transfer/source/communicator/P2PCommunicator.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsReceiver.cpp b/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsReceiver.cpp index 7d34654..238e5ab 100644 --- a/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsReceiver.cpp +++ b/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsReceiver.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsSender.cpp b/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsSender.cpp index 9377cce..5275d4c 100644 --- a/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsSender.cpp +++ b/third_party/P2P-Transfer/source/communicator/hccs-ipc/HccsSender.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communicator/roce/RoceReceiver.cpp b/third_party/P2P-Transfer/source/communicator/roce/RoceReceiver.cpp index f89f95b..5c17fbc 100644 --- a/third_party/P2P-Transfer/source/communicator/roce/RoceReceiver.cpp +++ b/third_party/P2P-Transfer/source/communicator/roce/RoceReceiver.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communicator/roce/RoceSender.cpp b/third_party/P2P-Transfer/source/communicator/roce/RoceSender.cpp index 600db9f..fcb8fc1 100644 --- a/third_party/P2P-Transfer/source/communicator/roce/RoceSender.cpp +++ b/third_party/P2P-Transfer/source/communicator/roce/RoceSender.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/communicator/test.cpp b/third_party/P2P-Transfer/source/communicator/test.cpp index 4a6f1b6..8ca8c72 100644 --- a/third_party/P2P-Transfer/source/communicator/test.cpp +++ b/third_party/P2P-Transfer/source/communicator/test.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/Hccp.cpp b/third_party/P2P-Transfer/source/npu/Hccp.cpp index 9e60277..9a23aeb 100644 --- a/third_party/P2P-Transfer/source/npu/Hccp.cpp +++ b/third_party/P2P-Transfer/source/npu/Hccp.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/LocalNotify.cpp b/third_party/P2P-Transfer/source/npu/LocalNotify.cpp index c6a67fe..4c57d40 100644 --- a/third_party/P2P-Transfer/source/npu/LocalNotify.cpp +++ b/third_party/P2P-Transfer/source/npu/LocalNotify.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/P2PMem.cpp b/third_party/P2P-Transfer/source/npu/P2PMem.cpp index d02fa80..bcdd568 100644 --- a/third_party/P2P-Transfer/source/npu/P2PMem.cpp +++ b/third_party/P2P-Transfer/source/npu/P2PMem.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/P2PNotify.cpp b/third_party/P2P-Transfer/source/npu/P2PNotify.cpp index 92b238b..5e9985b 100644 --- a/third_party/P2P-Transfer/source/npu/P2PNotify.cpp +++ b/third_party/P2P-Transfer/source/npu/P2PNotify.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/P2PStream.cpp b/third_party/P2P-Transfer/source/npu/P2PStream.cpp index d71376a..e9af436 100644 --- a/third_party/P2P-Transfer/source/npu/P2PStream.cpp +++ b/third_party/P2P-Transfer/source/npu/P2PStream.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/PeerManager.cpp b/third_party/P2P-Transfer/source/npu/PeerManager.cpp index 7f58e23..5b28381 100644 --- a/third_party/P2P-Transfer/source/npu/PeerManager.cpp +++ b/third_party/P2P-Transfer/source/npu/PeerManager.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/RaWrapper.cpp b/third_party/P2P-Transfer/source/npu/RaWrapper.cpp index 81d7799..b6fe13e 100644 --- a/third_party/P2P-Transfer/source/npu/RaWrapper.cpp +++ b/third_party/P2P-Transfer/source/npu/RaWrapper.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/RdmaErrCollector.cpp b/third_party/P2P-Transfer/source/npu/RdmaErrCollector.cpp index c6c5cef..eaae426 100644 --- a/third_party/P2P-Transfer/source/npu/RdmaErrCollector.cpp +++ b/third_party/P2P-Transfer/source/npu/RdmaErrCollector.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/npu/RdmaNotify.cpp b/third_party/P2P-Transfer/source/npu/RdmaNotify.cpp index 0320500..0f8889c 100644 --- a/third_party/P2P-Transfer/source/npu/RdmaNotify.cpp +++ b/third_party/P2P-Transfer/source/npu/RdmaNotify.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/p2p.cpp b/third_party/P2P-Transfer/source/p2p.cpp index 26d8ef1..6d38168 100644 --- a/third_party/P2P-Transfer/source/p2p.cpp +++ b/third_party/P2P-Transfer/source/p2p.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/tools/env.cpp b/third_party/P2P-Transfer/source/tools/env.cpp index e9d6d3b..3a420c8 100644 --- a/third_party/P2P-Transfer/source/tools/env.cpp +++ b/third_party/P2P-Transfer/source/tools/env.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/tools/hccl-convert.cpp b/third_party/P2P-Transfer/source/tools/hccl-convert.cpp index 73df318..f6049a3 100644 --- a/third_party/P2P-Transfer/source/tools/hccl-convert.cpp +++ b/third_party/P2P-Transfer/source/tools/hccl-convert.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/source/tools/host-interface.cpp b/third_party/P2P-Transfer/source/tools/host-interface.cpp index ec4f02f..5a01b2f 100644 --- a/third_party/P2P-Transfer/source/tools/host-interface.cpp +++ b/third_party/P2P-Transfer/source/tools/host-interface.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/p2p-transfer_test.cpp b/third_party/P2P-Transfer/test/source/p2p-transfer_test.cpp index bc2784b..f053682 100644 --- a/third_party/P2P-Transfer/test/source/p2p-transfer_test.cpp +++ b/third_party/P2P-Transfer/test/source/p2p-transfer_test.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch.cpp b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch.cpp index d84f32c..85a3ca5 100644 --- a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch.cpp +++ b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_recv.cpp b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_recv.cpp index a62b1c9..f0b52b8 100644 --- a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_recv.cpp +++ b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_recv.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_send.cpp b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_send.cpp index 181c9fb..cf94c61 100644 --- a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_send.cpp +++ b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_send.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_thread.cpp b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_thread.cpp index 12ecbd9..b2892e5 100644 --- a/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_thread.cpp +++ b/third_party/P2P-Transfer/test/source/p2p-transfer_test_batch_thread.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/p2p-transfer_test_init.cpp b/third_party/P2P-Transfer/test/source/p2p-transfer_test_init.cpp index 1912cb9..0384360 100644 --- a/third_party/P2P-Transfer/test/source/p2p-transfer_test_init.cpp +++ b/third_party/P2P-Transfer/test/source/p2p-transfer_test_init.cpp @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/test-tools/barrier.h b/third_party/P2P-Transfer/test/source/test-tools/barrier.h index 2cce479..da3752e 100644 --- a/third_party/P2P-Transfer/test/source/test-tools/barrier.h +++ b/third_party/P2P-Transfer/test/source/test-tools/barrier.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/test-tools/fifo.h b/third_party/P2P-Transfer/test/source/test-tools/fifo.h index 48d50a8..fb9b05b 100644 --- a/third_party/P2P-Transfer/test/source/test-tools/fifo.h +++ b/third_party/P2P-Transfer/test/source/test-tools/fifo.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/test-tools/measure.h b/third_party/P2P-Transfer/test/source/test-tools/measure.h index 9fea079..3782fd2 100644 --- a/third_party/P2P-Transfer/test/source/test-tools/measure.h +++ b/third_party/P2P-Transfer/test/source/test-tools/measure.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/test-tools/measurementSeries.h b/third_party/P2P-Transfer/test/source/test-tools/measurementSeries.h index 0672e3c..d1ab06e 100644 --- a/third_party/P2P-Transfer/test/source/test-tools/measurementSeries.h +++ b/third_party/P2P-Transfer/test/source/test-tools/measurementSeries.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/P2P-Transfer/test/source/test-tools/tools.h b/third_party/P2P-Transfer/test/source/test-tools/tools.h index 983396e..f9a672d 100644 --- a/third_party/P2P-Transfer/test/source/test-tools/tools.h +++ b/third_party/P2P-Transfer/test/source/test-tools/tools.h @@ -1,4 +1,3 @@ - /** * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * diff --git a/third_party/patches/curl/8.8.0/support_old_cmake.patch b/third_party/patches/curl/8.8.0/support_old_cmake.patch new file mode 100644 index 0000000..a2adb68 --- /dev/null +++ b/third_party/patches/curl/8.8.0/support_old_cmake.patch @@ -0,0 +1,12 @@ +diff --color -Npur curl-curl-8_8_0/CMake/curl-config.cmake.in curl-curl-8_8_0_new/CMake/curl-config.cmake.in +--- curl-curl-8_8_0/CMake/curl-config.cmake.in 2024-05-22 13:54:25.000000000 +0800 ++++ curl-curl-8_8_0_new/CMake/curl-config.cmake.in 2025-11-10 21:12:29.139731588 +0800 +@@ -36,5 +36,8 @@ check_required_components("@PROJECT_NAME + + # Alias for either shared or static library + if(NOT TARGET @PROJECT_NAME@::libcurl) ++ if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.11 AND CMAKE_VERSION VERSION_LESS 3.18) ++ set_target_properties(@PROJECT_NAME@::@LIB_SELECTED@ PROPERTIES IMPORTED_GLOBAL TRUE) ++ endif() + add_library(@PROJECT_NAME@::libcurl ALIAS @PROJECT_NAME@::@LIB_SELECTED@) + endif() diff --git a/third_party/patches/obs/3.24.3/obs-sdk-change-spdlog.patch b/third_party/patches/obs/3.24.3/obs-sdk-change-spdlog.patch new file mode 100644 index 0000000..e2f80b4 --- /dev/null +++ b/third_party/patches/obs/3.24.3/obs-sdk-change-spdlog.patch @@ -0,0 +1,119 @@ +diff --git a/platform/eSDK_LogAPI_V2.1.10/eSDKLogAPI/eSDKLog.cpp b/platform/eSDK_LogAPI_V2.1.10/eSDKLogAPI/eSDKLog.cpp +index 0847954..d94fc6f 100644 +--- a/platform/eSDK_LogAPI_V2.1.10/eSDKLogAPI/eSDKLog.cpp ++++ b/platform/eSDK_LogAPI_V2.1.10/eSDKLogAPI/eSDKLog.cpp +@@ -111,9 +111,9 @@ wchar_t *GetWcharFromChar(const char *char_str) + } + + static std::mutex locker; +-static std::shared_ptr runLogger; +-static std::shared_ptr operationLogger; +-static std::shared_ptr interfaceLogger; ++static std::shared_ptr runLogger; ++static std::shared_ptr operationLogger; ++static std::shared_ptr interfaceLogger; + + #ifdef WIN32 + bool eSDKLog::InitSPDLOG(const std::string& product, unsigned int logLevel[LOG_CATEGORY], const std::wstring& logPath, int mode) { +@@ -159,41 +159,41 @@ bool eSDKLog::InitSPDLOG(const std::string& product, unsigned int logLevel[LOG_C + } + + try { +- runLogger = spdlog::rotating_logger_mt(m_InstanceRunName, logPath + L"/" + instanceRunName_W + L".log" ++ runLogger = ds_spdlog::rotating_logger_mt(m_InstanceRunName, logPath + L"/" + instanceRunName_W + L".log" + , ConfigMgrInstance().GetLogSize_Run() * 1024, ConfigMgrInstance().GetLogNum_Run()); +- runLogger->set_level(spdlog::level::level_enum(spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Run())); ++ runLogger->set_level(ds_spdlog::level::level_enum(ds_spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Run())); + //设置日志格式为 时间(精确到毫秒) 线程号 日志名 日志级别 自定义信息 + runLogger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [tid:%t] [%n] [%l] %v"); + initiateSuccessfully = true; + } +- catch (const spdlog::spdlog_ex &ex) { ++ catch (const ds_spdlog::spdlog_ex &ex) { + initiateSuccessfully = false; + } + + if (initiateSuccessfully) { + try { +- operationLogger = spdlog::rotating_logger_mt(m_InstanceOperationName, logPath + L"/" + instanceOperationName_W + L".log" ++ operationLogger = ds_spdlog::rotating_logger_mt(m_InstanceOperationName, logPath + L"/" + instanceOperationName_W + L".log" + , ConfigMgrInstance().GetLogSize_Run() * 1024, ConfigMgrInstance().GetLogNum_Run()); +- operationLogger->set_level(spdlog::level::level_enum(spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Operation())); ++ operationLogger->set_level(ds_spdlog::level::level_enum(ds_spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Operation())); + //设置日志格式为 时间(精确到毫秒) 线程号 日志名 日志级别 自定义信息 + operationLogger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [tid:%t] [%n] [%l] %v"); + initiateSuccessfully = true; + } +- catch (const spdlog::spdlog_ex &ex) { ++ catch (const ds_spdlog::spdlog_ex &ex) { + initiateSuccessfully = false; + } + } + + if (initiateSuccessfully) { + try { +- interfaceLogger = spdlog::rotating_logger_mt(m_InstanceInterfaceName, logPath + L"/" + instanceInterfaceName_W + L".log" ++ interfaceLogger = ds_spdlog::rotating_logger_mt(m_InstanceInterfaceName, logPath + L"/" + instanceInterfaceName_W + L".log" + , ConfigMgrInstance().GetLogSize_Run() * 1024, ConfigMgrInstance().GetLogNum_Run()); +- interfaceLogger->set_level(spdlog::level::level_enum(spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Interface())); ++ interfaceLogger->set_level(ds_spdlog::level::level_enum(ds_spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Interface())); + //设置日志格式为 时间(精确到毫秒) 线程号 日志名 日志级别 自定义信息 + interfaceLogger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [tid:%t] [%n] [%l] %v"); + initiateSuccessfully = true; + } +- catch (const spdlog::spdlog_ex &ex) { ++ catch (const ds_spdlog::spdlog_ex &ex) { + initiateSuccessfully = false; + } + } +@@ -212,41 +212,41 @@ bool eSDKLog::InitSPDLOG(const std::string& product, unsigned int logLevel[LOG_C + m_InstanceRunName = product + LOG_RUN_INSTANCE; + + try { +- runLogger = spdlog::rotating_logger_mt(m_InstanceRunName, logPath + "/" + m_InstanceRunName + ".log" ++ runLogger = ds_spdlog::rotating_logger_mt(m_InstanceRunName, logPath + "/" + m_InstanceRunName + ".log" + , ConfigMgrInstance().GetLogSize_Run() * 1024, ConfigMgrInstance().GetLogNum_Run()); +- runLogger->set_level(spdlog::level::level_enum(spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Run())); ++ runLogger->set_level(ds_spdlog::level::level_enum(ds_spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Run())); + //设置日志格式为 时间(精确到毫秒) 线程号 日志名 日志级别 自定义信息 + runLogger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [tid:%t] [%n] [%l] %v"); + initiateSuccessfully = true; + } +- catch (const spdlog::spdlog_ex &ex) { ++ catch (const ds_spdlog::spdlog_ex &ex) { + initiateSuccessfully = false; + } + + if (initiateSuccessfully) { + try { +- operationLogger = spdlog::rotating_logger_mt(m_InstanceOperationName, logPath + "/" + m_InstanceOperationName + ".log" ++ operationLogger = ds_spdlog::rotating_logger_mt(m_InstanceOperationName, logPath + "/" + m_InstanceOperationName + ".log" + , ConfigMgrInstance().GetLogSize_Run() * 1024, ConfigMgrInstance().GetLogNum_Run()); +- operationLogger->set_level(spdlog::level::level_enum(spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Operation())); ++ operationLogger->set_level(ds_spdlog::level::level_enum(ds_spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Operation())); + //设置日志格式为 时间(精确到毫秒) 线程号 日志名 日志级别 自定义信息 + operationLogger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [tid:%t] [%n] [%l] %v"); + initiateSuccessfully = true; + } +- catch (const spdlog::spdlog_ex &ex) { ++ catch (const ds_spdlog::spdlog_ex &ex) { + initiateSuccessfully = false; + } + } + + if (initiateSuccessfully) { + try { +- interfaceLogger = spdlog::rotating_logger_mt(m_InstanceInterfaceName, logPath + "/" + m_InstanceInterfaceName + ".log" ++ interfaceLogger = ds_spdlog::rotating_logger_mt(m_InstanceInterfaceName, logPath + "/" + m_InstanceInterfaceName + ".log" + , ConfigMgrInstance().GetLogSize_Run() * 1024, ConfigMgrInstance().GetLogNum_Run()); +- interfaceLogger->set_level(spdlog::level::level_enum(spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Interface())); ++ interfaceLogger->set_level(ds_spdlog::level::level_enum(ds_spdlog::level::debug + ConfigMgrInstance().GetLogLevel_Interface())); + //设置日志格式为 时间(精确到毫秒) 线程号 日志名 日志级别 自定义信息 + interfaceLogger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [tid:%t] [%n] [%l] %v"); + initiateSuccessfully = true; + } +- catch (const spdlog::spdlog_ex &ex) { ++ catch (const ds_spdlog::spdlog_ex &ex) { + initiateSuccessfully = false; + } + } diff --git a/third_party/patches/spdlog/change-namespace.patch b/third_party/patches/spdlog/change-namespace.patch new file mode 100644 index 0000000..821ef4b --- /dev/null +++ b/third_party/patches/spdlog/change-namespace.patch @@ -0,0 +1,7274 @@ +From 7cc230e08365aa3db70ebb815b90ce7da68558e4 Mon Sep 17 00:00:00 2001 +From: yangsonglin +Date: Mon, 16 Jun 2025 21:35:14 +0800 +Subject: [PATCH] change namespace and library name with ds + +--- + CMakeLists.txt | 4 +- + README.md | 144 +++++------ + bench/CMakeLists.txt | 8 +- + bench/async_bench.cpp | 62 ++--- + bench/bench.cpp | 120 +++++----- + bench/formatter-bench.cpp | 12 +- + bench/latency.cpp | 86 +++---- + example/example.cpp | 162 ++++++------- + include/spdlog/async.h | 10 +- + include/spdlog/async_logger-inl.h | 16 +- + include/spdlog/async_logger.h | 4 +- + include/spdlog/cfg/argv.h | 4 +- + include/spdlog/cfg/env.h | 4 +- + include/spdlog/cfg/helpers-inl.h | 4 +- + include/spdlog/cfg/helpers.h | 4 +- + include/spdlog/common-inl.h | 10 +- + include/spdlog/common.h | 36 +-- + include/spdlog/details/backtracer-inl.h | 4 +- + include/spdlog/details/backtracer.h | 4 +- + include/spdlog/details/circular_q.h | 4 +- + include/spdlog/details/console_globals.h | 4 +- + include/spdlog/details/file_helper-inl.h | 4 +- + include/spdlog/details/file_helper.h | 4 +- + include/spdlog/details/fmt_helper.h | 6 +- + include/spdlog/details/log_msg-inl.h | 12 +- + include/spdlog/details/log_msg.h | 4 +- + include/spdlog/details/log_msg_buffer-inl.h | 4 +- + include/spdlog/details/log_msg_buffer.h | 4 +- + include/spdlog/details/mpmc_blocking_q.h | 6 +- + include/spdlog/details/null_mutex.h | 4 +- + include/spdlog/details/os-inl.h | 6 +- + include/spdlog/details/os.h | 6 +- + include/spdlog/details/periodic_worker-inl.h | 4 +- + include/spdlog/details/periodic_worker.h | 4 +- + include/spdlog/details/registry-inl.h | 10 +- + include/spdlog/details/registry.h | 10 +- + include/spdlog/details/synchronous_factory.h | 8 +- + include/spdlog/details/tcp_client-windows.h | 4 +- + include/spdlog/details/tcp_client.h | 4 +- + include/spdlog/details/thread_pool-inl.h | 6 +- + include/spdlog/details/thread_pool.h | 6 +- + include/spdlog/details/udp_client-windows.h | 4 +- + include/spdlog/details/udp_client.h | 4 +- + include/spdlog/fmt/bin_to_hex.h | 16 +- + include/spdlog/formatter.h | 4 +- + include/spdlog/fwd.h | 4 +- + include/spdlog/logger-inl.h | 8 +- + include/spdlog/logger.h | 10 +- + include/spdlog/pattern_formatter-inl.h | 8 +- + include/spdlog/pattern_formatter.h | 8 +- + include/spdlog/sinks/android_sink.h | 22 +- + include/spdlog/sinks/ansicolor_sink-inl.h | 10 +- + include/spdlog/sinks/ansicolor_sink.h | 8 +- + include/spdlog/sinks/base_sink-inl.h | 20 +- + include/spdlog/sinks/base_sink.h | 12 +- + include/spdlog/sinks/basic_file_sink-inl.h | 4 +- + include/spdlog/sinks/basic_file_sink.h | 8 +- + include/spdlog/sinks/callback_sink.h | 8 +- + include/spdlog/sinks/daily_file_sink.h | 18 +- + include/spdlog/sinks/dist_sink.h | 8 +- + include/spdlog/sinks/dup_filter_sink.h | 6 +- + include/spdlog/sinks/hourly_file_sink.h | 10 +- + include/spdlog/sinks/kafka_sink.h | 22 +- + include/spdlog/sinks/mongo_sink.h | 12 +- + include/spdlog/sinks/msvc_sink.h | 4 +- + include/spdlog/sinks/null_sink.h | 8 +- + include/spdlog/sinks/ostream_sink.h | 4 +- + include/spdlog/sinks/qt_sinks.h | 20 +- + include/spdlog/sinks/ringbuffer_sink.h | 4 +- + include/spdlog/sinks/rotating_file_sink-inl.h | 4 +- + include/spdlog/sinks/rotating_file_sink.h | 8 +- + include/spdlog/sinks/sink-inl.h | 8 +- + include/spdlog/sinks/sink.h | 6 +- + include/spdlog/sinks/stdout_color_sinks-inl.h | 4 +- + include/spdlog/sinks/stdout_color_sinks.h | 12 +- + include/spdlog/sinks/stdout_sinks-inl.h | 12 +- + include/spdlog/sinks/stdout_sinks.h | 16 +- + include/spdlog/sinks/syslog_sink.h | 22 +- + include/spdlog/sinks/systemd_sink.h | 22 +- + include/spdlog/sinks/tcp_sink.h | 14 +- + include/spdlog/sinks/udp_sink.h | 16 +- + include/spdlog/sinks/win_eventlog_sink.h | 4 +- + include/spdlog/sinks/wincolor_sink-inl.h | 10 +- + include/spdlog/sinks/wincolor_sink.h | 8 +- + include/spdlog/spdlog-inl.h | 14 +- + include/spdlog/spdlog.h | 70 +++--- + include/spdlog/stopwatch.h | 16 +- + src/color_sinks.cpp | 40 ++-- + src/file_sinks.cpp | 8 +- + src/spdlog.cpp | 6 +- + src/stdout_sinks.cpp | 28 +-- + tests/CMakeLists.txt | 4 +- + tests/test_async.cpp | 74 +++--- + tests/test_backtrace.cpp | 16 +- + tests/test_bin_to_hex.cpp | 68 +++--- + tests/test_cfg.cpp | 150 ++++++------ + tests/test_create_dir.cpp | 10 +- + tests/test_custom_callbacks.cpp | 14 +- + tests/test_daily_logger.cpp | 64 ++--- + tests/test_dup_filter.cpp | 40 ++-- + tests/test_errors.cpp | 46 ++-- + tests/test_eventlog.cpp | 6 +- + tests/test_file_helper.cpp | 46 ++-- + tests/test_file_logging.cpp | 36 +-- + tests/test_fmt_helper.cpp | 12 +- + tests/test_macros.cpp | 20 +- + tests/test_misc.cpp | 130 +++++----- + tests/test_mpmc_q.cpp | 16 +- + tests/test_pattern_formatter.cpp | 226 +++++++++--------- + tests/test_registry.cpp | 106 ++++---- + tests/test_sink.h | 4 +- + tests/test_stdout_api.cpp | 46 ++-- + tests/test_stopwatch.cpp | 8 +- + tests/test_systemd.cpp | 6 +- + tests/test_time_point.cpp | 20 +- + tests/utils.cpp | 4 +- + 116 files changed, 1308 insertions(+), 1308 deletions(-) + +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 6556144b..2ae01781 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -344,8 +344,8 @@ if(SPDLOG_INSTALL) + # --------------------------------------------------------------------------------------- + # Install CMake config files + # --------------------------------------------------------------------------------------- +- export(TARGETS spdlog NAMESPACE spdlog:: FILE "${CMAKE_CURRENT_BINARY_DIR}/${config_targets_file}") +- install(EXPORT spdlog DESTINATION ${export_dest_dir} NAMESPACE spdlog:: FILE ${config_targets_file}) ++ export(TARGETS spdlog NAMESPACE ds_spdlog:: FILE "${CMAKE_CURRENT_BINARY_DIR}/${config_targets_file}") ++ install(EXPORT spdlog DESTINATION ${export_dest_dir} NAMESPACE ds_spdlog:: FILE ${config_targets_file}) + + include(CMakePackageConfigHelpers) + configure_package_config_file("${project_config_in}" "${project_config_out}" INSTALL_DESTINATION ${export_dest_dir}) +diff --git a/README.md b/README.md +index 6ce32fb6..ff4b893c 100644 +--- a/README.md ++++ b/README.md +@@ -64,20 +64,20 @@ see example [CMakeLists.txt](https://github.com/gabime/spdlog/blob/v1.x/example/ + + int main() + { +- spdlog::info("Welcome to spdlog!"); +- spdlog::error("Some error message with arg: {}", 1); ++ ds_spdlog::info("Welcome to spdlog!"); ++ ds_spdlog::error("Some error message with arg: {}", 1); + +- spdlog::warn("Easy padding in numbers like {:08d}", 12); +- spdlog::critical("Support for int: {0:d}; hex: {0:x}; oct: {0:o}; bin: {0:b}", 42); +- spdlog::info("Support for floats {:03.2f}", 1.23456); +- spdlog::info("Positional args are {1} {0}..", "too", "supported"); +- spdlog::info("{:<30}", "left aligned"); ++ ds_spdlog::warn("Easy padding in numbers like {:08d}", 12); ++ ds_spdlog::critical("Support for int: {0:d}; hex: {0:x}; oct: {0:o}; bin: {0:b}", 42); ++ ds_spdlog::info("Support for floats {:03.2f}", 1.23456); ++ ds_spdlog::info("Positional args are {1} {0}..", "too", "supported"); ++ ds_spdlog::info("{:<30}", "left aligned"); + +- spdlog::set_level(spdlog::level::debug); // Set global log level to debug +- spdlog::debug("This message should be displayed.."); ++ ds_spdlog::set_level(ds_spdlog::level::debug); // Set global log level to debug ++ ds_spdlog::debug("This message should be displayed.."); + + // change log pattern +- spdlog::set_pattern("[%H:%M:%S %z] [%n] [%^---%L---%$] [thread %t] %v"); ++ ds_spdlog::set_pattern("[%H:%M:%S %z] [%n] [%^---%L---%$] [thread %t] %v"); + + // Compile time log levels + // define SPDLOG_ACTIVE_LEVEL to desired level +@@ -94,9 +94,9 @@ int main() + void stdout_example() + { + // create a color multi-threaded logger +- auto console = spdlog::stdout_color_mt("console"); +- auto err_logger = spdlog::stderr_color_mt("stderr"); +- spdlog::get("console")->info("loggers can be retrieved from a global registry using the spdlog::get(logger_name)"); ++ auto console = ds_spdlog::stdout_color_mt("console"); ++ auto err_logger = ds_spdlog::stderr_color_mt("stderr"); ++ ds_spdlog::get("console")->info("loggers can be retrieved from a global registry using the ds_spdlog::get(logger_name)"); + } + ``` + +@@ -108,9 +108,9 @@ void basic_logfile_example() + { + try + { +- auto logger = spdlog::basic_logger_mt("basic_logger", "logs/basic-log.txt"); ++ auto logger = ds_spdlog::basic_logger_mt("basic_logger", "logs/basic-log.txt"); + } +- catch (const spdlog::spdlog_ex &ex) ++ catch (const ds_spdlog::spdlog_ex &ex) + { + std::cout << "Log init failed: " << ex.what() << std::endl; + } +@@ -125,7 +125,7 @@ void rotating_example() + // Create a file rotating logger with 5 MB size max and 3 rotated files + auto max_size = 1048576 * 5; + auto max_files = 3; +- auto logger = spdlog::rotating_logger_mt("some_logger_name", "logs/rotating.txt", max_size, max_files); ++ auto logger = ds_spdlog::rotating_logger_mt("some_logger_name", "logs/rotating.txt", max_size, max_files); + } + ``` + +@@ -137,7 +137,7 @@ void rotating_example() + void daily_example() + { + // Create a daily logger - a new file is created every day at 2:30 am +- auto logger = spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30); ++ auto logger = ds_spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30); + } + + ``` +@@ -149,14 +149,14 @@ void daily_example() + // This is useful to display debug logs only when needed (e.g. when an error happens). + // When needed, call dump_backtrace() to dump them to your log. + +-spdlog::enable_backtrace(32); // Store the latest 32 messages in a buffer. ++ds_spdlog::enable_backtrace(32); // Store the latest 32 messages in a buffer. + // or my_logger->enable_backtrace(32).. + for(int i = 0; i < 100; i++) + { +- spdlog::debug("Backtrace message {}", i); // not logged yet.. ++ ds_spdlog::debug("Backtrace message {}", i); // not logged yet.. + } + // e.g. if some error happened: +-spdlog::dump_backtrace(); // log them now! show the last 32 messages ++ds_spdlog::dump_backtrace(); // log them now! show the last 32 messages + // or my_logger->dump_backtrace(32).. + ``` + +@@ -165,7 +165,7 @@ spdlog::dump_backtrace(); // log them now! show the last 32 messages + ```c++ + // periodically flush all *registered* loggers every 3 seconds: + // warning: only use if all your loggers are thread-safe ("_mt" loggers) +-spdlog::flush_every(std::chrono::seconds(3)); ++ds_spdlog::flush_every(std::chrono::seconds(3)); + + ``` + +@@ -176,9 +176,9 @@ spdlog::flush_every(std::chrono::seconds(3)); + #include "spdlog/stopwatch.h" + void stopwatch_example() + { +- spdlog::stopwatch sw; +- spdlog::debug("Elapsed {}", sw); +- spdlog::debug("Elapsed {:.3}", sw); ++ ds_spdlog::stopwatch sw; ++ ds_spdlog::debug("Elapsed {}", sw); ++ ds_spdlog::debug("Elapsed {:.3}", sw); + } + + ``` +@@ -199,14 +199,14 @@ void stopwatch_example() + + void binary_example() + { +- auto console = spdlog::get("console"); ++ auto console = ds_spdlog::get("console"); + std::array buf; +- console->info("Binary example: {}", spdlog::to_hex(buf)); +- console->info("Another binary example:{:n}", spdlog::to_hex(std::begin(buf), std::begin(buf) + 10)); ++ console->info("Binary example: {}", ds_spdlog::to_hex(buf)); ++ console->info("Another binary example:{:n}", ds_spdlog::to_hex(std::begin(buf), std::begin(buf) + 10)); + // more examples: +- // logger->info("uppercase: {:X}", spdlog::to_hex(buf)); +- // logger->info("uppercase, no delimiters: {:Xs}", spdlog::to_hex(buf)); +- // logger->info("uppercase, no delimiters, no position info: {:Xsp}", spdlog::to_hex(buf)); ++ // logger->info("uppercase: {:X}", ds_spdlog::to_hex(buf)); ++ // logger->info("uppercase, no delimiters: {:Xs}", ds_spdlog::to_hex(buf)); ++ // logger->info("uppercase, no delimiters, no position info: {:Xsp}", ds_spdlog::to_hex(buf)); + } + + ``` +@@ -219,15 +219,15 @@ void binary_example() + // The console will show only warnings or errors, while the file will log all. + void multi_sink_example() + { +- auto console_sink = std::make_shared(); +- console_sink->set_level(spdlog::level::warn); ++ auto console_sink = std::make_shared(); ++ console_sink->set_level(ds_spdlog::level::warn); + console_sink->set_pattern("[multi_sink_example] [%^%l%$] %v"); + +- auto file_sink = std::make_shared("logs/multisink.txt", true); +- file_sink->set_level(spdlog::level::trace); ++ auto file_sink = std::make_shared("logs/multisink.txt", true); ++ file_sink->set_level(ds_spdlog::level::trace); + +- spdlog::logger logger("multi_sink", {console_sink, file_sink}); +- logger.set_level(spdlog::level::debug); ++ ds_spdlog::logger logger("multi_sink", {console_sink, file_sink}); ++ logger.set_level(ds_spdlog::level::debug); + logger.warn("this should appear in both console and file"); + logger.info("this message should not appear in the console, only in the file"); + } +@@ -241,13 +241,13 @@ void multi_sink_example() + // each time something is logged to the logger + void callback_example() + { +- auto callback_sink = std::make_shared([](const spdlog::details::log_msg &msg) { ++ auto callback_sink = std::make_shared([](const ds_spdlog::details::log_msg &msg) { + // for example you can be notified by sending an email to yourself + }); +- callback_sink->set_level(spdlog::level::err); ++ callback_sink->set_level(ds_spdlog::level::err); + +- auto console_sink = std::make_shared(); +- spdlog::logger logger("custom_callback_logger", {console_sink, callback_sink}); ++ auto console_sink = std::make_shared(); ++ ds_spdlog::logger logger("custom_callback_logger", {console_sink, callback_sink}); + + logger.info("some info log"); + logger.error("critical issue"); // will notify you +@@ -262,10 +262,10 @@ void callback_example() + void async_example() + { + // default thread pool settings can be modified *before* creating the async logger: +- // spdlog::init_thread_pool(8192, 1); // queue with 8k items and 1 backing thread. +- auto async_file = spdlog::basic_logger_mt("async_file_logger", "logs/async_log.txt"); ++ // ds_spdlog::init_thread_pool(8192, 1); // queue with 8k items and 1 backing thread. ++ auto async_file = ds_spdlog::basic_logger_mt("async_file_logger", "logs/async_log.txt"); + // alternatively: +- // auto async_file = spdlog::create_async("async_file_logger", "logs/async_log.txt"); ++ // auto async_file = ds_spdlog::create_async("async_file_logger", "logs/async_log.txt"); + } + + ``` +@@ -278,12 +278,12 @@ void async_example() + + void multi_sink_example2() + { +- spdlog::init_thread_pool(8192, 1); +- auto stdout_sink = std::make_shared(); +- auto rotating_sink = std::make_shared("mylog.txt", 1024*1024*10, 3); +- std::vector sinks {stdout_sink, rotating_sink}; +- auto logger = std::make_shared("loggername", sinks.begin(), sinks.end(), spdlog::thread_pool(), spdlog::async_overflow_policy::block); +- spdlog::register_logger(logger); ++ ds_spdlog::init_thread_pool(8192, 1); ++ auto stdout_sink = std::make_shared(); ++ auto rotating_sink = std::make_shared("mylog.txt", 1024*1024*10, 3); ++ std::vector sinks {stdout_sink, rotating_sink}; ++ auto logger = std::make_shared("loggername", sinks.begin(), sinks.end(), ds_spdlog::thread_pool(), ds_spdlog::async_overflow_policy::block); ++ ds_spdlog::register_logger(logger); + } + ``` + +@@ -301,7 +301,7 @@ struct fmt::formatter : fmt::formatter + + void user_defined_example() + { +- spdlog::info("user defined type: {}", my_type(14)); ++ ds_spdlog::info("user defined type: {}", my_type(14)); + } + + ``` +@@ -312,10 +312,10 @@ void user_defined_example() + // Log patterns can contain custom flags. + // the following example will add new flag '%*' - which will be bound to a instance. + #include "spdlog/pattern_formatter.h" +-class my_formatter_flag : public spdlog::custom_flag_formatter ++class my_formatter_flag : public ds_spdlog::custom_flag_formatter + { + public: +- void format(const spdlog::details::log_msg &, const std::tm &, spdlog::memory_buf_t &dest) override ++ void format(const ds_spdlog::details::log_msg &, const std::tm &, ds_spdlog::memory_buf_t &dest) override + { + std::string some_txt = "custom-flag"; + dest.append(some_txt.data(), some_txt.data() + some_txt.size()); +@@ -323,15 +323,15 @@ public: + + std::unique_ptr clone() const override + { +- return spdlog::details::make_unique(); ++ return ds_spdlog::details::make_unique(); + } + }; + + void custom_flags_example() + { +- auto formatter = std::make_unique(); ++ auto formatter = std::make_unique(); + formatter->add_flag('*').set_pattern("[%n] [%*] [%^%l%$] %v"); +- spdlog::set_formatter(std::move(formatter)); ++ ds_spdlog::set_formatter(std::move(formatter)); + } + + ``` +@@ -342,8 +342,8 @@ void custom_flags_example() + void err_handler_example() + { + // can be set globally or per logger(logger->set_error_handler(..)) +- spdlog::set_error_handler([](const std::string &msg) { spdlog::get("console")->error("*** LOGGER ERROR ***: {}", msg); }); +- spdlog::get("console")->info("some invalid message to trigger an error {}{}{}{}", 3); ++ ds_spdlog::set_error_handler([](const std::string &msg) { ds_spdlog::get("console")->error("*** LOGGER ERROR ***: {}", msg); }); ++ ds_spdlog::get("console")->info("some invalid message to trigger an error {}{}{}{}", 3); + } + + ``` +@@ -355,7 +355,7 @@ void err_handler_example() + void syslog_example() + { + std::string ident = "spdlog-example"; +- auto syslog_logger = spdlog::syslog_logger_mt("syslog", ident, LOG_PID); ++ auto syslog_logger = ds_spdlog::syslog_logger_mt("syslog", ident, LOG_PID); + syslog_logger->warn("This is warning that will end up in syslog."); + } + ``` +@@ -366,7 +366,7 @@ void syslog_example() + void android_example() + { + std::string tag = "spdlog-android"; +- auto android_logger = spdlog::android_logger_mt("android", tag); ++ auto android_logger = ds_spdlog::android_logger_mt("android", tag); + android_logger->critical("Use \"adb shell logcat\" to view this message."); + } + ``` +@@ -378,11 +378,11 @@ void android_example() + #include "spdlog/cfg/env.h" + int main (int argc, char *argv[]) + { +- spdlog::cfg::load_env_levels(); ++ ds_spdlog::cfg::load_env_levels(); + // or from the command line: + // ./example SPDLOG_LEVEL=info,mylogger=trace + // #include "spdlog/cfg/argv.h" // for loading levels from argv +- // spdlog::cfg::load_argv_levels(argc, argv); ++ // ds_spdlog::cfg::load_argv_levels(argc, argv); + } + ``` + So then you can: +@@ -400,13 +400,13 @@ $ ./example + // This is useful for cleanup procedures or for adding something to the start/end of the log file. + void file_events_example() + { +- // pass the spdlog::file_event_handlers to file sinks for open/close log file notifications +- spdlog::file_event_handlers handlers; +- handlers.before_open = [](spdlog::filename_t filename) { spdlog::info("Before opening {}", filename); }; +- handlers.after_open = [](spdlog::filename_t filename, std::FILE *fstream) { fputs("After opening\n", fstream); }; +- handlers.before_close = [](spdlog::filename_t filename, std::FILE *fstream) { fputs("Before closing\n", fstream); }; +- handlers.after_close = [](spdlog::filename_t filename) { spdlog::info("After closing {}", filename); }; +- auto my_logger = spdlog::basic_logger_st("some_logger", "logs/events-sample.txt", true, handlers); ++ // pass the ds_spdlog::file_event_handlers to file sinks for open/close log file notifications ++ ds_spdlog::file_event_handlers handlers; ++ handlers.before_open = [](ds_spdlog::filename_t filename) { ds_spdlog::info("Before opening {}", filename); }; ++ handlers.after_open = [](ds_spdlog::filename_t filename, std::FILE *fstream) { fputs("After opening\n", fstream); }; ++ handlers.before_close = [](ds_spdlog::filename_t filename, std::FILE *fstream) { fputs("Before closing\n", fstream); }; ++ handlers.after_close = [](ds_spdlog::filename_t filename) { ds_spdlog::info("After closing {}", filename); }; ++ auto my_logger = ds_spdlog::basic_logger_st("some_logger", "logs/events-sample.txt", true, handlers); + } + ``` + +@@ -415,9 +415,9 @@ void file_events_example() + ```c++ + void replace_default_logger_example() + { +- auto new_logger = spdlog::basic_logger_mt("new_default_logger", "logs/new-default-log.txt", true); +- spdlog::set_default_logger(new_logger); +- spdlog::info("new logger log message"); ++ auto new_logger = ds_spdlog::basic_logger_mt("new_default_logger", "logs/new-default-log.txt", true); ++ ds_spdlog::set_default_logger(new_logger); ++ ds_spdlog::info("new logger log message"); + } + ``` + +@@ -432,7 +432,7 @@ MainWindow::MainWindow(QWidget *parent) : QMainWindow(parent) + auto log_widget = new QTextEdit(this); + setCentralWidget(log_widget); + int max_lines = 500; // keep the text widget to max 500 lines. remove old lines if needed. +- auto logger = spdlog::qt_color_logger_mt("qt_logger", log_widget, max_lines); ++ auto logger = ds_spdlog::qt_color_logger_mt("qt_logger", log_widget, max_lines); + logger->info("Some info message"); + } + ``` +diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt +index 8003886a..1c7c2c50 100644 +--- a/bench/CMakeLists.txt ++++ b/bench/CMakeLists.txt +@@ -29,13 +29,13 @@ endif() + + add_executable(bench bench.cpp) + spdlog_enable_warnings(bench) +-target_link_libraries(bench PRIVATE spdlog::spdlog) ++target_link_libraries(bench PRIVATE ds_spdlog::spdlog) + + add_executable(async_bench async_bench.cpp) +-target_link_libraries(async_bench PRIVATE spdlog::spdlog) ++target_link_libraries(async_bench PRIVATE ds_spdlog::spdlog) + + add_executable(latency latency.cpp) +-target_link_libraries(latency PRIVATE benchmark::benchmark spdlog::spdlog) ++target_link_libraries(latency PRIVATE benchmark::benchmark ds_spdlog::spdlog) + + add_executable(formatter-bench formatter-bench.cpp) +-target_link_libraries(formatter-bench PRIVATE benchmark::benchmark spdlog::spdlog) ++target_link_libraries(formatter-bench PRIVATE benchmark::benchmark ds_spdlog::spdlog) +diff --git a/bench/async_bench.cpp b/bench/async_bench.cpp +index cf4d9754..33592aec 100644 +--- a/bench/async_bench.cpp ++++ b/bench/async_bench.cpp +@@ -27,11 +27,11 @@ + + using namespace std; + using namespace std::chrono; +-using namespace spdlog; +-using namespace spdlog::sinks; ++using namespace ds_spdlog; ++using namespace ds_spdlog::sinks; + using namespace utils; + +-void bench_mt(int howmany, std::shared_ptr log, int thread_count); ++void bench_mt(int howmany, std::shared_ptr log, int thread_count); + + #ifdef _MSC_VER + # pragma warning(push) +@@ -55,14 +55,14 @@ int count_lines(const char *filename) + + void verify_file(const char *filename, int expected_count) + { +- spdlog::info("Verifying {} to contain {} line..", filename, expected_count); ++ ds_spdlog::info("Verifying {} to contain {} line..", filename, expected_count); + auto count = count_lines(filename); + if (count != expected_count) + { +- spdlog::error("Test failed. {} has {} lines instead of {}", filename, count, expected_count); ++ ds_spdlog::error("Test failed. {} has {} lines instead of {}", filename, count, expected_count); + exit(1); + } +- spdlog::info("Line count OK ({})\n", count); ++ ds_spdlog::info("Line count OK ({})\n", count); + } + + #ifdef _MSC_VER +@@ -79,10 +79,10 @@ int main(int argc, char *argv[]) + + try + { +- spdlog::set_pattern("[%^%l%$] %v"); ++ ds_spdlog::set_pattern("[%^%l%$] %v"); + if (argc == 1) + { +- spdlog::info("Usage: {} ", argv[0]); ++ ds_spdlog::info("Usage: {} ", argv[0]); + return 0; + } + +@@ -95,7 +95,7 @@ int main(int argc, char *argv[]) + queue_size = atoi(argv[3]); + if (queue_size > 500000) + { +- spdlog::error("Max queue size allowed: 500,000"); ++ ds_spdlog::error("Max queue size allowed: 500,000"); + exit(1); + } + } +@@ -103,44 +103,44 @@ int main(int argc, char *argv[]) + if (argc > 4) + iters = atoi(argv[4]); + +- auto slot_size = sizeof(spdlog::details::async_msg); +- spdlog::info("-------------------------------------------------"); +- spdlog::info("Messages : {:L}", howmany); +- spdlog::info("Threads : {:L}", threads); +- spdlog::info("Queue : {:L} slots", queue_size); +- spdlog::info("Queue memory : {:L} x {:L} = {:L} KB ", queue_size, slot_size, (queue_size * slot_size) / 1024); +- spdlog::info("Total iters : {:L}", iters); +- spdlog::info("-------------------------------------------------"); ++ auto slot_size = sizeof(ds_spdlog::details::async_msg); ++ ds_spdlog::info("-------------------------------------------------"); ++ ds_spdlog::info("Messages : {:L}", howmany); ++ ds_spdlog::info("Threads : {:L}", threads); ++ ds_spdlog::info("Queue : {:L} slots", queue_size); ++ ds_spdlog::info("Queue memory : {:L} x {:L} = {:L} KB ", queue_size, slot_size, (queue_size * slot_size) / 1024); ++ ds_spdlog::info("Total iters : {:L}", iters); ++ ds_spdlog::info("-------------------------------------------------"); + + const char *filename = "logs/basic_async.log"; +- spdlog::info(""); +- spdlog::info("*********************************"); +- spdlog::info("Queue Overflow Policy: block"); +- spdlog::info("*********************************"); ++ ds_spdlog::info(""); ++ ds_spdlog::info("*********************************"); ++ ds_spdlog::info("Queue Overflow Policy: block"); ++ ds_spdlog::info("*********************************"); + for (int i = 0; i < iters; i++) + { + auto tp = std::make_shared(queue_size, 1); +- auto file_sink = std::make_shared(filename, true); ++ auto file_sink = std::make_shared(filename, true); + auto logger = std::make_shared("async_logger", std::move(file_sink), std::move(tp), async_overflow_policy::block); + bench_mt(howmany, std::move(logger), threads); + // verify_file(filename, howmany); + } + +- spdlog::info(""); +- spdlog::info("*********************************"); +- spdlog::info("Queue Overflow Policy: overrun"); +- spdlog::info("*********************************"); ++ ds_spdlog::info(""); ++ ds_spdlog::info("*********************************"); ++ ds_spdlog::info("Queue Overflow Policy: overrun"); ++ ds_spdlog::info("*********************************"); + // do same test but discard oldest if queue is full instead of blocking + filename = "logs/basic_async-overrun.log"; + for (int i = 0; i < iters; i++) + { + auto tp = std::make_shared(queue_size, 1); +- auto file_sink = std::make_shared(filename, true); ++ auto file_sink = std::make_shared(filename, true); + auto logger = + std::make_shared("async_logger", std::move(file_sink), std::move(tp), async_overflow_policy::overrun_oldest); + bench_mt(howmany, std::move(logger), threads); + } +- spdlog::shutdown(); ++ ds_spdlog::shutdown(); + } + catch (std::exception &ex) + { +@@ -151,7 +151,7 @@ int main(int argc, char *argv[]) + return 0; + } + +-void thread_fun(std::shared_ptr logger, int howmany) ++void thread_fun(std::shared_ptr logger, int howmany) + { + for (int i = 0; i < howmany; i++) + { +@@ -159,7 +159,7 @@ void thread_fun(std::shared_ptr logger, int howmany) + } + } + +-void bench_mt(int howmany, std::shared_ptr logger, int thread_count) ++void bench_mt(int howmany, std::shared_ptr logger, int thread_count) + { + using std::chrono::high_resolution_clock; + vector threads; +@@ -182,5 +182,5 @@ void bench_mt(int howmany, std::shared_ptr logger, int thread_co + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::info("Elapsed: {} secs\t {:L}/sec", delta_d, int(howmany / delta_d)); ++ ds_spdlog::info("Elapsed: {} secs\t {:L}/sec", delta_d, int(howmany / delta_d)); + } +diff --git a/bench/bench.cpp b/bench/bench.cpp +index ae47f047..77cc5884 100644 +--- a/bench/bench.cpp ++++ b/bench/bench.cpp +@@ -27,11 +27,11 @@ + #include + #include + +-void bench(int howmany, std::shared_ptr log); +-void bench_mt(int howmany, std::shared_ptr log, size_t thread_count); ++void bench(int howmany, std::shared_ptr log); ++void bench_mt(int howmany, std::shared_ptr log, size_t thread_count); + +-// void bench_default_api(int howmany, std::shared_ptr log); +-// void bench_c_string(int howmany, std::shared_ptr log); ++// void bench_default_api(int howmany, std::shared_ptr log); ++// void bench_c_string(int howmany, std::shared_ptr log); + + static const size_t file_size = 30 * 1024 * 1024; + static const size_t rotating_files = 5; +@@ -39,81 +39,81 @@ static const int max_threads = 1000; + + void bench_threaded_logging(size_t threads, int iters) + { +- spdlog::info("**************************************************************"); +- spdlog::info(spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Multi threaded: {:L} threads, {:L} messages", threads, iters)); +- spdlog::info("**************************************************************"); ++ ds_spdlog::info("**************************************************************"); ++ ds_spdlog::info(ds_spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Multi threaded: {:L} threads, {:L} messages", threads, iters)); ++ ds_spdlog::info("**************************************************************"); + +- auto basic_mt = spdlog::basic_logger_mt("basic_mt", "logs/basic_mt.log", true); ++ auto basic_mt = ds_spdlog::basic_logger_mt("basic_mt", "logs/basic_mt.log", true); + bench_mt(iters, std::move(basic_mt), threads); +- auto basic_mt_tracing = spdlog::basic_logger_mt("basic_mt/backtrace-on", "logs/basic_mt.log", true); ++ auto basic_mt_tracing = ds_spdlog::basic_logger_mt("basic_mt/backtrace-on", "logs/basic_mt.log", true); + basic_mt_tracing->enable_backtrace(32); + bench_mt(iters, std::move(basic_mt_tracing), threads); + +- spdlog::info(""); +- auto rotating_mt = spdlog::rotating_logger_mt("rotating_mt", "logs/rotating_mt.log", file_size, rotating_files); ++ ds_spdlog::info(""); ++ auto rotating_mt = ds_spdlog::rotating_logger_mt("rotating_mt", "logs/rotating_mt.log", file_size, rotating_files); + bench_mt(iters, std::move(rotating_mt), threads); +- auto rotating_mt_tracing = spdlog::rotating_logger_mt("rotating_mt/backtrace-on", "logs/rotating_mt.log", file_size, rotating_files); ++ auto rotating_mt_tracing = ds_spdlog::rotating_logger_mt("rotating_mt/backtrace-on", "logs/rotating_mt.log", file_size, rotating_files); + rotating_mt_tracing->enable_backtrace(32); + bench_mt(iters, std::move(rotating_mt_tracing), threads); + +- spdlog::info(""); +- auto daily_mt = spdlog::daily_logger_mt("daily_mt", "logs/daily_mt.log"); ++ ds_spdlog::info(""); ++ auto daily_mt = ds_spdlog::daily_logger_mt("daily_mt", "logs/daily_mt.log"); + bench_mt(iters, std::move(daily_mt), threads); +- auto daily_mt_tracing = spdlog::daily_logger_mt("daily_mt/backtrace-on", "logs/daily_mt.log"); ++ auto daily_mt_tracing = ds_spdlog::daily_logger_mt("daily_mt/backtrace-on", "logs/daily_mt.log"); + daily_mt_tracing->enable_backtrace(32); + bench_mt(iters, std::move(daily_mt_tracing), threads); + +- spdlog::info(""); +- auto empty_logger = std::make_shared("level-off"); +- empty_logger->set_level(spdlog::level::off); ++ ds_spdlog::info(""); ++ auto empty_logger = std::make_shared("level-off"); ++ empty_logger->set_level(ds_spdlog::level::off); + bench(iters, empty_logger); +- auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); +- empty_logger_tracing->set_level(spdlog::level::off); ++ auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); ++ empty_logger_tracing->set_level(ds_spdlog::level::off); + empty_logger_tracing->enable_backtrace(32); + bench(iters, empty_logger_tracing); + } + + void bench_single_threaded(int iters) + { +- spdlog::info("**************************************************************"); +- spdlog::info(spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Single threaded: {} messages", iters)); +- spdlog::info("**************************************************************"); ++ ds_spdlog::info("**************************************************************"); ++ ds_spdlog::info(ds_spdlog::fmt_lib::format(std::locale("en_US.UTF-8"), "Single threaded: {} messages", iters)); ++ ds_spdlog::info("**************************************************************"); + +- auto basic_st = spdlog::basic_logger_st("basic_st", "logs/basic_st.log", true); ++ auto basic_st = ds_spdlog::basic_logger_st("basic_st", "logs/basic_st.log", true); + bench(iters, std::move(basic_st)); + +- auto basic_st_tracing = spdlog::basic_logger_st("basic_st/backtrace-on", "logs/basic_st.log", true); ++ auto basic_st_tracing = ds_spdlog::basic_logger_st("basic_st/backtrace-on", "logs/basic_st.log", true); + bench(iters, std::move(basic_st_tracing)); + +- spdlog::info(""); +- auto rotating_st = spdlog::rotating_logger_st("rotating_st", "logs/rotating_st.log", file_size, rotating_files); ++ ds_spdlog::info(""); ++ auto rotating_st = ds_spdlog::rotating_logger_st("rotating_st", "logs/rotating_st.log", file_size, rotating_files); + bench(iters, std::move(rotating_st)); +- auto rotating_st_tracing = spdlog::rotating_logger_st("rotating_st/backtrace-on", "logs/rotating_st.log", file_size, rotating_files); ++ auto rotating_st_tracing = ds_spdlog::rotating_logger_st("rotating_st/backtrace-on", "logs/rotating_st.log", file_size, rotating_files); + rotating_st_tracing->enable_backtrace(32); + bench(iters, std::move(rotating_st_tracing)); + +- spdlog::info(""); +- auto daily_st = spdlog::daily_logger_st("daily_st", "logs/daily_st.log"); ++ ds_spdlog::info(""); ++ auto daily_st = ds_spdlog::daily_logger_st("daily_st", "logs/daily_st.log"); + bench(iters, std::move(daily_st)); +- auto daily_st_tracing = spdlog::daily_logger_st("daily_st/backtrace-on", "logs/daily_st.log"); ++ auto daily_st_tracing = ds_spdlog::daily_logger_st("daily_st/backtrace-on", "logs/daily_st.log"); + daily_st_tracing->enable_backtrace(32); + bench(iters, std::move(daily_st_tracing)); + +- spdlog::info(""); +- auto empty_logger = std::make_shared("level-off"); +- empty_logger->set_level(spdlog::level::off); ++ ds_spdlog::info(""); ++ auto empty_logger = std::make_shared("level-off"); ++ empty_logger->set_level(ds_spdlog::level::off); + bench(iters, empty_logger); + +- auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); +- empty_logger_tracing->set_level(spdlog::level::off); ++ auto empty_logger_tracing = std::make_shared("level-off/backtrace-on"); ++ empty_logger_tracing->set_level(ds_spdlog::level::off); + empty_logger_tracing->enable_backtrace(32); + bench(iters, empty_logger_tracing); + } + + int main(int argc, char *argv[]) + { +- spdlog::set_automatic_registration(false); +- spdlog::default_logger()->set_pattern("[%^%l%$] %v"); ++ ds_spdlog::set_automatic_registration(false); ++ ds_spdlog::default_logger()->set_pattern("[%^%l%$] %v"); + int iters = 250000; + size_t threads = 4; + try +@@ -130,7 +130,7 @@ int main(int argc, char *argv[]) + + if (threads > max_threads) + { +- throw std::runtime_error(spdlog::fmt_lib::format("Number of threads exceeds maximum({})", max_threads)); ++ throw std::runtime_error(ds_spdlog::fmt_lib::format("Number of threads exceeds maximum({})", max_threads)); + } + + bench_single_threaded(iters); +@@ -139,13 +139,13 @@ int main(int argc, char *argv[]) + } + catch (std::exception &ex) + { +- spdlog::error(ex.what()); ++ ds_spdlog::error(ex.what()); + return EXIT_FAILURE; + } + return EXIT_SUCCESS; + } + +-void bench(int howmany, std::shared_ptr log) ++void bench(int howmany, std::shared_ptr log) + { + using std::chrono::duration; + using std::chrono::duration_cast; +@@ -160,12 +160,12 @@ void bench(int howmany, std::shared_ptr log) + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); + +- spdlog::info(spdlog::fmt_lib::format( ++ ds_spdlog::info(ds_spdlog::fmt_lib::format( + std::locale("en_US.UTF-8"), "{:<30} Elapsed: {:0.2f} secs {:>16L}/sec", log->name(), delta_d, int(howmany / delta_d))); +- spdlog::drop(log->name()); ++ ds_spdlog::drop(log->name()); + } + +-void bench_mt(int howmany, std::shared_ptr log, size_t thread_count) ++void bench_mt(int howmany, std::shared_ptr log, size_t thread_count) + { + using std::chrono::duration; + using std::chrono::duration_cast; +@@ -191,34 +191,34 @@ void bench_mt(int howmany, std::shared_ptr log, size_t thread_co + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::info(spdlog::fmt_lib::format( ++ ds_spdlog::info(ds_spdlog::fmt_lib::format( + std::locale("en_US.UTF-8"), "{:<30} Elapsed: {:0.2f} secs {:>16L}/sec", log->name(), delta_d, int(howmany / delta_d))); +- spdlog::drop(log->name()); ++ ds_spdlog::drop(log->name()); + } + + /* +-void bench_default_api(int howmany, std::shared_ptr log) ++void bench_default_api(int howmany, std::shared_ptr log) + { + using std::chrono::high_resolution_clock; + using std::chrono::duration; + using std::chrono::duration_cast; + +- auto orig_default = spdlog::default_logger(); +- spdlog::set_default_logger(log); ++ auto orig_default = ds_spdlog::default_logger(); ++ ds_spdlog::set_default_logger(log); + auto start = high_resolution_clock::now(); + for (auto i = 0; i < howmany; ++i) + { +- spdlog::info("Hello logger: msg number {}", i); ++ ds_spdlog::info("Hello logger: msg number {}", i); + } + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::drop(log->name()); +- spdlog::set_default_logger(std::move(orig_default)); +- spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); ++ ds_spdlog::drop(log->name()); ++ ds_spdlog::set_default_logger(std::move(orig_default)); ++ ds_spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); + } + +-void bench_c_string(int howmany, std::shared_ptr log) ++void bench_c_string(int howmany, std::shared_ptr log) + { + using std::chrono::high_resolution_clock; + using std::chrono::duration; +@@ -230,19 +230,19 @@ void bench_c_string(int howmany, std::shared_ptr log) + "augue pretium, nec scelerisque est maximus. Nullam convallis, sem nec blandit maximus, nisi turpis ornare " + "nisl, sit amet volutpat neque massa eu odio. Maecenas malesuada quam ex, posuere congue nibh turpis duis."; + +- auto orig_default = spdlog::default_logger(); +- spdlog::set_default_logger(log); ++ auto orig_default = ds_spdlog::default_logger(); ++ ds_spdlog::set_default_logger(log); + auto start = high_resolution_clock::now(); + for (auto i = 0; i < howmany; ++i) + { +- spdlog::log(spdlog::level::info, msg); ++ ds_spdlog::log(ds_spdlog::level::info, msg); + } + + auto delta = high_resolution_clock::now() - start; + auto delta_d = duration_cast>(delta).count(); +- spdlog::drop(log->name()); +- spdlog::set_default_logger(std::move(orig_default)); +- spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); ++ ds_spdlog::drop(log->name()); ++ ds_spdlog::set_default_logger(std::move(orig_default)); ++ ds_spdlog::info("{:<30} Elapsed: {:0.2f} secs {:>16}/sec", log->name(), delta_d, int(howmany / delta_d)); + } + + */ +\ No newline at end of file +diff --git a/bench/formatter-bench.cpp b/bench/formatter-bench.cpp +index 1454c6bb..bb9b17e5 100644 +--- a/bench/formatter-bench.cpp ++++ b/bench/formatter-bench.cpp +@@ -10,13 +10,13 @@ + + void bench_formatter(benchmark::State &state, std::string pattern) + { +- auto formatter = spdlog::details::make_unique(pattern); +- spdlog::memory_buf_t dest; ++ auto formatter = ds_spdlog::details::make_unique(pattern); ++ ds_spdlog::memory_buf_t dest; + std::string logger_name = "logger-name"; + const char *text = "Hello. This is some message with length of 80 "; + +- spdlog::source_loc source_loc{"a/b/c/d/myfile.cpp", 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, logger_name, spdlog::level::info, text); ++ ds_spdlog::source_loc source_loc{"a/b/c/d/myfile.cpp", 123, "some_func()"}; ++ ds_spdlog::details::log_msg msg(source_loc, logger_name, ds_spdlog::level::info, text); + + for (auto _ : state) + { +@@ -59,10 +59,10 @@ void bench_formatters() + int main(int argc, char *argv[]) + { + +- spdlog::set_pattern("[%^%l%$] %v"); ++ ds_spdlog::set_pattern("[%^%l%$] %v"); + if (argc != 2) + { +- spdlog::error("Usage: {} (or \"all\" to bench all)", argv[0]); ++ ds_spdlog::error("Usage: {} (or \"all\" to bench all)", argv[0]); + exit(1); + } + +diff --git a/bench/latency.cpp b/bench/latency.cpp +index 8f002ee1..2f2b5108 100644 +--- a/bench/latency.cpp ++++ b/bench/latency.cpp +@@ -16,7 +16,7 @@ + #include "spdlog/sinks/null_sink.h" + #include "spdlog/sinks/rotating_file_sink.h" + +-void bench_c_string(benchmark::State &state, std::shared_ptr logger) ++void bench_c_string(benchmark::State &state, std::shared_ptr logger) + { + const char *msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Vestibulum pharetra metus cursus " + "lacus placerat congue. Nulla egestas, mauris a tincidunt tempus, enim lectus volutpat mi, eu consequat sem " +@@ -30,7 +30,7 @@ void bench_c_string(benchmark::State &state, std::shared_ptr log + } + } + +-void bench_logger(benchmark::State &state, std::shared_ptr logger) ++void bench_logger(benchmark::State &state, std::shared_ptr logger) + { + int i = 0; + for (auto _ : state) +@@ -38,17 +38,17 @@ void bench_logger(benchmark::State &state, std::shared_ptr logge + logger->info("Hello logger: msg number {}...............", ++i); + } + } +-void bench_global_logger(benchmark::State &state, std::shared_ptr logger) ++void bench_global_logger(benchmark::State &state, std::shared_ptr logger) + { +- spdlog::set_default_logger(std::move(logger)); ++ ds_spdlog::set_default_logger(std::move(logger)); + int i = 0; + for (auto _ : state) + { +- spdlog::info("Hello logger: msg number {}...............", ++i); ++ ds_spdlog::info("Hello logger: msg number {}...............", ++i); + } + } + +-void bench_disabled_macro(benchmark::State &state, std::shared_ptr logger) ++void bench_disabled_macro(benchmark::State &state, std::shared_ptr logger) + { + int i = 0; + benchmark::DoNotOptimize(i); // prevent unused warnings +@@ -59,9 +59,9 @@ void bench_disabled_macro(benchmark::State &state, std::shared_ptr logger) ++void bench_disabled_macro_global_logger(benchmark::State &state, std::shared_ptr logger) + { +- spdlog::set_default_logger(std::move(logger)); ++ ds_spdlog::set_default_logger(std::move(logger)); + int i = 0; + benchmark::DoNotOptimize(i); // prevent unused warnings + benchmark::DoNotOptimize(logger); // prevent unused warnings +@@ -74,20 +74,20 @@ void bench_disabled_macro_global_logger(benchmark::State &state, std::shared_ptr + #ifdef __linux__ + void bench_dev_null() + { +- auto dev_null_st = spdlog::basic_logger_st("/dev/null_st", "/dev/null"); ++ auto dev_null_st = ds_spdlog::basic_logger_st("/dev/null_st", "/dev/null"); + benchmark::RegisterBenchmark("/dev/null_st", bench_logger, std::move(dev_null_st))->UseRealTime(); +- spdlog::drop("/dev/null_st"); ++ ds_spdlog::drop("/dev/null_st"); + +- auto dev_null_mt = spdlog::basic_logger_mt("/dev/null_mt", "/dev/null"); ++ auto dev_null_mt = ds_spdlog::basic_logger_mt("/dev/null_mt", "/dev/null"); + benchmark::RegisterBenchmark("/dev/null_mt", bench_logger, std::move(dev_null_mt))->UseRealTime(); +- spdlog::drop("/dev/null_mt"); ++ ds_spdlog::drop("/dev/null_mt"); + } + #endif // __linux__ + + int main(int argc, char *argv[]) + { +- using spdlog::sinks::null_sink_mt; +- using spdlog::sinks::null_sink_st; ++ using ds_spdlog::sinks::null_sink_mt; ++ using ds_spdlog::sinks::null_sink_st; + + size_t file_size = 30 * 1024 * 1024; + size_t rotating_files = 5; +@@ -96,23 +96,23 @@ int main(int argc, char *argv[]) + auto full_bench = argc > 1 && std::string(argv[1]) == "full"; + + // disabled loggers +- auto disabled_logger = std::make_shared("bench", std::make_shared()); +- disabled_logger->set_level(spdlog::level::off); ++ auto disabled_logger = std::make_shared("bench", std::make_shared()); ++ disabled_logger->set_level(ds_spdlog::level::off); + benchmark::RegisterBenchmark("disabled-at-compile-time", bench_disabled_macro, disabled_logger); + benchmark::RegisterBenchmark("disabled-at-compile-time (global logger)", bench_disabled_macro_global_logger, disabled_logger); + benchmark::RegisterBenchmark("disabled-at-runtime", bench_logger, disabled_logger); + benchmark::RegisterBenchmark("disabled-at-runtime (global logger)", bench_global_logger, disabled_logger); + // with backtrace of 64 +- auto tracing_disabled_logger = std::make_shared("bench", std::make_shared()); ++ auto tracing_disabled_logger = std::make_shared("bench", std::make_shared()); + tracing_disabled_logger->enable_backtrace(64); + benchmark::RegisterBenchmark("disabled-at-runtime/backtrace", bench_logger, tracing_disabled_logger); + +- auto null_logger_st = std::make_shared("bench", std::make_shared()); ++ auto null_logger_st = std::make_shared("bench", std::make_shared()); + benchmark::RegisterBenchmark("null_sink_st (500_bytes c_str)", bench_c_string, std::move(null_logger_st)); + benchmark::RegisterBenchmark("null_sink_st", bench_logger, null_logger_st); + benchmark::RegisterBenchmark("null_sink_st (global logger)", bench_global_logger, null_logger_st); + // with backtrace of 64 +- auto tracing_null_logger_st = std::make_shared("bench", std::make_shared()); ++ auto tracing_null_logger_st = std::make_shared("bench", std::make_shared()); + tracing_null_logger_st->enable_backtrace(64); + benchmark::RegisterBenchmark("null_sink_st/backtrace", bench_logger, tracing_null_logger_st); + +@@ -123,64 +123,64 @@ int main(int argc, char *argv[]) + if (full_bench) + { + // basic_st +- auto basic_st = spdlog::basic_logger_st("basic_st", "latency_logs/basic_st.log", true); ++ auto basic_st = ds_spdlog::basic_logger_st("basic_st", "latency_logs/basic_st.log", true); + benchmark::RegisterBenchmark("basic_st", bench_logger, std::move(basic_st))->UseRealTime(); +- spdlog::drop("basic_st"); ++ ds_spdlog::drop("basic_st"); + // with backtrace of 64 +- auto tracing_basic_st = spdlog::basic_logger_st("tracing_basic_st", "latency_logs/tracing_basic_st.log", true); ++ auto tracing_basic_st = ds_spdlog::basic_logger_st("tracing_basic_st", "latency_logs/tracing_basic_st.log", true); + tracing_basic_st->enable_backtrace(64); + benchmark::RegisterBenchmark("basic_st/backtrace", bench_logger, std::move(tracing_basic_st))->UseRealTime(); +- spdlog::drop("tracing_basic_st"); ++ ds_spdlog::drop("tracing_basic_st"); + + // rotating st +- auto rotating_st = spdlog::rotating_logger_st("rotating_st", "latency_logs/rotating_st.log", file_size, rotating_files); ++ auto rotating_st = ds_spdlog::rotating_logger_st("rotating_st", "latency_logs/rotating_st.log", file_size, rotating_files); + benchmark::RegisterBenchmark("rotating_st", bench_logger, std::move(rotating_st))->UseRealTime(); +- spdlog::drop("rotating_st"); ++ ds_spdlog::drop("rotating_st"); + // with backtrace of 64 + auto tracing_rotating_st = +- spdlog::rotating_logger_st("tracing_rotating_st", "latency_logs/tracing_rotating_st.log", file_size, rotating_files); ++ ds_spdlog::rotating_logger_st("tracing_rotating_st", "latency_logs/tracing_rotating_st.log", file_size, rotating_files); + benchmark::RegisterBenchmark("rotating_st/backtrace", bench_logger, std::move(tracing_rotating_st))->UseRealTime(); +- spdlog::drop("tracing_rotating_st"); ++ ds_spdlog::drop("tracing_rotating_st"); + + // daily st +- auto daily_st = spdlog::daily_logger_mt("daily_st", "latency_logs/daily_st.log"); ++ auto daily_st = ds_spdlog::daily_logger_mt("daily_st", "latency_logs/daily_st.log"); + benchmark::RegisterBenchmark("daily_st", bench_logger, std::move(daily_st))->UseRealTime(); +- spdlog::drop("daily_st"); +- auto tracing_daily_st = spdlog::daily_logger_mt("tracing_daily_st", "latency_logs/daily_st.log"); ++ ds_spdlog::drop("daily_st"); ++ auto tracing_daily_st = ds_spdlog::daily_logger_mt("tracing_daily_st", "latency_logs/daily_st.log"); + benchmark::RegisterBenchmark("daily_st/backtrace", bench_logger, std::move(tracing_daily_st))->UseRealTime(); +- spdlog::drop("tracing_daily_st"); ++ ds_spdlog::drop("tracing_daily_st"); + + // + // Multi threaded bench, 10 loggers using same logger concurrently + // +- auto null_logger_mt = std::make_shared("bench", std::make_shared()); ++ auto null_logger_mt = std::make_shared("bench", std::make_shared()); + benchmark::RegisterBenchmark("null_sink_mt", bench_logger, null_logger_mt)->Threads(n_threads)->UseRealTime(); + + // basic_mt +- auto basic_mt = spdlog::basic_logger_mt("basic_mt", "latency_logs/basic_mt.log", true); ++ auto basic_mt = ds_spdlog::basic_logger_mt("basic_mt", "latency_logs/basic_mt.log", true); + benchmark::RegisterBenchmark("basic_mt", bench_logger, std::move(basic_mt))->Threads(n_threads)->UseRealTime(); +- spdlog::drop("basic_mt"); ++ ds_spdlog::drop("basic_mt"); + + // rotating mt +- auto rotating_mt = spdlog::rotating_logger_mt("rotating_mt", "latency_logs/rotating_mt.log", file_size, rotating_files); ++ auto rotating_mt = ds_spdlog::rotating_logger_mt("rotating_mt", "latency_logs/rotating_mt.log", file_size, rotating_files); + benchmark::RegisterBenchmark("rotating_mt", bench_logger, std::move(rotating_mt))->Threads(n_threads)->UseRealTime(); +- spdlog::drop("rotating_mt"); ++ ds_spdlog::drop("rotating_mt"); + + // daily mt +- auto daily_mt = spdlog::daily_logger_mt("daily_mt", "latency_logs/daily_mt.log"); ++ auto daily_mt = ds_spdlog::daily_logger_mt("daily_mt", "latency_logs/daily_mt.log"); + benchmark::RegisterBenchmark("daily_mt", bench_logger, std::move(daily_mt))->Threads(n_threads)->UseRealTime(); +- spdlog::drop("daily_mt"); ++ ds_spdlog::drop("daily_mt"); + } + + // async + auto queue_size = 1024 * 1024 * 3; +- auto tp = std::make_shared(queue_size, 1); +- auto async_logger = std::make_shared( +- "async_logger", std::make_shared(), std::move(tp), spdlog::async_overflow_policy::overrun_oldest); ++ auto tp = std::make_shared(queue_size, 1); ++ auto async_logger = std::make_shared( ++ "async_logger", std::make_shared(), std::move(tp), ds_spdlog::async_overflow_policy::overrun_oldest); + benchmark::RegisterBenchmark("async_logger", bench_logger, async_logger)->Threads(n_threads)->UseRealTime(); + +- auto async_logger_tracing = std::make_shared( +- "async_logger_tracing", std::make_shared(), std::move(tp), spdlog::async_overflow_policy::overrun_oldest); ++ auto async_logger_tracing = std::make_shared( ++ "async_logger_tracing", std::make_shared(), std::move(tp), ds_spdlog::async_overflow_policy::overrun_oldest); + async_logger_tracing->enable_backtrace(32); + benchmark::RegisterBenchmark("async_logger/tracing", bench_logger, async_logger_tracing)->Threads(n_threads)->UseRealTime(); + +diff --git a/example/example.cpp b/example/example.cpp +index d6609ed5..a358dd04 100644 +--- a/example/example.cpp ++++ b/example/example.cpp +@@ -37,36 +37,36 @@ int main(int, char *[]) + // Log levels can be loaded from argv/env using "SPDLOG_LEVEL" + load_levels_example(); + +- spdlog::info("Welcome to spdlog version {}.{}.{} !", SPDLOG_VER_MAJOR, SPDLOG_VER_MINOR, SPDLOG_VER_PATCH); ++ ds_spdlog::info("Welcome to spdlog version {}.{}.{} !", SPDLOG_VER_MAJOR, SPDLOG_VER_MINOR, SPDLOG_VER_PATCH); + +- spdlog::warn("Easy padding in numbers like {:08d}", 12); +- spdlog::critical("Support for int: {0:d}; hex: {0:x}; oct: {0:o}; bin: {0:b}", 42); +- spdlog::info("Support for floats {:03.2f}", 1.23456); +- spdlog::info("Positional args are {1} {0}..", "too", "supported"); +- spdlog::info("{:>8} aligned, {:<8} aligned", "right", "left"); ++ ds_spdlog::warn("Easy padding in numbers like {:08d}", 12); ++ ds_spdlog::critical("Support for int: {0:d}; hex: {0:x}; oct: {0:o}; bin: {0:b}", 42); ++ ds_spdlog::info("Support for floats {:03.2f}", 1.23456); ++ ds_spdlog::info("Positional args are {1} {0}..", "too", "supported"); ++ ds_spdlog::info("{:>8} aligned, {:<8} aligned", "right", "left"); + + // Runtime log levels +- spdlog::set_level(spdlog::level::info); // Set global log level to info +- spdlog::debug("This message should not be displayed!"); +- spdlog::set_level(spdlog::level::trace); // Set specific logger's log level +- spdlog::debug("This message should be displayed.."); ++ ds_spdlog::set_level(ds_spdlog::level::info); // Set global log level to info ++ ds_spdlog::debug("This message should not be displayed!"); ++ ds_spdlog::set_level(ds_spdlog::level::trace); // Set specific logger's log level ++ ds_spdlog::debug("This message should be displayed.."); + + // Customize msg format for all loggers +- spdlog::set_pattern("[%H:%M:%S %z] [%^%L%$] [thread %t] %v"); +- spdlog::info("This an info message with custom format"); +- spdlog::set_pattern("%+"); // back to default format +- spdlog::set_level(spdlog::level::info); ++ ds_spdlog::set_pattern("[%H:%M:%S %z] [%^%L%$] [thread %t] %v"); ++ ds_spdlog::info("This an info message with custom format"); ++ ds_spdlog::set_pattern("%+"); // back to default format ++ ds_spdlog::set_level(ds_spdlog::level::info); + + // Backtrace support + // Loggers can store in a ring buffer all messages (including debug/trace) for later inspection. + // When needed, call dump_backtrace() to see what happened: +- spdlog::enable_backtrace(10); // create ring buffer with capacity of 10 messages ++ ds_spdlog::enable_backtrace(10); // create ring buffer with capacity of 10 messages + for (int i = 0; i < 100; i++) + { +- spdlog::debug("Backtrace message {}", i); // not logged.. ++ ds_spdlog::debug("Backtrace message {}", i); // not logged.. + } + // e.g. if some error happened: +- spdlog::dump_backtrace(); // log them now! ++ ds_spdlog::dump_backtrace(); // log them now! + + try + { +@@ -90,18 +90,18 @@ int main(int, char *[]) + + // Flush all *registered* loggers using a worker thread every 3 seconds. + // note: registered loggers *must* be thread safe for this to work correctly! +- spdlog::flush_every(std::chrono::seconds(3)); ++ ds_spdlog::flush_every(std::chrono::seconds(3)); + + // Apply some function on all registered loggers +- spdlog::apply_all([&](std::shared_ptr l) { l->info("End of example."); }); ++ ds_spdlog::apply_all([&](std::shared_ptr l) { l->info("End of example."); }); + + // Release all spdlog resources, and drop all loggers in the registry. + // This is optional (only mandatory if using windows + async log). +- spdlog::shutdown(); ++ ds_spdlog::shutdown(); + } + + // Exceptions will only be thrown upon failed logger or sink construction (not during logging). +- catch (const spdlog::spdlog_ex &ex) ++ catch (const ds_spdlog::spdlog_ex &ex) + { + std::printf("Log initialization failed: %s\n", ex.what()); + return 1; +@@ -113,37 +113,37 @@ int main(int, char *[]) + void stdout_logger_example() + { + // Create color multi threaded logger. +- auto console = spdlog::stdout_color_mt("console"); ++ auto console = ds_spdlog::stdout_color_mt("console"); + // or for stderr: +- // auto console = spdlog::stderr_color_mt("error-logger"); ++ // auto console = ds_spdlog::stderr_color_mt("error-logger"); + } + + #include "spdlog/sinks/basic_file_sink.h" + void basic_example() + { + // Create basic file logger (not rotated). +- auto my_logger = spdlog::basic_logger_mt("file_logger", "logs/basic-log.txt", true); ++ auto my_logger = ds_spdlog::basic_logger_mt("file_logger", "logs/basic-log.txt", true); + } + + #include "spdlog/sinks/rotating_file_sink.h" + void rotating_example() + { + // Create a file rotating logger with 5mb size max and 3 rotated files. +- auto rotating_logger = spdlog::rotating_logger_mt("some_logger_name", "logs/rotating.txt", 1048576 * 5, 3); ++ auto rotating_logger = ds_spdlog::rotating_logger_mt("some_logger_name", "logs/rotating.txt", 1048576 * 5, 3); + } + + #include "spdlog/sinks/daily_file_sink.h" + void daily_example() + { + // Create a daily logger - a new file is created every day on 2:30am. +- auto daily_logger = spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30); ++ auto daily_logger = ds_spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30); + } + + #include "spdlog/sinks/callback_sink.h" + void callback_example() + { + // Create the logger +- auto logger = spdlog::callback_logger_mt("custom_callback_logger", [](const spdlog::details::log_msg & /*msg*/) { ++ auto logger = ds_spdlog::callback_logger_mt("custom_callback_logger", [](const ds_spdlog::details::log_msg & /*msg*/) { + // do what you need to do with msg + }); + } +@@ -153,21 +153,21 @@ void load_levels_example() + { + // Set the log level to "info" and mylogger to "trace": + // SPDLOG_LEVEL=info,mylogger=trace && ./example +- spdlog::cfg::load_env_levels(); ++ ds_spdlog::cfg::load_env_levels(); + // or from command line: + // ./example SPDLOG_LEVEL=info,mylogger=trace + // #include "spdlog/cfg/argv.h" // for loading levels from argv +- // spdlog::cfg::load_argv_levels(args, argv); ++ // ds_spdlog::cfg::load_argv_levels(args, argv); + } + + #include "spdlog/async.h" + void async_example() + { + // Default thread pool settings can be modified *before* creating the async logger: +- // spdlog::init_thread_pool(32768, 1); // queue with max 32k items 1 backing thread. +- auto async_file = spdlog::basic_logger_mt("async_file_logger", "logs/async_log.txt"); ++ // ds_spdlog::init_thread_pool(32768, 1); // queue with max 32k items 1 backing thread. ++ auto async_file = ds_spdlog::basic_logger_mt("async_file_logger", "logs/async_log.txt"); + // alternatively: +- // auto async_file = spdlog::create_async("async_file_logger", "logs/async_log.txt"); ++ // auto async_file = ds_spdlog::create_async("async_file_logger", "logs/async_log.txt"); + + for (int i = 1; i < 101; ++i) + { +@@ -193,14 +193,14 @@ void binary_example() + { + buf.push_back(static_cast(i & 0xff)); + } +- spdlog::info("Binary example: {}", spdlog::to_hex(buf)); +- spdlog::info("Another binary example:{:n}", spdlog::to_hex(std::begin(buf), std::begin(buf) + 10)); ++ ds_spdlog::info("Binary example: {}", ds_spdlog::to_hex(buf)); ++ ds_spdlog::info("Another binary example:{:n}", ds_spdlog::to_hex(std::begin(buf), std::begin(buf) + 10)); + // more examples: +- // logger->info("uppercase: {:X}", spdlog::to_hex(buf)); +- // logger->info("uppercase, no delimiters: {:Xs}", spdlog::to_hex(buf)); +- // logger->info("uppercase, no delimiters, no position info: {:Xsp}", spdlog::to_hex(buf)); +- // logger->info("hexdump style: {:a}", spdlog::to_hex(buf)); +- // logger->info("hexdump style, 20 chars per line {:a}", spdlog::to_hex(buf, 20)); ++ // logger->info("uppercase: {:X}", ds_spdlog::to_hex(buf)); ++ // logger->info("uppercase, no delimiters: {:Xs}", ds_spdlog::to_hex(buf)); ++ // logger->info("uppercase, no delimiters, no position info: {:Xsp}", ds_spdlog::to_hex(buf)); ++ // logger->info("hexdump style: {:a}", ds_spdlog::to_hex(buf)); ++ // logger->info("hexdump style, 20 chars per line {:a}", ds_spdlog::to_hex(buf, 20)); + } + #else + void binary_example() { +@@ -214,7 +214,7 @@ void binary_example() { + void vector_example() + { + std::vector vec = {1, 2, 3}; +- spdlog::info("Vector example: {}", vec); ++ ds_spdlog::info("Vector example: {}", vec); + } + + #else +@@ -233,7 +233,7 @@ void trace_example() + SPDLOG_DEBUG("Some debug message.. {} ,{}", 1, 3.23); + + // trace from logger object +- auto logger = spdlog::get("file_logger"); ++ auto logger = ds_spdlog::get("file_logger"); + SPDLOG_LOGGER_TRACE(logger, "another trace message"); + } + +@@ -242,32 +242,32 @@ void trace_example() + #include + void stopwatch_example() + { +- spdlog::stopwatch sw; ++ ds_spdlog::stopwatch sw; + std::this_thread::sleep_for(std::chrono::milliseconds(123)); +- spdlog::info("Stopwatch: {} seconds", sw); ++ ds_spdlog::info("Stopwatch: {} seconds", sw); + } + + #include "spdlog/sinks/udp_sink.h" + void udp_example() + { +- spdlog::sinks::udp_sink_config cfg("127.0.0.1", 11091); +- auto my_logger = spdlog::udp_logger_mt("udplog", cfg); +- my_logger->set_level(spdlog::level::debug); ++ ds_spdlog::sinks::udp_sink_config cfg("127.0.0.1", 11091); ++ auto my_logger = ds_spdlog::udp_logger_mt("udplog", cfg); ++ my_logger->set_level(ds_spdlog::level::debug); + my_logger->info("hello world"); + } + + // A logger with multiple sinks (stdout and file) - each with a different format and log level. + void multi_sink_example() + { +- auto console_sink = std::make_shared(); +- console_sink->set_level(spdlog::level::warn); ++ auto console_sink = std::make_shared(); ++ console_sink->set_level(ds_spdlog::level::warn); + console_sink->set_pattern("[multi_sink_example] [%^%l%$] %v"); + +- auto file_sink = std::make_shared("logs/multisink.txt", true); +- file_sink->set_level(spdlog::level::trace); ++ auto file_sink = std::make_shared("logs/multisink.txt", true); ++ file_sink->set_level(ds_spdlog::level::trace); + +- spdlog::logger logger("multi_sink", {console_sink, file_sink}); +- logger.set_level(spdlog::level::debug); ++ ds_spdlog::logger logger("multi_sink", {console_sink, file_sink}); ++ logger.set_level(ds_spdlog::level::debug); + logger.warn("this should appear in both console and file"); + logger.info("this message should not appear in the console, only in the file"); + } +@@ -303,14 +303,14 @@ struct std::formatter : std::formatter + + void user_defined_example() + { +- spdlog::info("user defined type: {}", my_type(14)); ++ ds_spdlog::info("user defined type: {}", my_type(14)); + } + + // Custom error handler. Will be triggered on log failure. + void err_handler_example() + { + // can be set globally or per logger(logger->set_error_handler(..)) +- spdlog::set_error_handler([](const std::string &msg) { printf("*** Custom log error handler: %s ***\n", msg.c_str()); }); ++ ds_spdlog::set_error_handler([](const std::string &msg) { printf("*** Custom log error handler: %s ***\n", msg.c_str()); }); + } + + // syslog example (linux/osx/freebsd) +@@ -319,7 +319,7 @@ void err_handler_example() + void syslog_example() + { + std::string ident = "spdlog-example"; +- auto syslog_logger = spdlog::syslog_logger_mt("syslog", ident, LOG_PID); ++ auto syslog_logger = ds_spdlog::syslog_logger_mt("syslog", ident, LOG_PID); + syslog_logger->warn("This is warning that will end up in syslog."); + } + #endif +@@ -330,7 +330,7 @@ void syslog_example() + void android_example() + { + std::string tag = "spdlog-android"; +- auto android_logger = spdlog::android_logger_mt("android", tag); ++ auto android_logger = ds_spdlog::android_logger_mt("android", tag); + android_logger->critical("Use \"adb shell logcat\" to view this message."); + } + #endif +@@ -338,10 +338,10 @@ void android_example() + // Log patterns can contain custom flags. + // this will add custom flag '%*' which will be bound to a instance + #include "spdlog/pattern_formatter.h" +-class my_formatter_flag : public spdlog::custom_flag_formatter ++class my_formatter_flag : public ds_spdlog::custom_flag_formatter + { + public: +- void format(const spdlog::details::log_msg &, const std::tm &, spdlog::memory_buf_t &dest) override ++ void format(const ds_spdlog::details::log_msg &, const std::tm &, ds_spdlog::memory_buf_t &dest) override + { + std::string some_txt = "custom-flag"; + dest.append(some_txt.data(), some_txt.data() + some_txt.size()); +@@ -349,50 +349,50 @@ public: + + std::unique_ptr clone() const override + { +- return spdlog::details::make_unique(); ++ return ds_spdlog::details::make_unique(); + } + }; + + void custom_flags_example() + { + +- using spdlog::details::make_unique; // for pre c++14 +- auto formatter = make_unique(); ++ using ds_spdlog::details::make_unique; // for pre c++14 ++ auto formatter = make_unique(); + formatter->add_flag('*').set_pattern("[%n] [%*] [%^%l%$] %v"); +- // set the new formatter using spdlog::set_formatter(formatter) or logger->set_formatter(formatter) +- // spdlog::set_formatter(std::move(formatter)); ++ // set the new formatter using ds_spdlog::set_formatter(formatter) or logger->set_formatter(formatter) ++ // ds_spdlog::set_formatter(std::move(formatter)); + } + + void file_events_example() + { +- // pass the spdlog::file_event_handlers to file sinks for open/close log file notifications +- spdlog::file_event_handlers handlers; +- handlers.before_open = [](spdlog::filename_t filename) { spdlog::info("Before opening {}", filename); }; +- handlers.after_open = [](spdlog::filename_t filename, std::FILE *fstream) { +- spdlog::info("After opening {}", filename); ++ // pass the ds_spdlog::file_event_handlers to file sinks for open/close log file notifications ++ ds_spdlog::file_event_handlers handlers; ++ handlers.before_open = [](ds_spdlog::filename_t filename) { ds_spdlog::info("Before opening {}", filename); }; ++ handlers.after_open = [](ds_spdlog::filename_t filename, std::FILE *fstream) { ++ ds_spdlog::info("After opening {}", filename); + fputs("After opening\n", fstream); + }; +- handlers.before_close = [](spdlog::filename_t filename, std::FILE *fstream) { +- spdlog::info("Before closing {}", filename); ++ handlers.before_close = [](ds_spdlog::filename_t filename, std::FILE *fstream) { ++ ds_spdlog::info("Before closing {}", filename); + fputs("Before closing\n", fstream); + }; +- handlers.after_close = [](spdlog::filename_t filename) { spdlog::info("After closing {}", filename); }; +- auto file_sink = std::make_shared("logs/events-sample.txt", true, handlers); +- spdlog::logger my_logger("some_logger", file_sink); ++ handlers.after_close = [](ds_spdlog::filename_t filename) { ds_spdlog::info("After closing {}", filename); }; ++ auto file_sink = std::make_shared("logs/events-sample.txt", true, handlers); ++ ds_spdlog::logger my_logger("some_logger", file_sink); + my_logger.info("Some log line"); + } + + void replace_default_logger_example() + { + // store the old logger so we don't break other examples. +- auto old_logger = spdlog::default_logger(); ++ auto old_logger = ds_spdlog::default_logger(); + +- auto new_logger = spdlog::basic_logger_mt("new_default_logger", "logs/new-default-log.txt", true); +- spdlog::set_default_logger(new_logger); +- spdlog::set_level(spdlog::level::info); +- spdlog::debug("This message should not be displayed!"); +- spdlog::set_level(spdlog::level::trace); +- spdlog::debug("This message should be displayed.."); ++ auto new_logger = ds_spdlog::basic_logger_mt("new_default_logger", "logs/new-default-log.txt", true); ++ ds_spdlog::set_default_logger(new_logger); ++ ds_spdlog::set_level(ds_spdlog::level::info); ++ ds_spdlog::debug("This message should not be displayed!"); ++ ds_spdlog::set_level(ds_spdlog::level::trace); ++ ds_spdlog::debug("This message should be displayed.."); + +- spdlog::set_default_logger(old_logger); ++ ds_spdlog::set_default_logger(old_logger); + } +diff --git a/include/spdlog/async.h b/include/spdlog/async.h +index 94f9f6d9..89041879 100644 +--- a/include/spdlog/async.h ++++ b/include/spdlog/async.h +@@ -22,7 +22,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + + namespace details { + static const size_t default_async_q_size = 8192; +@@ -61,13 +61,13 @@ using async_factory = async_factory_impl; + using async_factory_nonblock = async_factory_impl; + + template +-inline std::shared_ptr create_async(std::string logger_name, SinkArgs &&...sink_args) ++inline std::shared_ptr create_async(std::string logger_name, SinkArgs &&...sink_args) + { + return async_factory::create(std::move(logger_name), std::forward(sink_args)...); + } + + template +-inline std::shared_ptr create_async_nb(std::string logger_name, SinkArgs &&...sink_args) ++inline std::shared_ptr create_async_nb(std::string logger_name, SinkArgs &&...sink_args) + { + return async_factory_nonblock::create(std::move(logger_name), std::forward(sink_args)...); + } +@@ -92,8 +92,8 @@ inline void init_thread_pool(size_t q_size, size_t thread_count) + } + + // get the global thread pool. +-inline std::shared_ptr thread_pool() ++inline std::shared_ptr thread_pool() + { + return details::registry::instance().get_tp(); + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/async_logger-inl.h b/include/spdlog/async_logger-inl.h +index 4de8382a..86cf3785 100644 +--- a/include/spdlog/async_logger-inl.h ++++ b/include/spdlog/async_logger-inl.h +@@ -13,18 +13,18 @@ + #include + #include + +-SPDLOG_INLINE spdlog::async_logger::async_logger( ++SPDLOG_INLINE ds_spdlog::async_logger::async_logger( + std::string logger_name, sinks_init_list sinks_list, std::weak_ptr tp, async_overflow_policy overflow_policy) + : async_logger(std::move(logger_name), sinks_list.begin(), sinks_list.end(), std::move(tp), overflow_policy) + {} + +-SPDLOG_INLINE spdlog::async_logger::async_logger( ++SPDLOG_INLINE ds_spdlog::async_logger::async_logger( + std::string logger_name, sink_ptr single_sink, std::weak_ptr tp, async_overflow_policy overflow_policy) + : async_logger(std::move(logger_name), {std::move(single_sink)}, std::move(tp), overflow_policy) + {} + + // send the log message to the thread pool +-SPDLOG_INLINE void spdlog::async_logger::sink_it_(const details::log_msg &msg){ ++SPDLOG_INLINE void ds_spdlog::async_logger::sink_it_(const details::log_msg &msg){ + SPDLOG_TRY{if (auto pool_ptr = thread_pool_.lock()){pool_ptr->post_log(shared_from_this(), msg, overflow_policy_); + } + else +@@ -36,7 +36,7 @@ SPDLOG_LOGGER_CATCH(msg.source) + } + + // send flush request to the thread pool +-SPDLOG_INLINE void spdlog::async_logger::flush_(){ ++SPDLOG_INLINE void ds_spdlog::async_logger::flush_(){ + SPDLOG_TRY{if (auto pool_ptr = thread_pool_.lock()){pool_ptr->post_flush(shared_from_this(), overflow_policy_); + } + else +@@ -50,7 +50,7 @@ SPDLOG_LOGGER_CATCH(source_loc()) + // + // backend functions - called from the thread pool to do the actual job + // +-SPDLOG_INLINE void spdlog::async_logger::backend_sink_it_(const details::log_msg &msg) ++SPDLOG_INLINE void ds_spdlog::async_logger::backend_sink_it_(const details::log_msg &msg) + { + for (auto &sink : sinks_) + { +@@ -70,7 +70,7 @@ SPDLOG_INLINE void spdlog::async_logger::backend_sink_it_(const details::log_msg + } + } + +-SPDLOG_INLINE void spdlog::async_logger::backend_flush_() ++SPDLOG_INLINE void ds_spdlog::async_logger::backend_flush_() + { + for (auto &sink : sinks_) + { +@@ -82,9 +82,9 @@ SPDLOG_INLINE void spdlog::async_logger::backend_flush_() + } + } + +-SPDLOG_INLINE std::shared_ptr spdlog::async_logger::clone(std::string new_name) ++SPDLOG_INLINE std::shared_ptr ds_spdlog::async_logger::clone(std::string new_name) + { +- auto cloned = std::make_shared(*this); ++ auto cloned = std::make_shared(*this); + cloned->name_ = std::move(new_name); + return cloned; + } +diff --git a/include/spdlog/async_logger.h b/include/spdlog/async_logger.h +index 91a93fcb..50949223 100644 +--- a/include/spdlog/async_logger.h ++++ b/include/spdlog/async_logger.h +@@ -16,7 +16,7 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + + // Async overflow policy - block by default. + enum class async_overflow_policy +@@ -61,7 +61,7 @@ private: + std::weak_ptr thread_pool_; + async_overflow_policy overflow_policy_; + }; +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "async_logger-inl.h" +diff --git a/include/spdlog/cfg/argv.h b/include/spdlog/cfg/argv.h +index 36d9f1c4..aecc8b25 100644 +--- a/include/spdlog/cfg/argv.h ++++ b/include/spdlog/cfg/argv.h +@@ -17,7 +17,7 @@ + // turn off all logging except for logger1 and logger2: + // example.exe "SPDLOG_LEVEL=off,logger1=debug,logger2=info" + +-namespace spdlog { ++namespace ds_spdlog { + namespace cfg { + + // search for SPDLOG_LEVEL= in the args and use it to init the levels +@@ -41,4 +41,4 @@ inline void load_argv_levels(int argc, char **argv) + } + + } // namespace cfg +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/cfg/env.h b/include/spdlog/cfg/env.h +index 1f39ebbb..bb80f581 100644 +--- a/include/spdlog/cfg/env.h ++++ b/include/spdlog/cfg/env.h +@@ -23,7 +23,7 @@ + // turn off all logging except for logger1 and logger2: + // export SPDLOG_LEVEL="off,logger1=debug,logger2=info" + +-namespace spdlog { ++namespace ds_spdlog { + namespace cfg { + inline void load_env_levels() + { +@@ -35,4 +35,4 @@ inline void load_env_levels() + } + + } // namespace cfg +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/cfg/helpers-inl.h b/include/spdlog/cfg/helpers-inl.h +index 675a13af..8cefe4b1 100644 +--- a/include/spdlog/cfg/helpers-inl.h ++++ b/include/spdlog/cfg/helpers-inl.h +@@ -16,7 +16,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace cfg { + namespace helpers { + +@@ -117,4 +117,4 @@ SPDLOG_INLINE void load_levels(const std::string &input) + + } // namespace helpers + } // namespace cfg +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/cfg/helpers.h b/include/spdlog/cfg/helpers.h +index ab7584e0..3fdc7034 100644 +--- a/include/spdlog/cfg/helpers.h ++++ b/include/spdlog/cfg/helpers.h +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace cfg { + namespace helpers { + // +@@ -22,7 +22,7 @@ SPDLOG_API void load_levels(const std::string &txt); + } // namespace helpers + + } // namespace cfg +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "helpers-inl.h" +diff --git a/include/spdlog/common-inl.h b/include/spdlog/common-inl.h +index 728f9831..9bbae234 100644 +--- a/include/spdlog/common-inl.h ++++ b/include/spdlog/common-inl.h +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace level { + + #if __cplusplus >= 201703L +@@ -20,17 +20,17 @@ constexpr + + static const char *short_level_names[] SPDLOG_SHORT_LEVEL_NAMES; + +-SPDLOG_INLINE const string_view_t &to_string_view(spdlog::level::level_enum l) SPDLOG_NOEXCEPT ++SPDLOG_INLINE const string_view_t &to_string_view(ds_spdlog::level::level_enum l) SPDLOG_NOEXCEPT + { + return level_string_views[l]; + } + +-SPDLOG_INLINE const char *to_short_c_str(spdlog::level::level_enum l) SPDLOG_NOEXCEPT ++SPDLOG_INLINE const char *to_short_c_str(ds_spdlog::level::level_enum l) SPDLOG_NOEXCEPT + { + return short_level_names[l]; + } + +-SPDLOG_INLINE spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT ++SPDLOG_INLINE ds_spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT + { + auto it = std::find(std::begin(level_string_views), std::end(level_string_views), name); + if (it != std::end(level_string_views)) +@@ -79,4 +79,4 @@ SPDLOG_INLINE void throw_spdlog_ex(std::string msg) + SPDLOG_THROW(spdlog_ex(std::move(msg))); + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/common.h b/include/spdlog/common.h +index 0a262eb2..a683ba9d 100644 +--- a/include/spdlog/common.h ++++ b/include/spdlog/common.h +@@ -111,7 +111,7 @@ + {} + #endif + +-namespace spdlog { ++namespace ds_spdlog { + + class formatter; + +@@ -242,13 +242,13 @@ enum level_enum : int + n_levels + }; + +-#define SPDLOG_LEVEL_NAME_TRACE spdlog::string_view_t("trace", 5) +-#define SPDLOG_LEVEL_NAME_DEBUG spdlog::string_view_t("debug", 5) +-#define SPDLOG_LEVEL_NAME_INFO spdlog::string_view_t("info", 4) +-#define SPDLOG_LEVEL_NAME_WARNING spdlog::string_view_t("warning", 7) +-#define SPDLOG_LEVEL_NAME_ERROR spdlog::string_view_t("error", 5) +-#define SPDLOG_LEVEL_NAME_CRITICAL spdlog::string_view_t("critical", 8) +-#define SPDLOG_LEVEL_NAME_OFF spdlog::string_view_t("off", 3) ++#define SPDLOG_LEVEL_NAME_TRACE ds_spdlog::string_view_t("trace", 5) ++#define SPDLOG_LEVEL_NAME_DEBUG ds_spdlog::string_view_t("debug", 5) ++#define SPDLOG_LEVEL_NAME_INFO ds_spdlog::string_view_t("info", 4) ++#define SPDLOG_LEVEL_NAME_WARNING ds_spdlog::string_view_t("warning", 7) ++#define SPDLOG_LEVEL_NAME_ERROR ds_spdlog::string_view_t("error", 5) ++#define SPDLOG_LEVEL_NAME_CRITICAL ds_spdlog::string_view_t("critical", 8) ++#define SPDLOG_LEVEL_NAME_OFF ds_spdlog::string_view_t("off", 3) + + #if !defined(SPDLOG_LEVEL_NAMES) + # define SPDLOG_LEVEL_NAMES \ +@@ -266,9 +266,9 @@ enum level_enum : int + } + #endif + +-SPDLOG_API const string_view_t &to_string_view(spdlog::level::level_enum l) SPDLOG_NOEXCEPT; +-SPDLOG_API const char *to_short_c_str(spdlog::level::level_enum l) SPDLOG_NOEXCEPT; +-SPDLOG_API spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT; ++SPDLOG_API const string_view_t &to_string_view(ds_spdlog::level::level_enum l) SPDLOG_NOEXCEPT; ++SPDLOG_API const char *to_short_c_str(ds_spdlog::level::level_enum l) SPDLOG_NOEXCEPT; ++SPDLOG_API ds_spdlog::level::level_enum from_str(const std::string &name) SPDLOG_NOEXCEPT; + + } // namespace level + +@@ -346,23 +346,23 @@ namespace details { + + // to_string_view + +-SPDLOG_CONSTEXPR_FUNC spdlog::string_view_t to_string_view(const memory_buf_t &buf) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC ds_spdlog::string_view_t to_string_view(const memory_buf_t &buf) SPDLOG_NOEXCEPT + { +- return spdlog::string_view_t{buf.data(), buf.size()}; ++ return ds_spdlog::string_view_t{buf.data(), buf.size()}; + } + +-SPDLOG_CONSTEXPR_FUNC spdlog::string_view_t to_string_view(spdlog::string_view_t str) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC ds_spdlog::string_view_t to_string_view(ds_spdlog::string_view_t str) SPDLOG_NOEXCEPT + { + return str; + } + + #if defined(SPDLOG_WCHAR_FILENAMES) || defined(SPDLOG_WCHAR_TO_UTF8_SUPPORT) +-SPDLOG_CONSTEXPR_FUNC spdlog::wstring_view_t to_string_view(const wmemory_buf_t &buf) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC ds_spdlog::wstring_view_t to_string_view(const wmemory_buf_t &buf) SPDLOG_NOEXCEPT + { +- return spdlog::wstring_view_t{buf.data(), buf.size()}; ++ return ds_spdlog::wstring_view_t{buf.data(), buf.size()}; + } + +-SPDLOG_CONSTEXPR_FUNC spdlog::wstring_view_t to_string_view(spdlog::wstring_view_t str) SPDLOG_NOEXCEPT ++SPDLOG_CONSTEXPR_FUNC ds_spdlog::wstring_view_t to_string_view(ds_spdlog::wstring_view_t str) SPDLOG_NOEXCEPT + { + return str; + } +@@ -413,7 +413,7 @@ constexpr T conditional_static_cast(U value) + } + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "common-inl.h" +diff --git a/include/spdlog/details/backtracer-inl.h b/include/spdlog/details/backtracer-inl.h +index 40eba408..9a739be4 100644 +--- a/include/spdlog/details/backtracer-inl.h ++++ b/include/spdlog/details/backtracer-inl.h +@@ -6,7 +6,7 @@ + #ifndef SPDLOG_HEADER_ONLY + # include + #endif +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + SPDLOG_INLINE backtracer::backtracer(const backtracer &other) + { +@@ -72,4 +72,4 @@ SPDLOG_INLINE void backtracer::foreach_pop(std::function + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + template + class circular_q +@@ -143,4 +143,4 @@ private: + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/console_globals.h b/include/spdlog/details/console_globals.h +index 665201dd..ebb07285 100644 +--- a/include/spdlog/details/console_globals.h ++++ b/include/spdlog/details/console_globals.h +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + struct console_mutex +@@ -29,4 +29,4 @@ struct console_nullmutex + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/file_helper-inl.h b/include/spdlog/details/file_helper-inl.h +index 74c89a87..86d88030 100644 +--- a/include/spdlog/details/file_helper-inl.h ++++ b/include/spdlog/details/file_helper-inl.h +@@ -17,7 +17,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + SPDLOG_INLINE file_helper::file_helper(const file_event_handlers &event_handlers) +@@ -177,4 +177,4 @@ SPDLOG_INLINE std::tuple file_helper::split_by_extension + } + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/file_helper.h b/include/spdlog/details/file_helper.h +index f42a5eb1..8a4bde47 100644 +--- a/include/spdlog/details/file_helper.h ++++ b/include/spdlog/details/file_helper.h +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + // Helper class for file sinks. +@@ -55,7 +55,7 @@ private: + file_event_handlers event_handlers_; + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "file_helper-inl.h" +diff --git a/include/spdlog/details/fmt_helper.h b/include/spdlog/details/fmt_helper.h +index d9867180..22adef57 100644 +--- a/include/spdlog/details/fmt_helper.h ++++ b/include/spdlog/details/fmt_helper.h +@@ -14,11 +14,11 @@ + #endif + + // Some fmt helpers to efficiently format and pad ints and strings +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + namespace fmt_helper { + +-inline void append_string_view(spdlog::string_view_t view, memory_buf_t &dest) ++inline void append_string_view(ds_spdlog::string_view_t view, memory_buf_t &dest) + { + auto *buf_ptr = view.data(); + dest.append(buf_ptr, buf_ptr + view.size()); +@@ -161,4 +161,4 @@ inline ToDuration time_fraction(log_clock::time_point tp) + + } // namespace fmt_helper + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/log_msg-inl.h b/include/spdlog/details/log_msg-inl.h +index c6e8a7e0..9717efdf 100644 +--- a/include/spdlog/details/log_msg-inl.h ++++ b/include/spdlog/details/log_msg-inl.h +@@ -9,11 +9,11 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + +-SPDLOG_INLINE log_msg::log_msg(spdlog::log_clock::time_point log_time, spdlog::source_loc loc, string_view_t a_logger_name, +- spdlog::level::level_enum lvl, spdlog::string_view_t msg) ++SPDLOG_INLINE log_msg::log_msg(ds_spdlog::log_clock::time_point log_time, ds_spdlog::source_loc loc, string_view_t a_logger_name, ++ ds_spdlog::level::level_enum lvl, ds_spdlog::string_view_t msg) + : logger_name(a_logger_name) + , level(lvl) + , time(log_time) +@@ -25,13 +25,13 @@ SPDLOG_INLINE log_msg::log_msg(spdlog::log_clock::time_point log_time, spdlog::s + {} + + SPDLOG_INLINE log_msg::log_msg( +- spdlog::source_loc loc, string_view_t a_logger_name, spdlog::level::level_enum lvl, spdlog::string_view_t msg) ++ ds_spdlog::source_loc loc, string_view_t a_logger_name, ds_spdlog::level::level_enum lvl, ds_spdlog::string_view_t msg) + : log_msg(os::now(), loc, a_logger_name, lvl, msg) + {} + +-SPDLOG_INLINE log_msg::log_msg(string_view_t a_logger_name, spdlog::level::level_enum lvl, spdlog::string_view_t msg) ++SPDLOG_INLINE log_msg::log_msg(string_view_t a_logger_name, ds_spdlog::level::level_enum lvl, ds_spdlog::string_view_t msg) + : log_msg(os::now(), source_loc{}, a_logger_name, lvl, msg) + {} + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/log_msg.h b/include/spdlog/details/log_msg.h +index fed51abd..06dd9773 100644 +--- a/include/spdlog/details/log_msg.h ++++ b/include/spdlog/details/log_msg.h +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + struct SPDLOG_API log_msg + { +@@ -30,7 +30,7 @@ struct SPDLOG_API log_msg + string_view_t payload; + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "log_msg-inl.h" +diff --git a/include/spdlog/details/log_msg_buffer-inl.h b/include/spdlog/details/log_msg_buffer-inl.h +index 84d83dc2..7b1c91f9 100644 +--- a/include/spdlog/details/log_msg_buffer-inl.h ++++ b/include/spdlog/details/log_msg_buffer-inl.h +@@ -7,7 +7,7 @@ + # include + #endif + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + SPDLOG_INLINE log_msg_buffer::log_msg_buffer(const log_msg &orig_msg) +@@ -55,4 +55,4 @@ SPDLOG_INLINE void log_msg_buffer::update_string_views() + } + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/log_msg_buffer.h b/include/spdlog/details/log_msg_buffer.h +index 81055065..24350729 100644 +--- a/include/spdlog/details/log_msg_buffer.h ++++ b/include/spdlog/details/log_msg_buffer.h +@@ -5,7 +5,7 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + // Extend log_msg with internal buffer to store its payload. +@@ -26,7 +26,7 @@ public: + }; + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "log_msg_buffer-inl.h" +diff --git a/include/spdlog/details/mpmc_blocking_q.h b/include/spdlog/details/mpmc_blocking_q.h +index 101ea8c0..66ae9ae3 100644 +--- a/include/spdlog/details/mpmc_blocking_q.h ++++ b/include/spdlog/details/mpmc_blocking_q.h +@@ -15,7 +15,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + template +@@ -148,7 +148,7 @@ private: + std::mutex queue_mutex_; + std::condition_variable push_cv_; + std::condition_variable pop_cv_; +- spdlog::details::circular_q q_; ++ ds_spdlog::details::circular_q q_; + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/null_mutex.h b/include/spdlog/details/null_mutex.h +index 6550a7bf..b34fb2d1 100644 +--- a/include/spdlog/details/null_mutex.h ++++ b/include/spdlog/details/null_mutex.h +@@ -7,7 +7,7 @@ + #include + // null, no cost dummy "mutex" and dummy "atomic" int + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + struct null_mutex + { +@@ -42,4 +42,4 @@ struct null_atomic_int + }; + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/os-inl.h b/include/spdlog/details/os-inl.h +index ea8864ea..c6a3f02b 100644 +--- a/include/spdlog/details/os-inl.h ++++ b/include/spdlog/details/os-inl.h +@@ -70,11 +70,11 @@ + # define __has_feature(x) 0 // Compatibility with non-clang compilers. + #endif + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + namespace os { + +-SPDLOG_INLINE spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT ++SPDLOG_INLINE ds_spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT + { + + #if defined __linux__ && defined SPDLOG_CLOCK_COARSE +@@ -632,4 +632,4 @@ SPDLOG_INLINE bool fsync(FILE *fp) + + } // namespace os + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/os.h b/include/spdlog/details/os.h +index 37b00874..fcd1c3f5 100644 +--- a/include/spdlog/details/os.h ++++ b/include/spdlog/details/os.h +@@ -6,11 +6,11 @@ + #include + #include // std::time_t + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + namespace os { + +-SPDLOG_API spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT; ++SPDLOG_API ds_spdlog::log_clock::time_point now() SPDLOG_NOEXCEPT; + + SPDLOG_API std::tm localtime(const std::time_t &time_tt) SPDLOG_NOEXCEPT; + +@@ -115,7 +115,7 @@ SPDLOG_API bool fsync(FILE *fp); + + } // namespace os + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "os-inl.h" +diff --git a/include/spdlog/details/periodic_worker-inl.h b/include/spdlog/details/periodic_worker-inl.h +index 520a2b33..9e71e3e5 100644 +--- a/include/spdlog/details/periodic_worker-inl.h ++++ b/include/spdlog/details/periodic_worker-inl.h +@@ -7,7 +7,7 @@ + # include + #endif + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + // stop the worker thread and join it +@@ -25,4 +25,4 @@ SPDLOG_INLINE periodic_worker::~periodic_worker() + } + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/periodic_worker.h b/include/spdlog/details/periodic_worker.h +index d7d69b28..02794722 100644 +--- a/include/spdlog/details/periodic_worker.h ++++ b/include/spdlog/details/periodic_worker.h +@@ -14,7 +14,7 @@ + #include + #include + #include +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + class SPDLOG_API periodic_worker +@@ -53,7 +53,7 @@ private: + std::condition_variable cv_; + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "periodic_worker-inl.h" +diff --git a/include/spdlog/details/registry-inl.h b/include/spdlog/details/registry-inl.h +index cb1fe84f..27af503f 100644 +--- a/include/spdlog/details/registry-inl.h ++++ b/include/spdlog/details/registry-inl.h +@@ -27,7 +27,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + SPDLOG_INLINE registry::registry() +@@ -43,7 +43,7 @@ SPDLOG_INLINE registry::registry() + # endif + + const char *default_logger_name = ""; +- default_logger_ = std::make_shared(default_logger_name, std::move(color_sink)); ++ default_logger_ = std::make_shared(default_logger_name, std::move(color_sink)); + loggers_[default_logger_name] = default_logger_; + + #endif // SPDLOG_DISABLE_DEFAULT_LOGGER +@@ -99,9 +99,9 @@ SPDLOG_INLINE std::shared_ptr registry::default_logger() + } + + // Return raw ptr to the default logger. +-// To be used directly by the spdlog default api (e.g. spdlog::info) ++// To be used directly by the spdlog default api (e.g. ds_spdlog::info) + // This make the default API faster, but cannot be used concurrently with set_default_logger(). +-// e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. ++// e.g do not call set_default_logger() from one thread while calling ds_spdlog::info() from another. + SPDLOG_INLINE logger *registry::get_default_raw() + { + return default_logger_.get(); +@@ -312,4 +312,4 @@ SPDLOG_INLINE void registry::register_logger_(std::shared_ptr new_logger + } + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/registry.h b/include/spdlog/details/registry.h +index 4666fa29..fda00b68 100644 +--- a/include/spdlog/details/registry.h ++++ b/include/spdlog/details/registry.h +@@ -18,7 +18,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + class logger; + + namespace details { +@@ -37,9 +37,9 @@ public: + std::shared_ptr default_logger(); + + // Return raw ptr to the default logger. +- // To be used directly by the spdlog default api (e.g. spdlog::info) ++ // To be used directly by the spdlog default api (e.g. ds_spdlog::info) + // This make the default API faster, but cannot be used concurrently with set_default_logger(). +- // e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. ++ // e.g do not call set_default_logger() from one thread while calling ds_spdlog::info() from another. + logger *get_default_raw(); + + // set default logger. +@@ -105,7 +105,7 @@ private: + std::unordered_map> loggers_; + log_levels log_levels_; + std::unique_ptr formatter_; +- spdlog::level::level_enum global_log_level_ = level::info; ++ ds_spdlog::level::level_enum global_log_level_ = level::info; + level::level_enum flush_level_ = level::off; + err_handler err_handler_; + std::shared_ptr tp_; +@@ -116,7 +116,7 @@ private: + }; + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "registry-inl.h" +diff --git a/include/spdlog/details/synchronous_factory.h b/include/spdlog/details/synchronous_factory.h +index e1e42268..29da160b 100644 +--- a/include/spdlog/details/synchronous_factory.h ++++ b/include/spdlog/details/synchronous_factory.h +@@ -5,7 +5,7 @@ + + #include "registry.h" + +-namespace spdlog { ++namespace ds_spdlog { + + // Default logger factory- creates synchronous loggers + class logger; +@@ -13,12 +13,12 @@ class logger; + struct synchronous_factory + { + template +- static std::shared_ptr create(std::string logger_name, SinkArgs &&...args) ++ static std::shared_ptr create(std::string logger_name, SinkArgs &&...args) + { + auto sink = std::make_shared(std::forward(args)...); +- auto new_logger = std::make_shared(std::move(logger_name), std::move(sink)); ++ auto new_logger = std::make_shared(std::move(logger_name), std::move(sink)); + details::registry::instance().initialize_logger(new_logger); + return new_logger; + } + }; +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/tcp_client-windows.h b/include/spdlog/details/tcp_client-windows.h +index 968b2570..60438ca8 100644 +--- a/include/spdlog/details/tcp_client-windows.h ++++ b/include/spdlog/details/tcp_client-windows.h +@@ -19,7 +19,7 @@ + #pragma comment(lib, "Mswsock.lib") + #pragma comment(lib, "AdvApi32.lib") + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + class tcp_client + { +@@ -157,4 +157,4 @@ public: + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/tcp_client.h b/include/spdlog/details/tcp_client.h +index 8b11dfd2..1c20833e 100644 +--- a/include/spdlog/details/tcp_client.h ++++ b/include/spdlog/details/tcp_client.h +@@ -20,7 +20,7 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + class tcp_client + { +@@ -143,4 +143,4 @@ public: + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/thread_pool-inl.h b/include/spdlog/details/thread_pool-inl.h +index dbd424ff..0b0652a4 100644 +--- a/include/spdlog/details/thread_pool-inl.h ++++ b/include/spdlog/details/thread_pool-inl.h +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + SPDLOG_INLINE thread_pool::thread_pool( +@@ -19,7 +19,7 @@ SPDLOG_INLINE thread_pool::thread_pool( + { + if (threads_n == 0 || threads_n > 1000) + { +- throw_spdlog_ex("spdlog::thread_pool(): invalid threads_n param (valid " ++ throw_spdlog_ex("ds_spdlog::thread_pool(): invalid threads_n param (valid " + "range is 1-1000)"); + } + for (size_t i = 0; i < threads_n; i++) +@@ -134,4 +134,4 @@ bool SPDLOG_INLINE thread_pool::process_next_msg_() + } + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/thread_pool.h b/include/spdlog/details/thread_pool.h +index 52c569b8..b6ab3e51 100644 +--- a/include/spdlog/details/thread_pool.h ++++ b/include/spdlog/details/thread_pool.h +@@ -13,12 +13,12 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + class async_logger; + + namespace details { + +-using async_logger_ptr = std::shared_ptr; ++using async_logger_ptr = std::shared_ptr; + + enum class async_msg_type + { +@@ -115,7 +115,7 @@ private: + }; + + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "thread_pool-inl.h" +diff --git a/include/spdlog/details/udp_client-windows.h b/include/spdlog/details/udp_client-windows.h +index 10894ee6..30968d93 100644 +--- a/include/spdlog/details/udp_client-windows.h ++++ b/include/spdlog/details/udp_client-windows.h +@@ -21,7 +21,7 @@ + # pragma comment(lib, "AdvApi32.lib") + #endif + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + class udp_client + { +@@ -110,4 +110,4 @@ public: + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/details/udp_client.h b/include/spdlog/details/udp_client.h +index e8c2cccf..61f19910 100644 +--- a/include/spdlog/details/udp_client.h ++++ b/include/spdlog/details/udp_client.h +@@ -22,7 +22,7 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + class udp_client +@@ -91,4 +91,4 @@ public: + } + }; + } // namespace details +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/fmt/bin_to_hex.h b/include/spdlog/fmt/bin_to_hex.h +index 3bf003d4..b3bcd69d 100644 +--- a/include/spdlog/fmt/bin_to_hex.h ++++ b/include/spdlog/fmt/bin_to_hex.h +@@ -31,12 +31,12 @@ + // Examples: + // + // std::vector v(200, 0x0b); +-// logger->info("Some buffer {}", spdlog::to_hex(v)); ++// logger->info("Some buffer {}", ds_spdlog::to_hex(v)); + // char buf[128]; +-// logger->info("Some buffer {:X}", spdlog::to_hex(std::begin(buf), std::end(buf))); +-// logger->info("Some buffer {:X}", spdlog::to_hex(std::begin(buf), std::end(buf), 16)); ++// logger->info("Some buffer {:X}", ds_spdlog::to_hex(std::begin(buf), std::end(buf))); ++// logger->info("Some buffer {:X}", ds_spdlog::to_hex(std::begin(buf), std::end(buf), 16)); + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + template +@@ -99,7 +99,7 @@ inline details::dump_info to_hex(const It range_begin, const It range_end, s + return details::dump_info(range_begin, range_end, size_per_line); + } + +-} // namespace spdlog ++} // namespace ds_spdlog + + namespace + #ifdef SPDLOG_USE_STD_FORMAT +@@ -110,7 +110,7 @@ namespace + { + + template +-struct formatter, char> ++struct formatter, char> + { + const char delimiter = ' '; + bool put_newlines = true; +@@ -156,7 +156,7 @@ struct formatter, char> + + // format the given bytes range as hex + template +- auto format(const spdlog::details::dump_info &the_range, FormatContext &ctx) const -> decltype(ctx.out()) ++ auto format(const ds_spdlog::details::dump_info &the_range, FormatContext &ctx) const -> decltype(ctx.out()) + { + SPDLOG_CONSTEXPR const char *hex_upper = "0123456789ABCDEF"; + SPDLOG_CONSTEXPR const char *hex_lower = "0123456789abcdef"; +@@ -241,7 +241,7 @@ struct formatter, char> + + if (put_positions) + { +- spdlog::fmt_lib::format_to(inserter, SPDLOG_FMT_STRING("{:04X}: "), pos); ++ ds_spdlog::fmt_lib::format_to(inserter, SPDLOG_FMT_STRING("{:04X}: "), pos); + } + } + }; +diff --git a/include/spdlog/formatter.h b/include/spdlog/formatter.h +index 5086fb21..e10cc585 100644 +--- a/include/spdlog/formatter.h ++++ b/include/spdlog/formatter.h +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + + class formatter + { +@@ -15,4 +15,4 @@ public: + virtual void format(const details::log_msg &msg, memory_buf_t &dest) = 0; + virtual std::unique_ptr clone() const = 0; + }; +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/fwd.h b/include/spdlog/fwd.h +index d2588257..1120ea7a 100644 +--- a/include/spdlog/fwd.h ++++ b/include/spdlog/fwd.h +@@ -3,7 +3,7 @@ + + #pragma once + +-namespace spdlog { ++namespace ds_spdlog { + class logger; + class formatter; + +@@ -15,4 +15,4 @@ namespace level { + enum level_enum : int; + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/logger-inl.h b/include/spdlog/logger-inl.h +index 227cec43..907d8a73 100644 +--- a/include/spdlog/logger-inl.h ++++ b/include/spdlog/logger-inl.h +@@ -13,7 +13,7 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + + // public methods + SPDLOG_INLINE logger::logger(const logger &other) +@@ -40,7 +40,7 @@ SPDLOG_INLINE logger &logger::operator=(logger other) SPDLOG_NOEXCEPT + return *this; + } + +-SPDLOG_INLINE void logger::swap(spdlog::logger &other) SPDLOG_NOEXCEPT ++SPDLOG_INLINE void logger::swap(ds_spdlog::logger &other) SPDLOG_NOEXCEPT + { + name_.swap(other.name_); + sinks_.swap(other.sinks_); +@@ -163,7 +163,7 @@ SPDLOG_INLINE std::shared_ptr logger::clone(std::string logger_name) + } + + // protected methods +-SPDLOG_INLINE void logger::log_it_(const spdlog::details::log_msg &log_msg, bool log_enabled, bool traceback_enabled) ++SPDLOG_INLINE void logger::log_it_(const ds_spdlog::details::log_msg &log_msg, bool log_enabled, bool traceback_enabled) + { + if (log_enabled) + { +@@ -254,4 +254,4 @@ SPDLOG_INLINE void logger::err_handler_(const std::string &msg) + #endif + } + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/logger.h b/include/spdlog/logger.h +index 0802a5d9..f9d85094 100644 +--- a/include/spdlog/logger.h ++++ b/include/spdlog/logger.h +@@ -49,7 +49,7 @@ + # define SPDLOG_LOGGER_CATCH(location) + #endif + +-namespace spdlog { ++namespace ds_spdlog { + + class SPDLOG_API logger + { +@@ -82,7 +82,7 @@ public: + logger(const logger &other); + logger(logger &&other) SPDLOG_NOEXCEPT; + logger &operator=(logger other) SPDLOG_NOEXCEPT; +- void swap(spdlog::logger &other) SPDLOG_NOEXCEPT; ++ void swap(ds_spdlog::logger &other) SPDLOG_NOEXCEPT; + + template + void log(source_loc loc, level::level_enum lvl, format_string_t fmt, Args &&...args) +@@ -350,8 +350,8 @@ public: + protected: + std::string name_; + std::vector sinks_; +- spdlog::level_t level_{level::info}; +- spdlog::level_t flush_level_{level::off}; ++ ds_spdlog::level_t level_{level::info}; ++ ds_spdlog::level_t flush_level_{level::off}; + err_handler custom_err_handler_{nullptr}; + details::backtracer tracer_; + +@@ -420,7 +420,7 @@ protected: + + void swap(logger &a, logger &b); + +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "logger-inl.h" +diff --git a/include/spdlog/pattern_formatter-inl.h b/include/spdlog/pattern_formatter-inl.h +index 01afbe6f..7ef28f0f 100644 +--- a/include/spdlog/pattern_formatter-inl.h ++++ b/include/spdlog/pattern_formatter-inl.h +@@ -27,7 +27,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + /////////////////////////////////////////////////////////////////////// +@@ -1318,8 +1318,8 @@ SPDLOG_INLINE void pattern_formatter::handle_flag_(char flag, details::padding_i + formatters_.push_back((std::move(unknown_flag))); + } + // fix issue #1617 (prev char was '!' and should have been treated as funcname flag instead of truncating flag) +- // spdlog::set_pattern("[%10!] %v") => "[ main] some message" +- // spdlog::set_pattern("[%3!!] %v") => "[mai] some message" ++ // ds_spdlog::set_pattern("[%10!] %v") => "[ main] some message" ++ // ds_spdlog::set_pattern("[%3!!] %v") => "[mai] some message" + else + { + padding.truncate_ = false; +@@ -1433,4 +1433,4 @@ SPDLOG_INLINE void pattern_formatter::compile_pattern_(const std::string &patter + formatters_.push_back(std::move(user_chars)); + } + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/pattern_formatter.h b/include/spdlog/pattern_formatter.h +index 4c87b21e..59503535 100644 +--- a/include/spdlog/pattern_formatter.h ++++ b/include/spdlog/pattern_formatter.h +@@ -16,7 +16,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace details { + + // padding information. +@@ -80,10 +80,10 @@ public: + using custom_flags = std::unordered_map>; + + explicit pattern_formatter(std::string pattern, pattern_time_type time_type = pattern_time_type::local, +- std::string eol = spdlog::details::os::default_eol, custom_flags custom_user_flags = custom_flags()); ++ std::string eol = ds_spdlog::details::os::default_eol, custom_flags custom_user_flags = custom_flags()); + + // use default pattern is not given +- explicit pattern_formatter(pattern_time_type time_type = pattern_time_type::local, std::string eol = spdlog::details::os::default_eol); ++ explicit pattern_formatter(pattern_time_type time_type = pattern_time_type::local, std::string eol = ds_spdlog::details::os::default_eol); + + pattern_formatter(const pattern_formatter &other) = delete; + pattern_formatter &operator=(const pattern_formatter &other) = delete; +@@ -121,7 +121,7 @@ private: + + void compile_pattern_(const std::string &pattern); + }; +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "pattern_formatter-inl.h" +diff --git a/include/spdlog/sinks/android_sink.h b/include/spdlog/sinks/android_sink.h +index 0087e953..2c742ac6 100644 +--- a/include/spdlog/sinks/android_sink.h ++++ b/include/spdlog/sinks/android_sink.h +@@ -22,7 +22,7 @@ + # define SPDLOG_ANDROID_RETRIES 2 + # endif + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + /* +@@ -92,21 +92,21 @@ private: + return __android_log_buf_write(ID, prio, tag, text); + } + +- static android_LogPriority convert_to_android_(spdlog::level::level_enum level) ++ static android_LogPriority convert_to_android_(ds_spdlog::level::level_enum level) + { + switch (level) + { +- case spdlog::level::trace: ++ case ds_spdlog::level::trace: + return ANDROID_LOG_VERBOSE; +- case spdlog::level::debug: ++ case ds_spdlog::level::debug: + return ANDROID_LOG_DEBUG; +- case spdlog::level::info: ++ case ds_spdlog::level::info: + return ANDROID_LOG_INFO; +- case spdlog::level::warn: ++ case ds_spdlog::level::warn: + return ANDROID_LOG_WARN; +- case spdlog::level::err: ++ case ds_spdlog::level::err: + return ANDROID_LOG_ERROR; +- case spdlog::level::critical: ++ case ds_spdlog::level::critical: + return ANDROID_LOG_FATAL; + default: + return ANDROID_LOG_DEFAULT; +@@ -129,18 +129,18 @@ using android_sink_buf_st = android_sink; + + // Create and register android syslog logger + +-template ++template + inline std::shared_ptr android_logger_mt(const std::string &logger_name, const std::string &tag = "spdlog") + { + return Factory::template create(logger_name, tag); + } + +-template ++template + inline std::shared_ptr android_logger_st(const std::string &logger_name, const std::string &tag = "spdlog") + { + return Factory::template create(logger_name, tag); + } + +-} // namespace spdlog ++} // namespace ds_spdlog + + #endif // __ANDROID__ +diff --git a/include/spdlog/sinks/ansicolor_sink-inl.h b/include/spdlog/sinks/ansicolor_sink-inl.h +index c924fc5b..ebfc2665 100644 +--- a/include/spdlog/sinks/ansicolor_sink-inl.h ++++ b/include/spdlog/sinks/ansicolor_sink-inl.h +@@ -10,14 +10,14 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + template + SPDLOG_INLINE ansicolor_sink::ansicolor_sink(FILE *target_file, color_mode mode) + : target_file_(target_file) + , mutex_(ConsoleMutex::mutex()) +- , formatter_(details::make_unique()) ++ , formatter_(details::make_unique()) + + { + set_color_mode(mode); +@@ -76,11 +76,11 @@ template + SPDLOG_INLINE void ansicolor_sink::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); +- formatter_ = std::unique_ptr(new pattern_formatter(pattern)); ++ formatter_ = std::unique_ptr(new pattern_formatter(pattern)); + } + + template +-SPDLOG_INLINE void ansicolor_sink::set_formatter(std::unique_ptr sink_formatter) ++SPDLOG_INLINE void ansicolor_sink::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + formatter_ = std::move(sink_formatter); +@@ -142,4 +142,4 @@ SPDLOG_INLINE ansicolor_stderr_sink::ansicolor_stderr_sink(color_m + {} + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/ansicolor_sink.h b/include/spdlog/sinks/ansicolor_sink.h +index 39d966bc..3ebd07e7 100644 +--- a/include/spdlog/sinks/ansicolor_sink.h ++++ b/include/spdlog/sinks/ansicolor_sink.h +@@ -11,7 +11,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + /** +@@ -42,7 +42,7 @@ public: + void log(const details::log_msg &msg) override; + void flush() override; + void set_pattern(const std::string &pattern) final; +- void set_formatter(std::unique_ptr sink_formatter) override; ++ void set_formatter(std::unique_ptr sink_formatter) override; + + // Formatting codes + const string_view_t reset = "\033[m"; +@@ -83,7 +83,7 @@ private: + FILE *target_file_; + mutex_t &mutex_; + bool should_do_colors_; +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + std::array colors_; + void print_ccode_(const string_view_t &color_code); + void print_range_(const memory_buf_t &formatted, size_t start, size_t end); +@@ -111,7 +111,7 @@ using ansicolor_stderr_sink_mt = ansicolor_stderr_sink; + using ansicolor_stderr_sink_st = ansicolor_stderr_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "ansicolor_sink-inl.h" +diff --git a/include/spdlog/sinks/base_sink-inl.h b/include/spdlog/sinks/base_sink-inl.h +index 421fdf9d..7a8e9512 100644 +--- a/include/spdlog/sinks/base_sink-inl.h ++++ b/include/spdlog/sinks/base_sink-inl.h +@@ -13,51 +13,51 @@ + #include + + template +-SPDLOG_INLINE spdlog::sinks::base_sink::base_sink() +- : formatter_{details::make_unique()} ++SPDLOG_INLINE ds_spdlog::sinks::base_sink::base_sink() ++ : formatter_{details::make_unique()} + {} + + template +-SPDLOG_INLINE spdlog::sinks::base_sink::base_sink(std::unique_ptr formatter) ++SPDLOG_INLINE ds_spdlog::sinks::base_sink::base_sink(std::unique_ptr formatter) + : formatter_{std::move(formatter)} + {} + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::log(const details::log_msg &msg) ++void SPDLOG_INLINE ds_spdlog::sinks::base_sink::log(const details::log_msg &msg) + { + std::lock_guard lock(mutex_); + sink_it_(msg); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::flush() ++void SPDLOG_INLINE ds_spdlog::sinks::base_sink::flush() + { + std::lock_guard lock(mutex_); + flush_(); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_pattern(const std::string &pattern) ++void SPDLOG_INLINE ds_spdlog::sinks::base_sink::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); + set_pattern_(pattern); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_formatter(std::unique_ptr sink_formatter) ++void SPDLOG_INLINE ds_spdlog::sinks::base_sink::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + set_formatter_(std::move(sink_formatter)); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_pattern_(const std::string &pattern) ++void SPDLOG_INLINE ds_spdlog::sinks::base_sink::set_pattern_(const std::string &pattern) + { +- set_formatter_(details::make_unique(pattern)); ++ set_formatter_(details::make_unique(pattern)); + } + + template +-void SPDLOG_INLINE spdlog::sinks::base_sink::set_formatter_(std::unique_ptr sink_formatter) ++void SPDLOG_INLINE ds_spdlog::sinks::base_sink::set_formatter_(std::unique_ptr sink_formatter) + { + formatter_ = std::move(sink_formatter); + } +diff --git a/include/spdlog/sinks/base_sink.h b/include/spdlog/sinks/base_sink.h +index 2e795f59..26c571bc 100644 +--- a/include/spdlog/sinks/base_sink.h ++++ b/include/spdlog/sinks/base_sink.h +@@ -13,14 +13,14 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + template + class SPDLOG_API base_sink : public sink + { + public: + base_sink(); +- explicit base_sink(std::unique_ptr formatter); ++ explicit base_sink(std::unique_ptr formatter); + ~base_sink() override = default; + + base_sink(const base_sink &) = delete; +@@ -32,20 +32,20 @@ public: + void log(const details::log_msg &msg) final; + void flush() final; + void set_pattern(const std::string &pattern) final; +- void set_formatter(std::unique_ptr sink_formatter) final; ++ void set_formatter(std::unique_ptr sink_formatter) final; + + protected: + // sink formatter +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + Mutex mutex_; + + virtual void sink_it_(const details::log_msg &msg) = 0; + virtual void flush_() = 0; + virtual void set_pattern_(const std::string &pattern); +- virtual void set_formatter_(std::unique_ptr sink_formatter); ++ virtual void set_formatter_(std::unique_ptr sink_formatter); + }; + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "base_sink-inl.h" +diff --git a/include/spdlog/sinks/basic_file_sink-inl.h b/include/spdlog/sinks/basic_file_sink-inl.h +index 8d23f96d..1d7a0b8c 100644 +--- a/include/spdlog/sinks/basic_file_sink-inl.h ++++ b/include/spdlog/sinks/basic_file_sink-inl.h +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + template +@@ -41,4 +41,4 @@ SPDLOG_INLINE void basic_file_sink::flush_() + } + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/basic_file_sink.h b/include/spdlog/sinks/basic_file_sink.h +index aacc993b..5aa6a3e6 100644 +--- a/include/spdlog/sinks/basic_file_sink.h ++++ b/include/spdlog/sinks/basic_file_sink.h +@@ -11,7 +11,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + /* + * Trivial file sink with single file as target +@@ -39,21 +39,21 @@ using basic_file_sink_st = basic_file_sink; + // + // factory functions + // +-template ++template + inline std::shared_ptr basic_logger_mt( + const std::string &logger_name, const filename_t &filename, bool truncate = false, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, event_handlers); + } + +-template ++template + inline std::shared_ptr basic_logger_st( + const std::string &logger_name, const filename_t &filename, bool truncate = false, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, event_handlers); + } + +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "basic_file_sink-inl.h" +diff --git a/include/spdlog/sinks/callback_sink.h b/include/spdlog/sinks/callback_sink.h +index bcd31383..2af1bae7 100644 +--- a/include/spdlog/sinks/callback_sink.h ++++ b/include/spdlog/sinks/callback_sink.h +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + + // callbacks type + typedef std::function custom_log_callback; +@@ -46,16 +46,16 @@ using callback_sink_st = callback_sink; + // + // factory functions + // +-template ++template + inline std::shared_ptr callback_logger_mt(const std::string &logger_name, const custom_log_callback &callback) + { + return Factory::template create(logger_name, callback); + } + +-template ++template + inline std::shared_ptr callback_logger_st(const std::string &logger_name, const custom_log_callback &callback) + { + return Factory::template create(logger_name, callback); + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/daily_file_sink.h b/include/spdlog/sinks/daily_file_sink.h +index 0770380c..10a1f7cb 100644 +--- a/include/spdlog/sinks/daily_file_sink.h ++++ b/include/spdlog/sinks/daily_file_sink.h +@@ -20,7 +20,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + /* +@@ -41,8 +41,8 @@ struct daily_filename_calculator + /* + * Generator of daily log file names with strftime format. + * Usages: +- * auto sink = std::make_shared("myapp-%Y-%m-%d:%H:%M:%S.log", hour, minute);" +- * auto logger = spdlog::daily_logger_format_mt("loggername, "myapp-%Y-%m-%d:%X.log", hour, minute)" ++ * auto sink = std::make_shared("myapp-%Y-%m-%d:%H:%M:%S.log", hour, minute);" ++ * auto logger = ds_spdlog::daily_logger_format_mt("loggername, "myapp-%Y-%m-%d:%X.log", hour, minute)" + * + */ + struct daily_filename_format_calculator +@@ -155,7 +155,7 @@ private: + tm now_tm(log_clock::time_point tp) + { + time_t tnow = log_clock::to_time_t(tp); +- return spdlog::details::os::localtime(tnow); ++ return ds_spdlog::details::os::localtime(tnow); + } + + log_clock::time_point next_rotation_tp_() +@@ -215,14 +215,14 @@ using daily_file_format_sink_st = daily_file_sink ++template + inline std::shared_ptr daily_logger_mt(const std::string &logger_name, const filename_t &filename, int hour = 0, int minute = 0, + bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr daily_logger_format_mt(const std::string &logger_name, const filename_t &filename, int hour = 0, + int minute = 0, bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { +@@ -230,18 +230,18 @@ inline std::shared_ptr daily_logger_format_mt(const std::string &logger_ + logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr daily_logger_st(const std::string &logger_name, const filename_t &filename, int hour = 0, int minute = 0, + bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr daily_logger_format_st(const std::string &logger_name, const filename_t &filename, int hour = 0, + int minute = 0, bool truncate = false, uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create( + logger_name, filename, hour, minute, truncate, max_files, event_handlers); + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/dist_sink.h b/include/spdlog/sinks/dist_sink.h +index 7ec3a2ec..e93f4d39 100644 +--- a/include/spdlog/sinks/dist_sink.h ++++ b/include/spdlog/sinks/dist_sink.h +@@ -16,7 +16,7 @@ + // Distribution sink (mux). Stores a vector of sinks which get called when log + // is called + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + template +@@ -76,10 +76,10 @@ protected: + + void set_pattern_(const std::string &pattern) override + { +- set_formatter_(details::make_unique(pattern)); ++ set_formatter_(details::make_unique(pattern)); + } + +- void set_formatter_(std::unique_ptr sink_formatter) override ++ void set_formatter_(std::unique_ptr sink_formatter) override + { + base_sink::formatter_ = std::move(sink_formatter); + for (auto &sub_sink : sinks_) +@@ -94,4 +94,4 @@ using dist_sink_mt = dist_sink; + using dist_sink_st = dist_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/dup_filter_sink.h b/include/spdlog/sinks/dup_filter_sink.h +index 3c96549c..80f90fdd 100644 +--- a/include/spdlog/sinks/dup_filter_sink.h ++++ b/include/spdlog/sinks/dup_filter_sink.h +@@ -22,7 +22,7 @@ + // int main() { + // auto dup_filter = std::make_shared(std::chrono::seconds(5), level::info); + // dup_filter->add_sink(std::make_shared()); +-// spdlog::logger l("logger", dup_filter); ++// ds_spdlog::logger l("logger", dup_filter); + // l.info("Hello"); + // l.info("Hello"); + // l.info("Hello"); +@@ -34,7 +34,7 @@ + // [2019-06-25 17:50:56.512] [logger] [info] Skipped 3 duplicate messages.. + // [2019-06-25 17:50:56.512] [logger] [info] Different Hello + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + template + class dup_filter_sink : public dist_sink +@@ -93,4 +93,4 @@ using dup_filter_sink_mt = dup_filter_sink; + using dup_filter_sink_st = dup_filter_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/hourly_file_sink.h b/include/spdlog/sinks/hourly_file_sink.h +index 33dd8948..4c3c1cba 100644 +--- a/include/spdlog/sinks/hourly_file_sink.h ++++ b/include/spdlog/sinks/hourly_file_sink.h +@@ -18,7 +18,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + /* +@@ -132,7 +132,7 @@ private: + tm now_tm(log_clock::time_point tp) + { + time_t tnow = log_clock::to_time_t(tp); +- return spdlog::details::os::localtime(tnow); ++ return ds_spdlog::details::os::localtime(tnow); + } + + log_clock::time_point next_rotation_tp_() +@@ -188,17 +188,17 @@ using hourly_file_sink_st = hourly_file_sink; + // + // factory functions + // +-template ++template + inline std::shared_ptr hourly_logger_mt(const std::string &logger_name, const filename_t &filename, bool truncate = false, + uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, max_files, event_handlers); + } + +-template ++template + inline std::shared_ptr hourly_logger_st(const std::string &logger_name, const filename_t &filename, bool truncate = false, + uint16_t max_files = 0, const file_event_handlers &event_handlers = {}) + { + return Factory::template create(logger_name, filename, truncate, max_files, event_handlers); + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/kafka_sink.h b/include/spdlog/sinks/kafka_sink.h +index ce740efc..61b2984f 100644 +--- a/include/spdlog/sinks/kafka_sink.h ++++ b/include/spdlog/sinks/kafka_sink.h +@@ -21,7 +21,7 @@ + // kafka header + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + struct kafka_sink_config +@@ -102,32 +102,32 @@ private: + }; + + using kafka_sink_mt = kafka_sink; +-using kafka_sink_st = kafka_sink; ++using kafka_sink_st = kafka_sink; + + } // namespace sinks + +-template +-inline std::shared_ptr kafka_logger_mt(const std::string &logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_mt(const std::string &logger_name, ds_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-template +-inline std::shared_ptr kafka_logger_st(const std::string &logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_st(const std::string &logger_name, ds_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-template +-inline std::shared_ptr kafka_logger_async_mt(std::string logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_async_mt(std::string logger_name, ds_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-template +-inline std::shared_ptr kafka_logger_async_st(std::string logger_name, spdlog::sinks::kafka_sink_config config) ++template ++inline std::shared_ptr kafka_logger_async_st(std::string logger_name, ds_spdlog::sinks::kafka_sink_config config) + { + return Factory::template create(logger_name, config); + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/mongo_sink.h b/include/spdlog/sinks/mongo_sink.h +index 6a8927f5..30c94c56 100644 +--- a/include/spdlog/sinks/mongo_sink.h ++++ b/include/spdlog/sinks/mongo_sink.h +@@ -23,7 +23,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + template + class mongo_sink : public base_sink +@@ -45,7 +45,7 @@ public: + { + try + { +- client_ = spdlog::details::make_unique(mongocxx::uri{uri}); ++ client_ = ds_spdlog::details::make_unique(mongocxx::uri{uri}); + } + catch (const std::exception &e) + { +@@ -86,22 +86,22 @@ private: + #include "spdlog/details/null_mutex.h" + #include + using mongo_sink_mt = mongo_sink; +-using mongo_sink_st = mongo_sink; ++using mongo_sink_st = mongo_sink; + + } // namespace sinks + +-template ++template + inline std::shared_ptr mongo_logger_mt(const std::string &logger_name, const std::string &db_name, + const std::string &collection_name, const std::string &uri = "mongodb://localhost:27017") + { + return Factory::template create(logger_name, db_name, collection_name, uri); + } + +-template ++template + inline std::shared_ptr mongo_logger_st(const std::string &logger_name, const std::string &db_name, + const std::string &collection_name, const std::string &uri = "mongodb://localhost:27017") + { + return Factory::template create(logger_name, db_name, collection_name, uri); + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/msvc_sink.h b/include/spdlog/sinks/msvc_sink.h +index bf68ae88..7addab94 100644 +--- a/include/spdlog/sinks/msvc_sink.h ++++ b/include/spdlog/sinks/msvc_sink.h +@@ -22,7 +22,7 @@ extern "C" __declspec(dllimport) void __stdcall OutputDebugStringA(const char *l + # endif + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + /* + * MSVC sink (logging using OutputDebugStringA) +@@ -66,6 +66,6 @@ using windebug_sink_mt = msvc_sink_mt; + using windebug_sink_st = msvc_sink_st; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog + + #endif +diff --git a/include/spdlog/sinks/null_sink.h b/include/spdlog/sinks/null_sink.h +index eb832801..13977cbb 100644 +--- a/include/spdlog/sinks/null_sink.h ++++ b/include/spdlog/sinks/null_sink.h +@@ -9,7 +9,7 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + template +@@ -25,7 +25,7 @@ using null_sink_st = null_sink; + + } // namespace sinks + +-template ++template + inline std::shared_ptr null_logger_mt(const std::string &logger_name) + { + auto null_logger = Factory::template create(logger_name); +@@ -33,7 +33,7 @@ inline std::shared_ptr null_logger_mt(const std::string &logger_name) + return null_logger; + } + +-template ++template + inline std::shared_ptr null_logger_st(const std::string &logger_name) + { + auto null_logger = Factory::template create(logger_name); +@@ -41,4 +41,4 @@ inline std::shared_ptr null_logger_st(const std::string &logger_name) + return null_logger; + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/ostream_sink.h b/include/spdlog/sinks/ostream_sink.h +index 95c1e962..6f49e37c 100644 +--- a/include/spdlog/sinks/ostream_sink.h ++++ b/include/spdlog/sinks/ostream_sink.h +@@ -9,7 +9,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + template + class ostream_sink final : public base_sink +@@ -47,4 +47,4 @@ using ostream_sink_mt = ostream_sink; + using ostream_sink_st = ostream_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/qt_sinks.h b/include/spdlog/sinks/qt_sinks.h +index f801ac34..7782b111 100644 +--- a/include/spdlog/sinks/qt_sinks.h ++++ b/include/spdlog/sinks/qt_sinks.h +@@ -24,7 +24,7 @@ + // + // qt_sink class + // +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + template + class qt_sink : public base_sink +@@ -237,56 +237,56 @@ using qt_color_sink_st = qt_color_sink; + // + + // log to QTextEdit +-template ++template + inline std::shared_ptr qt_logger_mt(const std::string &logger_name, QTextEdit *qt_object, const std::string &meta_method = "append") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + +-template ++template + inline std::shared_ptr qt_logger_st(const std::string &logger_name, QTextEdit *qt_object, const std::string &meta_method = "append") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + + // log to QPlainTextEdit +-template ++template + inline std::shared_ptr qt_logger_mt( + const std::string &logger_name, QPlainTextEdit *qt_object, const std::string &meta_method = "appendPlainText") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + +-template ++template + inline std::shared_ptr qt_logger_st( + const std::string &logger_name, QPlainTextEdit *qt_object, const std::string &meta_method = "appendPlainText") + { + return Factory::template create(logger_name, qt_object, meta_method); + } + // log to QObject +-template ++template + inline std::shared_ptr qt_logger_mt(const std::string &logger_name, QObject *qt_object, const std::string &meta_method) + { + return Factory::template create(logger_name, qt_object, meta_method); + } + +-template ++template + inline std::shared_ptr qt_logger_st(const std::string &logger_name, QObject *qt_object, const std::string &meta_method) + { + return Factory::template create(logger_name, qt_object, meta_method); + } + + // log to QTextEdit with colorize output +-template ++template + inline std::shared_ptr qt_color_logger_mt(const std::string &logger_name, QTextEdit *qt_text_edit, int max_lines) + { + return Factory::template create(logger_name, qt_text_edit, max_lines); + } + +-template ++template + inline std::shared_ptr qt_color_logger_st(const std::string &logger_name, QTextEdit *qt_text_edit, int max_lines) + { + return Factory::template create(logger_name, qt_text_edit, max_lines); + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/ringbuffer_sink.h b/include/spdlog/sinks/ringbuffer_sink.h +index 3ca47c6f..26f8c7cc 100644 +--- a/include/spdlog/sinks/ringbuffer_sink.h ++++ b/include/spdlog/sinks/ringbuffer_sink.h +@@ -12,7 +12,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + /* + * Ring buffer sink +@@ -71,4 +71,4 @@ using ringbuffer_sink_st = ringbuffer_sink; + + } // namespace sinks + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/rotating_file_sink-inl.h b/include/spdlog/sinks/rotating_file_sink-inl.h +index cf8b9d5c..24aa70a7 100644 +--- a/include/spdlog/sinks/rotating_file_sink-inl.h ++++ b/include/spdlog/sinks/rotating_file_sink-inl.h +@@ -20,7 +20,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + template +@@ -149,4 +149,4 @@ SPDLOG_INLINE bool rotating_file_sink::rename_file_(const filename_t &src + } + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/rotating_file_sink.h b/include/spdlog/sinks/rotating_file_sink.h +index ce0d7b1e..97f3c87a 100644 +--- a/include/spdlog/sinks/rotating_file_sink.h ++++ b/include/spdlog/sinks/rotating_file_sink.h +@@ -12,7 +12,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + // +@@ -59,7 +59,7 @@ using rotating_file_sink_st = rotating_file_sink; + // factory functions + // + +-template ++template + inline std::shared_ptr rotating_logger_mt(const std::string &logger_name, const filename_t &filename, size_t max_file_size, + size_t max_files, bool rotate_on_open = false, const file_event_handlers &event_handlers = {}) + { +@@ -67,14 +67,14 @@ inline std::shared_ptr rotating_logger_mt(const std::string &logger_name + logger_name, filename, max_file_size, max_files, rotate_on_open, event_handlers); + } + +-template ++template + inline std::shared_ptr rotating_logger_st(const std::string &logger_name, const filename_t &filename, size_t max_file_size, + size_t max_files, bool rotate_on_open = false, const file_event_handlers &event_handlers = {}) + { + return Factory::template create( + logger_name, filename, max_file_size, max_files, rotate_on_open, event_handlers); + } +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "rotating_file_sink-inl.h" +diff --git a/include/spdlog/sinks/sink-inl.h b/include/spdlog/sinks/sink-inl.h +index df07adda..c703e020 100644 +--- a/include/spdlog/sinks/sink-inl.h ++++ b/include/spdlog/sinks/sink-inl.h +@@ -9,17 +9,17 @@ + + #include + +-SPDLOG_INLINE bool spdlog::sinks::sink::should_log(spdlog::level::level_enum msg_level) const ++SPDLOG_INLINE bool ds_spdlog::sinks::sink::should_log(ds_spdlog::level::level_enum msg_level) const + { + return msg_level >= level_.load(std::memory_order_relaxed); + } + +-SPDLOG_INLINE void spdlog::sinks::sink::set_level(level::level_enum log_level) ++SPDLOG_INLINE void ds_spdlog::sinks::sink::set_level(level::level_enum log_level) + { + level_.store(log_level, std::memory_order_relaxed); + } + +-SPDLOG_INLINE spdlog::level::level_enum spdlog::sinks::sink::level() const ++SPDLOG_INLINE ds_spdlog::level::level_enum ds_spdlog::sinks::sink::level() const + { +- return static_cast(level_.load(std::memory_order_relaxed)); ++ return static_cast(level_.load(std::memory_order_relaxed)); + } +diff --git a/include/spdlog/sinks/sink.h b/include/spdlog/sinks/sink.h +index 0a28cccc..e18c5fc4 100644 +--- a/include/spdlog/sinks/sink.h ++++ b/include/spdlog/sinks/sink.h +@@ -6,7 +6,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + + namespace sinks { + class SPDLOG_API sink +@@ -16,7 +16,7 @@ public: + virtual void log(const details::log_msg &msg) = 0; + virtual void flush() = 0; + virtual void set_pattern(const std::string &pattern) = 0; +- virtual void set_formatter(std::unique_ptr sink_formatter) = 0; ++ virtual void set_formatter(std::unique_ptr sink_formatter) = 0; + + void set_level(level::level_enum log_level); + level::level_enum level() const; +@@ -28,7 +28,7 @@ protected: + }; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "sink-inl.h" +diff --git a/include/spdlog/sinks/stdout_color_sinks-inl.h b/include/spdlog/sinks/stdout_color_sinks-inl.h +index 066df182..344d3dfa 100644 +--- a/include/spdlog/sinks/stdout_color_sinks-inl.h ++++ b/include/spdlog/sinks/stdout_color_sinks-inl.h +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + + template + SPDLOG_INLINE std::shared_ptr stdout_color_mt(const std::string &logger_name, color_mode mode) +@@ -35,4 +35,4 @@ SPDLOG_INLINE std::shared_ptr stderr_color_st(const std::string &logger_ + { + return Factory::template create(logger_name, mode); + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/stdout_color_sinks.h b/include/spdlog/sinks/stdout_color_sinks.h +index 420b13ab..de3e73e3 100644 +--- a/include/spdlog/sinks/stdout_color_sinks.h ++++ b/include/spdlog/sinks/stdout_color_sinks.h +@@ -11,7 +11,7 @@ + + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + #ifdef _WIN32 + using stdout_color_sink_mt = wincolor_stdout_sink_mt; +@@ -26,19 +26,19 @@ using stderr_color_sink_st = ansicolor_stderr_sink_st; + #endif + } // namespace sinks + +-template ++template + std::shared_ptr stdout_color_mt(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-template ++template + std::shared_ptr stdout_color_st(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-template ++template + std::shared_ptr stderr_color_mt(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-template ++template + std::shared_ptr stderr_color_st(const std::string &logger_name, color_mode mode = color_mode::automatic); + +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "stdout_color_sinks-inl.h" +diff --git a/include/spdlog/sinks/stdout_sinks-inl.h b/include/spdlog/sinks/stdout_sinks-inl.h +index c1754370..ed7e2f5e 100644 +--- a/include/spdlog/sinks/stdout_sinks-inl.h ++++ b/include/spdlog/sinks/stdout_sinks-inl.h +@@ -24,7 +24,7 @@ + # include // _fileno(..) + #endif // WIN32 + +-namespace spdlog { ++namespace ds_spdlog { + + namespace sinks { + +@@ -32,7 +32,7 @@ template + SPDLOG_INLINE stdout_sink_base::stdout_sink_base(FILE *file) + : mutex_(ConsoleMutex::mutex()) + , file_(file) +- , formatter_(details::make_unique()) ++ , formatter_(details::make_unique()) + { + #ifdef _WIN32 + // get windows handle from the FILE* object +@@ -44,7 +44,7 @@ SPDLOG_INLINE stdout_sink_base::stdout_sink_base(FILE *file) + // throw only if non stdout/stderr target is requested (probably regular file and not console). + if (handle_ == INVALID_HANDLE_VALUE && file != stdout && file != stderr) + { +- throw_spdlog_ex("spdlog::stdout_sink_base: _get_osfhandle() failed", errno); ++ throw_spdlog_ex("ds_spdlog::stdout_sink_base: _get_osfhandle() failed", errno); + } + #endif // WIN32 + } +@@ -87,11 +87,11 @@ template + SPDLOG_INLINE void stdout_sink_base::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); +- formatter_ = std::unique_ptr(new pattern_formatter(pattern)); ++ formatter_ = std::unique_ptr(new pattern_formatter(pattern)); + } + + template +-SPDLOG_INLINE void stdout_sink_base::set_formatter(std::unique_ptr sink_formatter) ++SPDLOG_INLINE void stdout_sink_base::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + formatter_ = std::move(sink_formatter); +@@ -135,4 +135,4 @@ SPDLOG_INLINE std::shared_ptr stderr_logger_st(const std::string &logger + { + return Factory::template create(logger_name); + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/stdout_sinks.h b/include/spdlog/sinks/stdout_sinks.h +index 6fdc0de3..f0a428ee 100644 +--- a/include/spdlog/sinks/stdout_sinks.h ++++ b/include/spdlog/sinks/stdout_sinks.h +@@ -12,7 +12,7 @@ + # include + #endif + +-namespace spdlog { ++namespace ds_spdlog { + + namespace sinks { + +@@ -34,12 +34,12 @@ public: + void flush() override; + void set_pattern(const std::string &pattern) override; + +- void set_formatter(std::unique_ptr sink_formatter) override; ++ void set_formatter(std::unique_ptr sink_formatter) override; + + protected: + mutex_t &mutex_; + FILE *file_; +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + #ifdef _WIN32 + HANDLE handle_; + #endif // WIN32 +@@ -68,19 +68,19 @@ using stderr_sink_st = stderr_sink; + } // namespace sinks + + // factory methods +-template ++template + std::shared_ptr stdout_logger_mt(const std::string &logger_name); + +-template ++template + std::shared_ptr stdout_logger_st(const std::string &logger_name); + +-template ++template + std::shared_ptr stderr_logger_mt(const std::string &logger_name); + +-template ++template + std::shared_ptr stderr_logger_st(const std::string &logger_name); + +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "stdout_sinks-inl.h" +diff --git a/include/spdlog/sinks/syslog_sink.h b/include/spdlog/sinks/syslog_sink.h +index 7c38fcb5..59a87162 100644 +--- a/include/spdlog/sinks/syslog_sink.h ++++ b/include/spdlog/sinks/syslog_sink.h +@@ -11,7 +11,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + /** + * Sink that write to syslog using the `syscall()` library call. +@@ -23,13 +23,13 @@ class syslog_sink : public base_sink + public: + syslog_sink(std::string ident, int syslog_option, int syslog_facility, bool enable_formatting) + : enable_formatting_{enable_formatting} +- , syslog_levels_{{/* spdlog::level::trace */ LOG_DEBUG, +- /* spdlog::level::debug */ LOG_DEBUG, +- /* spdlog::level::info */ LOG_INFO, +- /* spdlog::level::warn */ LOG_WARNING, +- /* spdlog::level::err */ LOG_ERR, +- /* spdlog::level::critical */ LOG_CRIT, +- /* spdlog::level::off */ LOG_INFO}} ++ , syslog_levels_{{/* ds_spdlog::level::trace */ LOG_DEBUG, ++ /* ds_spdlog::level::debug */ LOG_DEBUG, ++ /* ds_spdlog::level::info */ LOG_INFO, ++ /* ds_spdlog::level::warn */ LOG_WARNING, ++ /* ds_spdlog::level::err */ LOG_ERR, ++ /* ds_spdlog::level::critical */ LOG_CRIT, ++ /* ds_spdlog::level::off */ LOG_INFO}} + , ident_{std::move(ident)} + { + // set ident to be program name if empty +@@ -93,17 +93,17 @@ using syslog_sink_st = syslog_sink; + } // namespace sinks + + // Create and register a syslog logger +-template ++template + inline std::shared_ptr syslog_logger_mt(const std::string &logger_name, const std::string &syslog_ident = "", int syslog_option = 0, + int syslog_facility = LOG_USER, bool enable_formatting = false) + { + return Factory::template create(logger_name, syslog_ident, syslog_option, syslog_facility, enable_formatting); + } + +-template ++template + inline std::shared_ptr syslog_logger_st(const std::string &logger_name, const std::string &syslog_ident = "", int syslog_option = 0, + int syslog_facility = LOG_USER, bool enable_formatting = false) + { + return Factory::template create(logger_name, syslog_ident, syslog_option, syslog_facility, enable_formatting); + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/systemd_sink.h b/include/spdlog/sinks/systemd_sink.h +index b00a95f2..e7cb0cdc 100644 +--- a/include/spdlog/sinks/systemd_sink.h ++++ b/include/spdlog/sinks/systemd_sink.h +@@ -14,7 +14,7 @@ + #endif + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + /** +@@ -27,13 +27,13 @@ public: + systemd_sink(std::string ident = "", bool enable_formatting = false) + : ident_{std::move(ident)} + , enable_formatting_{enable_formatting} +- , syslog_levels_{{/* spdlog::level::trace */ LOG_DEBUG, +- /* spdlog::level::debug */ LOG_DEBUG, +- /* spdlog::level::info */ LOG_INFO, +- /* spdlog::level::warn */ LOG_WARNING, +- /* spdlog::level::err */ LOG_ERR, +- /* spdlog::level::critical */ LOG_CRIT, +- /* spdlog::level::off */ LOG_INFO}} ++ , syslog_levels_{{/* ds_spdlog::level::trace */ LOG_DEBUG, ++ /* ds_spdlog::level::debug */ LOG_DEBUG, ++ /* ds_spdlog::level::info */ LOG_INFO, ++ /* ds_spdlog::level::warn */ LOG_WARNING, ++ /* ds_spdlog::level::err */ LOG_ERR, ++ /* ds_spdlog::level::critical */ LOG_CRIT, ++ /* ds_spdlog::level::off */ LOG_INFO}} + {} + + ~systemd_sink() override {} +@@ -110,17 +110,17 @@ using systemd_sink_st = systemd_sink; + } // namespace sinks + + // Create and register a syslog logger +-template ++template + inline std::shared_ptr systemd_logger_mt( + const std::string &logger_name, const std::string &ident = "", bool enable_formatting = false) + { + return Factory::template create(logger_name, ident, enable_formatting); + } + +-template ++template + inline std::shared_ptr systemd_logger_st( + const std::string &logger_name, const std::string &ident = "", bool enable_formatting = false) + { + return Factory::template create(logger_name, ident, enable_formatting); + } +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/tcp_sink.h b/include/spdlog/sinks/tcp_sink.h +index e0efb31d..4f9a9459 100644 +--- a/include/spdlog/sinks/tcp_sink.h ++++ b/include/spdlog/sinks/tcp_sink.h +@@ -24,7 +24,7 @@ + // Will attempt to reconnect if connection drops. + // If more complicated behaviour is needed (i.e get responses), you can inherit it and override the sink_it_ method. + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + struct tcp_sink_config +@@ -40,7 +40,7 @@ struct tcp_sink_config + }; + + template +-class tcp_sink : public spdlog::sinks::base_sink ++class tcp_sink : public ds_spdlog::sinks::base_sink + { + public: + // connect to tcp host/port or throw if failed +@@ -58,10 +58,10 @@ public: + ~tcp_sink() override = default; + + protected: +- void sink_it_(const spdlog::details::log_msg &msg) override ++ void sink_it_(const ds_spdlog::details::log_msg &msg) override + { +- spdlog::memory_buf_t formatted; +- spdlog::sinks::base_sink::formatter_->format(msg, formatted); ++ ds_spdlog::memory_buf_t formatted; ++ ds_spdlog::sinks::base_sink::formatter_->format(msg, formatted); + if (!client_.is_connected()) + { + client_.connect(config_.server_host, config_.server_port); +@@ -75,7 +75,7 @@ protected: + }; + + using tcp_sink_mt = tcp_sink; +-using tcp_sink_st = tcp_sink; ++using tcp_sink_st = tcp_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/udp_sink.h b/include/spdlog/sinks/udp_sink.h +index ccbce2be..5385db19 100644 +--- a/include/spdlog/sinks/udp_sink.h ++++ b/include/spdlog/sinks/udp_sink.h +@@ -20,7 +20,7 @@ + // Simple udp client sink + // Sends formatted log via udp + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + struct udp_sink_config +@@ -35,7 +35,7 @@ struct udp_sink_config + }; + + template +-class udp_sink : public spdlog::sinks::base_sink ++class udp_sink : public ds_spdlog::sinks::base_sink + { + public: + // host can be hostname or ip address +@@ -46,10 +46,10 @@ public: + ~udp_sink() override = default; + + protected: +- void sink_it_(const spdlog::details::log_msg &msg) override ++ void sink_it_(const ds_spdlog::details::log_msg &msg) override + { +- spdlog::memory_buf_t formatted; +- spdlog::sinks::base_sink::formatter_->format(msg, formatted); ++ ds_spdlog::memory_buf_t formatted; ++ ds_spdlog::sinks::base_sink::formatter_->format(msg, formatted); + client_.send(formatted.data(), formatted.size()); + } + +@@ -58,17 +58,17 @@ protected: + }; + + using udp_sink_mt = udp_sink; +-using udp_sink_st = udp_sink; ++using udp_sink_st = udp_sink; + + } // namespace sinks + + // + // factory functions + // +-template ++template + inline std::shared_ptr udp_logger_mt(const std::string &logger_name, sinks::udp_sink_config skin_config) + { + return Factory::template create(logger_name, skin_config); + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/win_eventlog_sink.h b/include/spdlog/sinks/win_eventlog_sink.h +index d23d00a8..45ee2e03 100644 +--- a/include/spdlog/sinks/win_eventlog_sink.h ++++ b/include/spdlog/sinks/win_eventlog_sink.h +@@ -40,7 +40,7 @@ Windows Registry Editor Version 5.00 + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + namespace win_eventlog { +@@ -286,4 +286,4 @@ using win_eventlog_sink_mt = win_eventlog::win_eventlog_sink; + using win_eventlog_sink_st = win_eventlog::win_eventlog_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/wincolor_sink-inl.h b/include/spdlog/sinks/wincolor_sink-inl.h +index 8311929e..1a038cc8 100644 +--- a/include/spdlog/sinks/wincolor_sink-inl.h ++++ b/include/spdlog/sinks/wincolor_sink-inl.h +@@ -13,13 +13,13 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + template + SPDLOG_INLINE wincolor_sink::wincolor_sink(void *out_handle, color_mode mode) + : out_handle_(out_handle) + , mutex_(ConsoleMutex::mutex()) +- , formatter_(details::make_unique()) ++ , formatter_(details::make_unique()) + { + + set_color_mode_impl(mode); +@@ -88,11 +88,11 @@ template + void SPDLOG_INLINE wincolor_sink::set_pattern(const std::string &pattern) + { + std::lock_guard lock(mutex_); +- formatter_ = std::unique_ptr(new pattern_formatter(pattern)); ++ formatter_ = std::unique_ptr(new pattern_formatter(pattern)); + } + + template +-void SPDLOG_INLINE wincolor_sink::set_formatter(std::unique_ptr sink_formatter) ++void SPDLOG_INLINE wincolor_sink::set_formatter(std::unique_ptr sink_formatter) + { + std::lock_guard lock(mutex_); + formatter_ = std::move(sink_formatter); +@@ -172,4 +172,4 @@ SPDLOG_INLINE wincolor_stderr_sink::wincolor_stderr_sink(color_mod + : wincolor_sink(::GetStdHandle(STD_ERROR_HANDLE), mode) + {} + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/sinks/wincolor_sink.h b/include/spdlog/sinks/wincolor_sink.h +index 9b030fc1..9e479733 100644 +--- a/include/spdlog/sinks/wincolor_sink.h ++++ b/include/spdlog/sinks/wincolor_sink.h +@@ -14,7 +14,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + /* + * Windows color console sink. Uses WriteConsoleA to write to the console with +@@ -35,7 +35,7 @@ public: + void log(const details::log_msg &msg) final override; + void flush() final override; + void set_pattern(const std::string &pattern) override final; +- void set_formatter(std::unique_ptr sink_formatter) override final; ++ void set_formatter(std::unique_ptr sink_formatter) override final; + void set_color_mode(color_mode mode); + + protected: +@@ -43,7 +43,7 @@ protected: + void *out_handle_; + mutex_t &mutex_; + bool should_do_colors_; +- std::unique_ptr formatter_; ++ std::unique_ptr formatter_; + std::array colors_; + + // set foreground color and return the orig console attributes (for resetting later) +@@ -78,7 +78,7 @@ using wincolor_stdout_sink_st = wincolor_stdout_sink + using wincolor_stderr_sink_mt = wincolor_stderr_sink; + using wincolor_stderr_sink_st = wincolor_stderr_sink; + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog + + #ifdef SPDLOG_HEADER_ONLY + # include "wincolor_sink-inl.h" +diff --git a/include/spdlog/spdlog-inl.h b/include/spdlog/spdlog-inl.h +index 22ea22bb..56b13720 100644 +--- a/include/spdlog/spdlog-inl.h ++++ b/include/spdlog/spdlog-inl.h +@@ -10,7 +10,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + + SPDLOG_INLINE void initialize_logger(std::shared_ptr logger) + { +@@ -22,14 +22,14 @@ SPDLOG_INLINE std::shared_ptr get(const std::string &name) + return details::registry::instance().get(name); + } + +-SPDLOG_INLINE void set_formatter(std::unique_ptr formatter) ++SPDLOG_INLINE void set_formatter(std::unique_ptr formatter) + { + details::registry::instance().set_formatter(std::move(formatter)); + } + + SPDLOG_INLINE void set_pattern(std::string pattern, pattern_time_type time_type) + { +- set_formatter(std::unique_ptr(new pattern_formatter(std::move(pattern), time_type))); ++ set_formatter(std::unique_ptr(new pattern_formatter(std::move(pattern), time_type))); + } + + SPDLOG_INLINE void enable_backtrace(size_t n_messages) +@@ -102,17 +102,17 @@ SPDLOG_INLINE void set_automatic_registration(bool automatic_registration) + details::registry::instance().set_automatic_registration(automatic_registration); + } + +-SPDLOG_INLINE std::shared_ptr default_logger() ++SPDLOG_INLINE std::shared_ptr default_logger() + { + return details::registry::instance().default_logger(); + } + +-SPDLOG_INLINE spdlog::logger *default_logger_raw() ++SPDLOG_INLINE ds_spdlog::logger *default_logger_raw() + { + return details::registry::instance().get_default_raw(); + } + +-SPDLOG_INLINE void set_default_logger(std::shared_ptr default_logger) ++SPDLOG_INLINE void set_default_logger(std::shared_ptr default_logger) + { + details::registry::instance().set_default_logger(std::move(default_logger)); + } +@@ -122,4 +122,4 @@ SPDLOG_INLINE void apply_logger_env_levels(std::shared_ptr logger) + details::registry::instance().apply_logger_env_levels(std::move(logger)); + } + +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/include/spdlog/spdlog.h b/include/spdlog/spdlog.h +index fbfe5fb8..ad8749b7 100644 +--- a/include/spdlog/spdlog.h ++++ b/include/spdlog/spdlog.h +@@ -20,7 +20,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + + using default_factory = synchronous_factory; + +@@ -29,9 +29,9 @@ using default_factory = synchronous_factory; + // global settings. + // + // Example: +-// spdlog::create("logger_name", "dailylog_filename", 11, 59); ++// ds_spdlog::create("logger_name", "dailylog_filename", 11, 59); + template +-inline std::shared_ptr create(std::string logger_name, SinkArgs &&...sink_args) ++inline std::shared_ptr create(std::string logger_name, SinkArgs &&...sink_args) + { + return default_factory::create(std::move(logger_name), std::forward(sink_args)...); + } +@@ -42,20 +42,20 @@ inline std::shared_ptr create(std::string logger_name, SinkArgs + // Useful for initializing manually created loggers with the global settings. + // + // Example: +-// auto mylogger = std::make_shared("mylogger", ...); +-// spdlog::initialize_logger(mylogger); ++// auto mylogger = std::make_shared("mylogger", ...); ++// ds_spdlog::initialize_logger(mylogger); + SPDLOG_API void initialize_logger(std::shared_ptr logger); + + // Return an existing logger or nullptr if a logger with such name doesn't + // exist. +-// example: spdlog::get("my_logger")->info("hello {}", "world"); ++// example: ds_spdlog::get("my_logger")->info("hello {}", "world"); + SPDLOG_API std::shared_ptr get(const std::string &name); + + // Set global formatter. Each sink in each logger will get a clone of this object +-SPDLOG_API void set_formatter(std::unique_ptr formatter); ++SPDLOG_API void set_formatter(std::unique_ptr formatter); + + // Set global format string. +-// example: spdlog::set_pattern("%Y-%m-%d %H:%M:%S.%e %l : %v"); ++// example: ds_spdlog::set_pattern("%Y-%m-%d %H:%M:%S.%e %l : %v"); + SPDLOG_API void set_pattern(std::string pattern, pattern_time_type time_type = pattern_time_type::local); + + // enable global backtrace support +@@ -95,7 +95,7 @@ SPDLOG_API void register_logger(std::shared_ptr logger); + + // Apply a user defined function on all registered loggers + // Example: +-// spdlog::apply_all([&](std::shared_ptr l) {l->flush();}); ++// ds_spdlog::apply_all([&](std::shared_ptr l) {l->flush();}); + SPDLOG_API void apply_all(const std::function)> &fun); + + // Drop the reference to the given logger +@@ -107,37 +107,37 @@ SPDLOG_API void drop_all(); + // stop any running threads started by spdlog and clean registry loggers + SPDLOG_API void shutdown(); + +-// Automatic registration of loggers when using spdlog::create() or spdlog::create_async ++// Automatic registration of loggers when using ds_spdlog::create() or ds_spdlog::create_async + SPDLOG_API void set_automatic_registration(bool automatic_registration); + + // API for using default logger (stdout_color_mt), +-// e.g: spdlog::info("Message {}", 1); ++// e.g: ds_spdlog::info("Message {}", 1); + // +-// The default logger object can be accessed using the spdlog::default_logger(): ++// The default logger object can be accessed using the ds_spdlog::default_logger(): + // For example, to add another sink to it: +-// spdlog::default_logger()->sinks().push_back(some_sink); ++// ds_spdlog::default_logger()->sinks().push_back(some_sink); + // +-// The default logger can replaced using spdlog::set_default_logger(new_logger). ++// The default logger can replaced using ds_spdlog::set_default_logger(new_logger). + // For example, to replace it with a file logger. + // + // IMPORTANT: + // The default API is thread safe (for _mt loggers), but: + // set_default_logger() *should not* be used concurrently with the default API. +-// e.g do not call set_default_logger() from one thread while calling spdlog::info() from another. ++// e.g do not call set_default_logger() from one thread while calling ds_spdlog::info() from another. + +-SPDLOG_API std::shared_ptr default_logger(); ++SPDLOG_API std::shared_ptr default_logger(); + +-SPDLOG_API spdlog::logger *default_logger_raw(); ++SPDLOG_API ds_spdlog::logger *default_logger_raw(); + +-SPDLOG_API void set_default_logger(std::shared_ptr default_logger); ++SPDLOG_API void set_default_logger(std::shared_ptr default_logger); + + // Initialize logger level based on environment configs. + // + // Useful for applying SPDLOG_LEVEL to manually created loggers. + // + // Example: +-// auto mylogger = std::make_shared("mylogger", ...); +-// spdlog::apply_logger_env_levels(mylogger); ++// auto mylogger = std::make_shared("mylogger", ...); ++// ds_spdlog::apply_logger_env_levels(mylogger); + SPDLOG_API void apply_logger_env_levels(std::shared_ptr logger); + + template +@@ -286,7 +286,7 @@ inline void critical(const T &msg) + default_logger_raw()->critical(msg); + } + +-} // namespace spdlog ++} // namespace ds_spdlog + + // + // enable/disable log calls at compile time according to global level. +@@ -303,54 +303,54 @@ inline void critical(const T &msg) + + #ifndef SPDLOG_NO_SOURCE_LOC + # define SPDLOG_LOGGER_CALL(logger, level, ...) \ +- (logger)->log(spdlog::source_loc{__FILE__, __LINE__, SPDLOG_FUNCTION}, level, __VA_ARGS__) ++ (logger)->log(ds_spdlog::source_loc{__FILE__, __LINE__, SPDLOG_FUNCTION}, level, __VA_ARGS__) + #else +-# define SPDLOG_LOGGER_CALL(logger, level, ...) (logger)->log(spdlog::source_loc{}, level, __VA_ARGS__) ++# define SPDLOG_LOGGER_CALL(logger, level, ...) (logger)->log(ds_spdlog::source_loc{}, level, __VA_ARGS__) + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_TRACE +-# define SPDLOG_LOGGER_TRACE(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::trace, __VA_ARGS__) +-# define SPDLOG_TRACE(...) SPDLOG_LOGGER_TRACE(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_TRACE(logger, ...) SPDLOG_LOGGER_CALL(logger, ds_spdlog::level::trace, __VA_ARGS__) ++# define SPDLOG_TRACE(...) SPDLOG_LOGGER_TRACE(ds_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_TRACE(logger, ...) (void)0 + # define SPDLOG_TRACE(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_DEBUG +-# define SPDLOG_LOGGER_DEBUG(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::debug, __VA_ARGS__) +-# define SPDLOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_DEBUG(logger, ...) SPDLOG_LOGGER_CALL(logger, ds_spdlog::level::debug, __VA_ARGS__) ++# define SPDLOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(ds_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_DEBUG(logger, ...) (void)0 + # define SPDLOG_DEBUG(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_INFO +-# define SPDLOG_LOGGER_INFO(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::info, __VA_ARGS__) +-# define SPDLOG_INFO(...) SPDLOG_LOGGER_INFO(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_INFO(logger, ...) SPDLOG_LOGGER_CALL(logger, ds_spdlog::level::info, __VA_ARGS__) ++# define SPDLOG_INFO(...) SPDLOG_LOGGER_INFO(ds_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_INFO(logger, ...) (void)0 + # define SPDLOG_INFO(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_WARN +-# define SPDLOG_LOGGER_WARN(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::warn, __VA_ARGS__) +-# define SPDLOG_WARN(...) SPDLOG_LOGGER_WARN(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_WARN(logger, ...) SPDLOG_LOGGER_CALL(logger, ds_spdlog::level::warn, __VA_ARGS__) ++# define SPDLOG_WARN(...) SPDLOG_LOGGER_WARN(ds_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_WARN(logger, ...) (void)0 + # define SPDLOG_WARN(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_ERROR +-# define SPDLOG_LOGGER_ERROR(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::err, __VA_ARGS__) +-# define SPDLOG_ERROR(...) SPDLOG_LOGGER_ERROR(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_ERROR(logger, ...) SPDLOG_LOGGER_CALL(logger, ds_spdlog::level::err, __VA_ARGS__) ++# define SPDLOG_ERROR(...) SPDLOG_LOGGER_ERROR(ds_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_ERROR(logger, ...) (void)0 + # define SPDLOG_ERROR(...) (void)0 + #endif + + #if SPDLOG_ACTIVE_LEVEL <= SPDLOG_LEVEL_CRITICAL +-# define SPDLOG_LOGGER_CRITICAL(logger, ...) SPDLOG_LOGGER_CALL(logger, spdlog::level::critical, __VA_ARGS__) +-# define SPDLOG_CRITICAL(...) SPDLOG_LOGGER_CRITICAL(spdlog::default_logger_raw(), __VA_ARGS__) ++# define SPDLOG_LOGGER_CRITICAL(logger, ...) SPDLOG_LOGGER_CALL(logger, ds_spdlog::level::critical, __VA_ARGS__) ++# define SPDLOG_CRITICAL(...) SPDLOG_LOGGER_CRITICAL(ds_spdlog::default_logger_raw(), __VA_ARGS__) + #else + # define SPDLOG_LOGGER_CRITICAL(logger, ...) (void)0 + # define SPDLOG_CRITICAL(...) (void)0 +diff --git a/include/spdlog/stopwatch.h b/include/spdlog/stopwatch.h +index bea7b8a7..b68efdb2 100644 +--- a/include/spdlog/stopwatch.h ++++ b/include/spdlog/stopwatch.h +@@ -11,10 +11,10 @@ + // + // Usage: + // +-// spdlog::stopwatch sw; ++// ds_spdlog::stopwatch sw; + // ... +-// spdlog::debug("Elapsed: {} seconds", sw); => "Elapsed 0.005116733 seconds" +-// spdlog::info("Elapsed: {:.6} seconds", sw); => "Elapsed 0.005163 seconds" ++// ds_spdlog::debug("Elapsed: {} seconds", sw); => "Elapsed 0.005116733 seconds" ++// ds_spdlog::info("Elapsed: {:.6} seconds", sw); => "Elapsed 0.005163 seconds" + // + // + // If other units are needed (e.g. millis instead of double), include "fmt/chrono.h" and use "duration_cast<..>(sw.elapsed())": +@@ -23,9 +23,9 @@ + //.. + // using std::chrono::duration_cast; + // using std::chrono::milliseconds; +-// spdlog::info("Elapsed {}", duration_cast(sw.elapsed())); => "Elapsed 5ms" ++// ds_spdlog::info("Elapsed {}", duration_cast(sw.elapsed())); => "Elapsed 5ms" + +-namespace spdlog { ++namespace ds_spdlog { + class stopwatch + { + using clock = std::chrono::steady_clock; +@@ -46,7 +46,7 @@ public: + start_tp_ = clock::now(); + } + }; +-} // namespace spdlog ++} // namespace ds_spdlog + + // Support for fmt formatting (e.g. "{:012.9}" or just "{}") + namespace +@@ -58,10 +58,10 @@ namespace + { + + template<> +-struct formatter : formatter ++struct formatter : formatter + { + template +- auto format(const spdlog::stopwatch &sw, FormatContext &ctx) const -> decltype(ctx.out()) ++ auto format(const ds_spdlog::stopwatch &sw, FormatContext &ctx) const -> decltype(ctx.out()) + { + return formatter::format(sw.elapsed().count(), ctx); + } +diff --git a/src/color_sinks.cpp b/src/color_sinks.cpp +index 38fa308c..71f9171c 100644 +--- a/src/color_sinks.cpp ++++ b/src/color_sinks.cpp +@@ -14,38 +14,38 @@ + // + #ifdef _WIN32 + # include +-template class SPDLOG_API spdlog::sinks::wincolor_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stderr_sink; +-template class SPDLOG_API spdlog::sinks::wincolor_stderr_sink; ++template class SPDLOG_API ds_spdlog::sinks::wincolor_sink; ++template class SPDLOG_API ds_spdlog::sinks::wincolor_sink; ++template class SPDLOG_API ds_spdlog::sinks::wincolor_stdout_sink; ++template class SPDLOG_API ds_spdlog::sinks::wincolor_stdout_sink; ++template class SPDLOG_API ds_spdlog::sinks::wincolor_stderr_sink; ++template class SPDLOG_API ds_spdlog::sinks::wincolor_stderr_sink; + #else + # include "spdlog/sinks/ansicolor_sink-inl.h" +-template class SPDLOG_API spdlog::sinks::ansicolor_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stdout_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stderr_sink; +-template class SPDLOG_API spdlog::sinks::ansicolor_stderr_sink; ++template class SPDLOG_API ds_spdlog::sinks::ansicolor_sink; ++template class SPDLOG_API ds_spdlog::sinks::ansicolor_sink; ++template class SPDLOG_API ds_spdlog::sinks::ansicolor_stdout_sink; ++template class SPDLOG_API ds_spdlog::sinks::ansicolor_stdout_sink; ++template class SPDLOG_API ds_spdlog::sinks::ansicolor_stderr_sink; ++template class SPDLOG_API ds_spdlog::sinks::ansicolor_stderr_sink; + #endif + + // factory methods for color loggers + #include "spdlog/sinks/stdout_color_sinks-inl.h" +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_mt( ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_st( ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_color_st( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_mt( ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_st( ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_color_st( + const std::string &logger_name, color_mode mode); + +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_mt( ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stdout_color_st( ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_color_st( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_mt( ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_color_mt( + const std::string &logger_name, color_mode mode); +-template SPDLOG_API std::shared_ptr spdlog::stderr_color_st( ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_color_st( + const std::string &logger_name, color_mode mode); +diff --git a/src/file_sinks.cpp b/src/file_sinks.cpp +index 10ffba60..e29bcbd6 100644 +--- a/src/file_sinks.cpp ++++ b/src/file_sinks.cpp +@@ -12,9 +12,9 @@ + + #include + +-template class SPDLOG_API spdlog::sinks::basic_file_sink; +-template class SPDLOG_API spdlog::sinks::basic_file_sink; ++template class SPDLOG_API ds_spdlog::sinks::basic_file_sink; ++template class SPDLOG_API ds_spdlog::sinks::basic_file_sink; + + #include +-template class SPDLOG_API spdlog::sinks::rotating_file_sink; +-template class SPDLOG_API spdlog::sinks::rotating_file_sink; ++template class SPDLOG_API ds_spdlog::sinks::rotating_file_sink; ++template class SPDLOG_API ds_spdlog::sinks::rotating_file_sink; +diff --git a/src/spdlog.cpp b/src/spdlog.cpp +index c86d8fff..e848030a 100644 +--- a/src/spdlog.cpp ++++ b/src/spdlog.cpp +@@ -21,6 +21,6 @@ + #include + + // template instantiate logger constructor with sinks init list +-template SPDLOG_API spdlog::logger::logger(std::string name, sinks_init_list::iterator begin, sinks_init_list::iterator end); +-template class SPDLOG_API spdlog::sinks::base_sink; +-template class SPDLOG_API spdlog::sinks::base_sink; ++template SPDLOG_API ds_spdlog::logger::logger(std::string name, sinks_init_list::iterator begin, sinks_init_list::iterator end); ++template class SPDLOG_API ds_spdlog::sinks::base_sink; ++template class SPDLOG_API ds_spdlog::sinks::base_sink; +diff --git a/src/stdout_sinks.cpp b/src/stdout_sinks.cpp +index 2d5256a4..5d520bd6 100644 +--- a/src/stdout_sinks.cpp ++++ b/src/stdout_sinks.cpp +@@ -11,19 +11,19 @@ + #include + #include + +-template class SPDLOG_API spdlog::sinks::stdout_sink_base; +-template class SPDLOG_API spdlog::sinks::stdout_sink_base; +-template class SPDLOG_API spdlog::sinks::stdout_sink; +-template class SPDLOG_API spdlog::sinks::stdout_sink; +-template class SPDLOG_API spdlog::sinks::stderr_sink; +-template class SPDLOG_API spdlog::sinks::stderr_sink; ++template class SPDLOG_API ds_spdlog::sinks::stdout_sink_base; ++template class SPDLOG_API ds_spdlog::sinks::stdout_sink_base; ++template class SPDLOG_API ds_spdlog::sinks::stdout_sink; ++template class SPDLOG_API ds_spdlog::sinks::stdout_sink; ++template class SPDLOG_API ds_spdlog::sinks::stderr_sink; ++template class SPDLOG_API ds_spdlog::sinks::stderr_sink; + +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_st(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_logger_st(const std::string &logger_name); + +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stdout_logger_st(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_mt(const std::string &logger_name); +-template SPDLOG_API std::shared_ptr spdlog::stderr_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stdout_logger_st(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_logger_mt(const std::string &logger_name); ++template SPDLOG_API std::shared_ptr ds_spdlog::stderr_logger_st(const std::string &logger_name); +diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt +index 176578ad..5fc01654 100644 +--- a/tests/CMakeLists.txt ++++ b/tests/CMakeLists.txt +@@ -77,10 +77,10 @@ endfunction() + + # The compiled library tests + if(SPDLOG_BUILD_TESTS OR SPDLOG_BUILD_ALL) +- spdlog_prepare_test(spdlog-utests spdlog::spdlog) ++ spdlog_prepare_test(spdlog-utests ds_spdlog::spdlog) + endif() + + # The header-only library version tests + if(SPDLOG_BUILD_TESTS_HO OR SPDLOG_BUILD_ALL) +- spdlog_prepare_test(spdlog-utests-ho spdlog::spdlog_header_only) ++ spdlog_prepare_test(spdlog-utests-ho ds_spdlog::spdlog_header_only) + endif() +diff --git a/tests/test_async.cpp b/tests/test_async.cpp +index 06c5c921..cf741382 100644 +--- a/tests/test_async.cpp ++++ b/tests/test_async.cpp +@@ -7,13 +7,13 @@ + + TEST_CASE("basic async test ", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + size_t overrun_counter = 0; + size_t queue_size = 128; + size_t messages = 256; + { +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, ds_spdlog::async_overflow_policy::block); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message #{}", i); +@@ -28,13 +28,13 @@ TEST_CASE("basic async test ", "[async]") + + TEST_CASE("discard policy ", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + test_sink->set_delay(std::chrono::milliseconds(1)); + size_t queue_size = 4; + size_t messages = 1024; + +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::overrun_oldest); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, ds_spdlog::async_overflow_policy::overrun_oldest); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message"); +@@ -47,10 +47,10 @@ TEST_CASE("discard policy using factory ", "[async]") + { + size_t queue_size = 4; + size_t messages = 1024; +- spdlog::init_thread_pool(queue_size, 1); ++ ds_spdlog::init_thread_pool(queue_size, 1); + +- auto logger = spdlog::create_async_nb("as2"); +- auto test_sink = std::static_pointer_cast(logger->sinks()[0]); ++ auto logger = ds_spdlog::create_async_nb("as2"); ++ auto test_sink = std::static_pointer_cast(logger->sinks()[0]); + test_sink->set_delay(std::chrono::milliseconds(3)); + + for (size_t i = 0; i < messages; i++) +@@ -59,17 +59,17 @@ TEST_CASE("discard policy using factory ", "[async]") + } + + REQUIRE(test_sink->msg_counter() < messages); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("flush", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + size_t queue_size = 256; + size_t messages = 256; + { +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, ds_spdlog::async_overflow_policy::block); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message #{}", i); +@@ -85,24 +85,24 @@ TEST_CASE("flush", "[async]") + TEST_CASE("async periodic flush", "[async]") + { + +- auto logger = spdlog::create_async("as"); +- auto test_sink = std::static_pointer_cast(logger->sinks()[0]); ++ auto logger = ds_spdlog::create_async("as"); ++ auto test_sink = std::static_pointer_cast(logger->sinks()[0]); + +- spdlog::flush_every(std::chrono::seconds(1)); ++ ds_spdlog::flush_every(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::milliseconds(1700)); + REQUIRE(test_sink->flush_counter() == 1); +- spdlog::flush_every(std::chrono::seconds(0)); +- spdlog::drop_all(); ++ ds_spdlog::flush_every(std::chrono::seconds(0)); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("tp->wait_empty() ", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + test_sink->set_delay(std::chrono::milliseconds(5)); + size_t messages = 100; + +- auto tp = std::make_shared(messages, 2); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(messages, 2); ++ auto logger = std::make_shared("as", test_sink, tp, ds_spdlog::async_overflow_policy::block); + for (size_t i = 0; i < messages; i++) + { + logger->info("Hello message #{}", i); +@@ -116,13 +116,13 @@ TEST_CASE("tp->wait_empty() ", "[async]") + + TEST_CASE("multi threads", "[async]") + { +- auto test_sink = std::make_shared(); ++ auto test_sink = std::make_shared(); + size_t queue_size = 128; + size_t messages = 256; + size_t n_threads = 10; + { +- auto tp = std::make_shared(queue_size, 1); +- auto logger = std::make_shared("as", test_sink, tp, spdlog::async_overflow_policy::block); ++ auto tp = std::make_shared(queue_size, 1); ++ auto logger = std::make_shared("as", test_sink, tp, ds_spdlog::async_overflow_policy::block); + + std::vector threads; + for (size_t i = 0; i < n_threads; i++) +@@ -151,11 +151,11 @@ TEST_CASE("to_file", "[async]") + prepare_logdir(); + size_t messages = 1024; + size_t tp_threads = 1; +- spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); + { +- auto file_sink = std::make_shared(filename, true); +- auto tp = std::make_shared(messages, tp_threads); +- auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); ++ auto file_sink = std::make_shared(filename, true); ++ auto tp = std::make_shared(messages, tp_threads); ++ auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); + + for (size_t j = 0; j < messages; j++) + { +@@ -165,8 +165,8 @@ TEST_CASE("to_file", "[async]") + + require_message_count(TEST_FILENAME, messages); + auto contents = file_contents(TEST_FILENAME); +- using spdlog::details::os::default_eol; +- REQUIRE(ends_with(contents, spdlog::fmt_lib::format("Hello message #1023{}", default_eol))); ++ using ds_spdlog::details::os::default_eol; ++ REQUIRE(ends_with(contents, ds_spdlog::fmt_lib::format("Hello message #1023{}", default_eol))); + } + + TEST_CASE("to_file multi-workers", "[async]") +@@ -174,11 +174,11 @@ TEST_CASE("to_file multi-workers", "[async]") + prepare_logdir(); + size_t messages = 1024 * 10; + size_t tp_threads = 10; +- spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); + { +- auto file_sink = std::make_shared(filename, true); +- auto tp = std::make_shared(messages, tp_threads); +- auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); ++ auto file_sink = std::make_shared(filename, true); ++ auto tp = std::make_shared(messages, tp_threads); ++ auto logger = std::make_shared("as", std::move(file_sink), std::move(tp)); + + for (size_t j = 0; j < messages; j++) + { +@@ -190,9 +190,9 @@ TEST_CASE("to_file multi-workers", "[async]") + + TEST_CASE("bad_tp", "[async]") + { +- auto test_sink = std::make_shared(); +- std::shared_ptr const empty_tp; +- auto logger = std::make_shared("as", test_sink, empty_tp); ++ auto test_sink = std::make_shared(); ++ std::shared_ptr const empty_tp; ++ auto logger = std::make_shared("as", test_sink, empty_tp); + logger->info("Please throw an exception"); + REQUIRE(test_sink->msg_counter() == 0); + } +diff --git a/tests/test_backtrace.cpp b/tests/test_backtrace.cpp +index 6cf9ec55..d2d5aa8d 100644 +--- a/tests/test_backtrace.cpp ++++ b/tests/test_backtrace.cpp +@@ -5,11 +5,11 @@ + TEST_CASE("bactrace1", "[bactrace]") + { + +- using spdlog::sinks::test_sink_st; ++ using ds_spdlog::sinks::test_sink_st; + auto test_sink = std::make_shared(); + size_t backtrace_size = 5; + +- spdlog::logger logger("test-backtrace", test_sink); ++ ds_spdlog::logger logger("test-backtrace", test_sink); + logger.set_pattern("%v"); + logger.enable_backtrace(backtrace_size); + +@@ -33,11 +33,11 @@ TEST_CASE("bactrace1", "[bactrace]") + + TEST_CASE("bactrace-empty", "[bactrace]") + { +- using spdlog::sinks::test_sink_st; ++ using ds_spdlog::sinks::test_sink_st; + auto test_sink = std::make_shared(); + size_t backtrace_size = 5; + +- spdlog::logger logger("test-backtrace", test_sink); ++ ds_spdlog::logger logger("test-backtrace", test_sink); + logger.set_pattern("%v"); + logger.enable_backtrace(backtrace_size); + logger.dump_backtrace(); +@@ -46,14 +46,14 @@ TEST_CASE("bactrace-empty", "[bactrace]") + + TEST_CASE("bactrace-async", "[bactrace]") + { +- using spdlog::sinks::test_sink_mt; ++ using ds_spdlog::sinks::test_sink_mt; + auto test_sink = std::make_shared(); +- using spdlog::details::os::sleep_for_millis; ++ using ds_spdlog::details::os::sleep_for_millis; + + size_t backtrace_size = 5; + +- spdlog::init_thread_pool(120, 1); +- auto logger = std::make_shared("test-bactrace-async", test_sink, spdlog::thread_pool()); ++ ds_spdlog::init_thread_pool(120, 1); ++ auto logger = std::make_shared("test-bactrace-async", test_sink, ds_spdlog::thread_pool()); + logger->set_pattern("%v"); + logger->enable_backtrace(backtrace_size); + +diff --git a/tests/test_bin_to_hex.cpp b/tests/test_bin_to_hex.cpp +index 3c50c74a..36bc11f3 100644 +--- a/tests/test_bin_to_hex.cpp ++++ b/tests/test_bin_to_hex.cpp +@@ -5,89 +5,89 @@ + TEST_CASE("to_hex", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ ds_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0xc, 0xff, 0xff}; +- oss_logger.info("{}", spdlog::to_hex(v)); ++ oss_logger.info("{}", ds_spdlog::to_hex(v)); + + auto output = oss.str(); +- REQUIRE(ends_with(output, "0000: 09 0a 0b 0c ff ff" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(output, "0000: 09 0a 0b 0c ff ff" + std::string(ds_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_upper", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ ds_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0xc, 0xff, 0xff}; +- oss_logger.info("{:X}", spdlog::to_hex(v)); ++ oss_logger.info("{:X}", ds_spdlog::to_hex(v)); + + auto output = oss.str(); +- REQUIRE(ends_with(output, "0000: 09 0A 0B 0C FF FF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(output, "0000: 09 0A 0B 0C FF FF" + std::string(ds_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_no_delimiter", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ ds_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0xc, 0xff, 0xff}; +- oss_logger.info("{:sX}", spdlog::to_hex(v)); ++ oss_logger.info("{:sX}", ds_spdlog::to_hex(v)); + + auto output = oss.str(); +- REQUIRE(ends_with(output, "0000: 090A0B0CFFFF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(output, "0000: 090A0B0CFFFF" + std::string(ds_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_show_ascii", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ ds_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0x41, 0xc, 0x4b, 0xff, 0xff}; +- oss_logger.info("{:Xsa}", spdlog::to_hex(v, 8)); ++ oss_logger.info("{:Xsa}", ds_spdlog::to_hex(v, 8)); + +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(ds_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_different_size_per_line", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ ds_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0x41, 0xc, 0x4b, 0xff, 0xff}; + +- oss_logger.info("{:Xsa}", spdlog::to_hex(v, 10)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(spdlog::details::os::default_eol))); ++ oss_logger.info("{:Xsa}", ds_spdlog::to_hex(v, 10)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF ...A.K.." + std::string(ds_spdlog::details::os::default_eol))); + +- oss_logger.info("{:Xs}", spdlog::to_hex(v, 10)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(spdlog::details::os::default_eol))); ++ oss_logger.info("{:Xs}", ds_spdlog::to_hex(v, 10)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(ds_spdlog::details::os::default_eol))); + +- oss_logger.info("{:Xsa}", spdlog::to_hex(v, 6)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B ...A.K" + std::string(spdlog::details::os::default_eol) + "0006: FFFF .." + +- std::string(spdlog::details::os::default_eol))); ++ oss_logger.info("{:Xsa}", ds_spdlog::to_hex(v, 6)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B ...A.K" + std::string(ds_spdlog::details::os::default_eol) + "0006: FFFF .." + ++ std::string(ds_spdlog::details::os::default_eol))); + +- oss_logger.info("{:Xs}", spdlog::to_hex(v, 6)); +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B" + std::string(spdlog::details::os::default_eol) + "0006: FFFF" + +- std::string(spdlog::details::os::default_eol))); ++ oss_logger.info("{:Xs}", ds_spdlog::to_hex(v, 6)); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4B" + std::string(ds_spdlog::details::os::default_eol) + "0006: FFFF" + ++ std::string(ds_spdlog::details::os::default_eol))); + } + + TEST_CASE("to_hex_no_ascii", "[to_hex]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("oss", oss_sink); ++ auto oss_sink = std::make_shared(oss); ++ ds_spdlog::logger oss_logger("oss", oss_sink); + + std::vector v{9, 0xa, 0xb, 0x41, 0xc, 0x4b, 0xff, 0xff}; +- oss_logger.info("{:Xs}", spdlog::to_hex(v, 8)); ++ oss_logger.info("{:Xs}", ds_spdlog::to_hex(v, 8)); + +- REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(oss.str(), "0000: 090A0B410C4BFFFF" + std::string(ds_spdlog::details::os::default_eol))); + +- oss_logger.info("{:Xsna}", spdlog::to_hex(v, 8)); ++ oss_logger.info("{:Xsna}", ds_spdlog::to_hex(v, 8)); + +- REQUIRE(ends_with(oss.str(), "090A0B410C4BFFFF" + std::string(spdlog::details::os::default_eol))); ++ REQUIRE(ends_with(oss.str(), "090A0B410C4BFFFF" + std::string(ds_spdlog::details::os::default_eol))); + } +diff --git a/tests/test_cfg.cpp b/tests/test_cfg.cpp +index 11aefa20..ec3ec622 100644 +--- a/tests/test_cfg.cpp ++++ b/tests/test_cfg.cpp +@@ -5,179 +5,179 @@ + #include + #include + +-using spdlog::cfg::load_argv_levels; +-using spdlog::cfg::load_env_levels; +-using spdlog::sinks::test_sink_st; ++using ds_spdlog::cfg::load_argv_levels; ++using ds_spdlog::cfg::load_env_levels; ++using ds_spdlog::sinks::test_sink_st; + + TEST_CASE("env", "[cfg]") + { +- spdlog::drop("l1"); +- auto l1 = spdlog::create("l1"); ++ ds_spdlog::drop("l1"); ++ auto l1 = ds_spdlog::create("l1"); + #ifdef CATCH_PLATFORM_WINDOWS + _putenv_s("SPDLOG_LEVEL", "l1=warn"); + #else + setenv("SPDLOG_LEVEL", "l1=warn", 1); + #endif + load_env_levels(); +- REQUIRE(l1->level() == spdlog::level::warn); +- spdlog::set_default_logger(spdlog::create("cfg-default")); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(l1->level() == ds_spdlog::level::warn); ++ ds_spdlog::set_default_logger(ds_spdlog::create("cfg-default")); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::info); + } + + TEST_CASE("argv1", "[cfg]") + { +- spdlog::drop("l1"); ++ ds_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=warn"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ auto l1 = ds_spdlog::create("l1"); ++ REQUIRE(l1->level() == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::info); + } + + TEST_CASE("argv2", "[cfg]") + { +- spdlog::drop("l1"); ++ ds_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=warn,trace"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::trace); ++ auto l1 = ds_spdlog::create("l1"); ++ REQUIRE(l1->level() == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::trace); + } + + TEST_CASE("argv3", "[cfg]") + { +- spdlog::set_level(spdlog::level::trace); ++ ds_spdlog::set_level(ds_spdlog::level::trace); + +- spdlog::drop("l1"); ++ ds_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=junk_name=warn"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::trace); ++ auto l1 = ds_spdlog::create("l1"); ++ REQUIRE(l1->level() == ds_spdlog::level::trace); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::trace); + } + + TEST_CASE("argv4", "[cfg]") + { +- spdlog::set_level(spdlog::level::info); +- spdlog::drop("l1"); ++ ds_spdlog::set_level(ds_spdlog::level::info); ++ ds_spdlog::drop("l1"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=junk"}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::info); ++ auto l1 = ds_spdlog::create("l1"); ++ REQUIRE(l1->level() == ds_spdlog::level::info); + } + + TEST_CASE("argv5", "[cfg]") + { +- spdlog::set_level(spdlog::level::info); +- spdlog::drop("l1"); ++ ds_spdlog::set_level(ds_spdlog::level::info); ++ ds_spdlog::drop("l1"); + const char *argv[] = {"ignore", "ignore", "SPDLOG_LEVEL=l1=warn,trace"}; + load_argv_levels(3, argv); +- auto l1 = spdlog::create("l1"); +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::trace); +- spdlog::set_level(spdlog::level::info); ++ auto l1 = ds_spdlog::create("l1"); ++ REQUIRE(l1->level() == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::trace); ++ ds_spdlog::set_level(ds_spdlog::level::info); + } + + TEST_CASE("argv6", "[cfg]") + { +- spdlog::set_level(spdlog::level::err); ++ ds_spdlog::set_level(ds_spdlog::level::err); + const char *argv[] = {""}; + load_argv_levels(1, argv); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::err); +- spdlog::set_level(spdlog::level::info); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::err); ++ ds_spdlog::set_level(ds_spdlog::level::info); + } + + TEST_CASE("argv7", "[cfg]") + { +- spdlog::set_level(spdlog::level::err); ++ ds_spdlog::set_level(ds_spdlog::level::err); + const char *argv[] = {""}; + load_argv_levels(0, argv); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::err); +- spdlog::set_level(spdlog::level::info); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::err); ++ ds_spdlog::set_level(ds_spdlog::level::info); + } + + TEST_CASE("level-not-set-test1", "[cfg]") + { +- spdlog::drop("l1"); ++ ds_spdlog::drop("l1"); + const char *argv[] = {"ignore", ""}; + load_argv_levels(2, argv); +- auto l1 = spdlog::create("l1"); +- l1->set_level(spdlog::level::trace); +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ auto l1 = ds_spdlog::create("l1"); ++ l1->set_level(ds_spdlog::level::trace); ++ REQUIRE(l1->level() == ds_spdlog::level::trace); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::info); + } + + TEST_CASE("level-not-set-test2", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ ds_spdlog::drop("l1"); ++ ds_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=trace"}; + +- auto l1 = spdlog::create("l1"); +- l1->set_level(spdlog::level::warn); +- auto l2 = spdlog::create("l2"); +- l2->set_level(spdlog::level::warn); ++ auto l1 = ds_spdlog::create("l1"); ++ l1->set_level(ds_spdlog::level::warn); ++ auto l2 = ds_spdlog::create("l2"); ++ l2->set_level(ds_spdlog::level::warn); + + load_argv_levels(2, argv); + +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(l2->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(l1->level() == ds_spdlog::level::trace); ++ REQUIRE(l2->level() == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::info); + } + + TEST_CASE("level-not-set-test3", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ ds_spdlog::drop("l1"); ++ ds_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=trace"}; + + load_argv_levels(2, argv); + +- auto l1 = spdlog::create("l1"); +- auto l2 = spdlog::create("l2"); ++ auto l1 = ds_spdlog::create("l1"); ++ auto l2 = ds_spdlog::create("l2"); + +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(l2->level() == spdlog::level::info); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(l1->level() == ds_spdlog::level::trace); ++ REQUIRE(l2->level() == ds_spdlog::level::info); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::info); + } + + TEST_CASE("level-not-set-test4", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ ds_spdlog::drop("l1"); ++ ds_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=trace,warn"}; + + load_argv_levels(2, argv); + +- auto l1 = spdlog::create("l1"); +- auto l2 = spdlog::create("l2"); ++ auto l1 = ds_spdlog::create("l1"); ++ auto l2 = ds_spdlog::create("l2"); + +- REQUIRE(l1->level() == spdlog::level::trace); +- REQUIRE(l2->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::warn); ++ REQUIRE(l1->level() == ds_spdlog::level::trace); ++ REQUIRE(l2->level() == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::warn); + } + + TEST_CASE("level-not-set-test5", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ ds_spdlog::drop("l1"); ++ ds_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=l1=junk,warn"}; + + load_argv_levels(2, argv); + +- auto l1 = spdlog::create("l1"); +- auto l2 = spdlog::create("l2"); ++ auto l1 = ds_spdlog::create("l1"); ++ auto l2 = ds_spdlog::create("l2"); + +- REQUIRE(l1->level() == spdlog::level::warn); +- REQUIRE(l2->level() == spdlog::level::warn); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::warn); ++ REQUIRE(l1->level() == ds_spdlog::level::warn); ++ REQUIRE(l2->level() == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::warn); + } + + TEST_CASE("restore-to-default", "[cfg]") + { +- spdlog::drop("l1"); +- spdlog::drop("l2"); ++ ds_spdlog::drop("l1"); ++ ds_spdlog::drop("l2"); + const char *argv[] = {"ignore", "SPDLOG_LEVEL=info"}; + load_argv_levels(2, argv); +- REQUIRE(spdlog::default_logger()->level() == spdlog::level::info); ++ REQUIRE(ds_spdlog::default_logger()->level() == ds_spdlog::level::info); + } +diff --git a/tests/test_create_dir.cpp b/tests/test_create_dir.cpp +index f17126bc..35288f5a 100644 +--- a/tests/test_create_dir.cpp ++++ b/tests/test_create_dir.cpp +@@ -3,10 +3,10 @@ + */ + #include "includes.h" + +-using spdlog::details::os::create_dir; +-using spdlog::details::os::path_exists; ++using ds_spdlog::details::os::create_dir; ++using ds_spdlog::details::os::path_exists; + +-bool try_create_dir(const spdlog::filename_t &path, const spdlog::filename_t &normalized_path) ++bool try_create_dir(const ds_spdlog::filename_t &path, const ds_spdlog::filename_t &normalized_path) + { + auto rv = create_dir(path); + REQUIRE(rv == true); +@@ -36,7 +36,7 @@ TEST_CASE("create_dir", "[create_dir]") + TEST_CASE("create_invalid_dir", "[create_dir]") + { + REQUIRE(create_dir(SPDLOG_FILENAME_T("")) == false); +- REQUIRE(create_dir(spdlog::filename_t{}) == false); ++ REQUIRE(create_dir(ds_spdlog::filename_t{}) == false); + #ifdef __linux__ + REQUIRE(create_dir("/proc/spdlog-utest") == false); + #endif +@@ -44,7 +44,7 @@ TEST_CASE("create_invalid_dir", "[create_dir]") + + TEST_CASE("dir_name", "[create_dir]") + { +- using spdlog::details::os::dir_name; ++ using ds_spdlog::details::os::dir_name; + REQUIRE(dir_name(SPDLOG_FILENAME_T("")).empty()); + REQUIRE(dir_name(SPDLOG_FILENAME_T("dir")).empty()); + +diff --git a/tests/test_custom_callbacks.cpp b/tests/test_custom_callbacks.cpp +index 78babd79..91ad1571 100644 +--- a/tests/test_custom_callbacks.cpp ++++ b/tests/test_custom_callbacks.cpp +@@ -10,16 +10,16 @@ + TEST_CASE("custom_callback_logger", "[custom_callback_logger]") + { + std::vector lines; +- spdlog::pattern_formatter formatter; +- auto callback_logger = std::make_shared([&](const spdlog::details::log_msg &msg) { +- spdlog::memory_buf_t formatted; ++ ds_spdlog::pattern_formatter formatter; ++ auto callback_logger = std::make_shared([&](const ds_spdlog::details::log_msg &msg) { ++ ds_spdlog::memory_buf_t formatted; + formatter.format(msg, formatted); +- auto eol_len = strlen(spdlog::details::os::default_eol); ++ auto eol_len = strlen(ds_spdlog::details::os::default_eol); + lines.emplace_back(formatted.begin(), formatted.end() - eol_len); + }); +- std::shared_ptr test_sink(new spdlog::sinks::test_sink_st); ++ std::shared_ptr test_sink(new ds_spdlog::sinks::test_sink_st); + +- spdlog::logger logger("test-callback", {callback_logger, test_sink}); ++ ds_spdlog::logger logger("test-callback", {callback_logger, test_sink}); + + logger.info("test message 1"); + logger.info("test message 2"); +@@ -30,5 +30,5 @@ TEST_CASE("custom_callback_logger", "[custom_callback_logger]") + REQUIRE(lines[0] == ref_lines[0]); + REQUIRE(lines[1] == ref_lines[1]); + REQUIRE(lines[2] == ref_lines[2]); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } +diff --git a/tests/test_daily_logger.cpp b/tests/test_daily_logger.cpp +index 82f28941..e6df10e5 100644 +--- a/tests/test_daily_logger.cpp ++++ b/tests/test_daily_logger.cpp +@@ -4,16 +4,16 @@ + #include "includes.h" + + #ifdef SPDLOG_USE_STD_FORMAT +-using filename_memory_buf_t = std::basic_string; ++using filename_memory_buf_t = std::basic_string; + #else +-using filename_memory_buf_t = fmt::basic_memory_buffer; ++using filename_memory_buf_t = fmt::basic_memory_buffer; + #endif + + #ifdef SPDLOG_WCHAR_FILENAMES + std::string filename_buf_to_utf8string(const filename_memory_buf_t &w) + { +- spdlog::memory_buf_t buf; +- spdlog::details::os::wstr_to_utf8buf(spdlog::wstring_view_t(w.data(), w.size()), buf); ++ ds_spdlog::memory_buf_t buf; ++ ds_spdlog::details::os::wstr_to_utf8buf(ds_spdlog::wstring_view_t(w.data(), w.size()), buf); + return SPDLOG_BUF_TO_STRING(buf); + } + #else +@@ -25,18 +25,18 @@ std::string filename_buf_to_utf8string(const filename_memory_buf_t &w) + + TEST_CASE("daily_logger with dateonly calculator", "[daily_logger]") + { +- using sink_type = spdlog::sinks::daily_file_sink; ++ using sink_type = ds_spdlog::sinks::daily_file_sink; + + prepare_logdir(); + + // calculate filename (time based) +- spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); +- std::tm tm = spdlog::details::os::localtime(); ++ ds_spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); ++ std::tm tm = ds_spdlog::details::os::localtime(); + filename_memory_buf_t w; +- spdlog::fmt_lib::format_to( ++ ds_spdlog::fmt_lib::format_to( + std::back_inserter(w), SPDLOG_FILENAME_T("{}_{:04d}-{:02d}-{:02d}"), basename, tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday); + +- auto logger = spdlog::create("logger", basename, 0, 0); ++ auto logger = ds_spdlog::create("logger", basename, 0, 0); + for (int i = 0; i < 10; ++i) + { + +@@ -49,10 +49,10 @@ TEST_CASE("daily_logger with dateonly calculator", "[daily_logger]") + + struct custom_daily_file_name_calculator + { +- static spdlog::filename_t calc_filename(const spdlog::filename_t &basename, const tm &now_tm) ++ static ds_spdlog::filename_t calc_filename(const ds_spdlog::filename_t &basename, const tm &now_tm) + { + filename_memory_buf_t w; +- spdlog::fmt_lib::format_to(std::back_inserter(w), SPDLOG_FILENAME_T("{}{:04d}{:02d}{:02d}"), basename, now_tm.tm_year + 1900, ++ ds_spdlog::fmt_lib::format_to(std::back_inserter(w), SPDLOG_FILENAME_T("{}{:04d}{:02d}{:02d}"), basename, now_tm.tm_year + 1900, + now_tm.tm_mon + 1, now_tm.tm_mday); + + return SPDLOG_BUF_TO_STRING(w); +@@ -61,18 +61,18 @@ struct custom_daily_file_name_calculator + + TEST_CASE("daily_logger with custom calculator", "[daily_logger]") + { +- using sink_type = spdlog::sinks::daily_file_sink; ++ using sink_type = ds_spdlog::sinks::daily_file_sink; + + prepare_logdir(); + + // calculate filename (time based) +- spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); +- std::tm tm = spdlog::details::os::localtime(); ++ ds_spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_dateonly"); ++ std::tm tm = ds_spdlog::details::os::localtime(); + filename_memory_buf_t w; +- spdlog::fmt_lib::format_to( ++ ds_spdlog::fmt_lib::format_to( + std::back_inserter(w), SPDLOG_FILENAME_T("{}{:04d}{:02d}{:02d}"), basename, tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday); + +- auto logger = spdlog::create("logger", basename, 0, 0); ++ auto logger = ds_spdlog::create("logger", basename, 0, 0); + for (int i = 0; i < 10; ++i) + { + logger->info("Test message {}", i); +@@ -89,19 +89,19 @@ TEST_CASE("daily_logger with custom calculator", "[daily_logger]") + + TEST_CASE("rotating_file_sink::calc_filename1", "[rotating_file_sink]") + { +- auto filename = spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 3); ++ auto filename = ds_spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 3); + REQUIRE(filename == SPDLOG_FILENAME_T("rotated.3.txt")); + } + + TEST_CASE("rotating_file_sink::calc_filename2", "[rotating_file_sink]") + { +- auto filename = spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated"), 3); ++ auto filename = ds_spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated"), 3); + REQUIRE(filename == SPDLOG_FILENAME_T("rotated.3")); + } + + TEST_CASE("rotating_file_sink::calc_filename3", "[rotating_file_sink]") + { +- auto filename = spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 0); ++ auto filename = ds_spdlog::sinks::rotating_file_sink_st::calc_filename(SPDLOG_FILENAME_T("rotated.txt"), 0); + REQUIRE(filename == SPDLOG_FILENAME_T("rotated.txt")); + } + +@@ -114,43 +114,43 @@ TEST_CASE("daily_file_sink::daily_filename_calculator", "[daily_file_sink]") + { + // daily_YYYY-MM-DD_hh-mm.txt + auto filename = +- spdlog::sinks::daily_filename_calculator::calc_filename(SPDLOG_FILENAME_T("daily.txt"), spdlog::details::os::localtime()); ++ ds_spdlog::sinks::daily_filename_calculator::calc_filename(SPDLOG_FILENAME_T("daily.txt"), ds_spdlog::details::os::localtime()); + // date regex based on https://www.regular-expressions.info/dates.html +- std::basic_regex re( ++ std::basic_regex re( + SPDLOG_FILENAME_T(R"(^daily_(19|20)\d\d-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])\.txt$)")); +- std::match_results match; ++ std::match_results match; + REQUIRE(std::regex_match(filename, match, re)); + } + #endif + + TEST_CASE("daily_file_sink::daily_filename_format_calculator", "[daily_file_sink]") + { +- std::tm tm = spdlog::details::os::localtime(); ++ std::tm tm = ds_spdlog::details::os::localtime(); + // example-YYYY-MM-DD.log +- auto filename = spdlog::sinks::daily_filename_format_calculator::calc_filename(SPDLOG_FILENAME_T("example-%Y-%m-%d.log"), tm); ++ auto filename = ds_spdlog::sinks::daily_filename_format_calculator::calc_filename(SPDLOG_FILENAME_T("example-%Y-%m-%d.log"), tm); + + REQUIRE(filename == +- spdlog::fmt_lib::format(SPDLOG_FILENAME_T("example-{:04d}-{:02d}-{:02d}.log"), tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday)); ++ ds_spdlog::fmt_lib::format(SPDLOG_FILENAME_T("example-{:04d}-{:02d}-{:02d}.log"), tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday)); + } + + /* Test removal of old files */ +-static spdlog::details::log_msg create_msg(std::chrono::seconds offset) ++static ds_spdlog::details::log_msg create_msg(std::chrono::seconds offset) + { +- using spdlog::log_clock; +- spdlog::details::log_msg msg{"test", spdlog::level::info, "Hello Message"}; ++ using ds_spdlog::log_clock; ++ ds_spdlog::details::log_msg msg{"test", ds_spdlog::level::info, "Hello Message"}; + msg.time = log_clock::now() + offset; + return msg; + } + + static void test_rotate(int days_to_run, uint16_t max_days, uint16_t expected_n_files) + { +- using spdlog::log_clock; +- using spdlog::details::log_msg; +- using spdlog::sinks::daily_file_sink_st; ++ using ds_spdlog::log_clock; ++ using ds_spdlog::details::log_msg; ++ using ds_spdlog::sinks::daily_file_sink_st; + + prepare_logdir(); + +- spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_rotate.txt"); ++ ds_spdlog::filename_t basename = SPDLOG_FILENAME_T("test_logs/daily_rotate.txt"); + daily_file_sink_st sink{basename, 2, 30, true, max_days}; + + // simulate messages with 24 intervals +diff --git a/tests/test_dup_filter.cpp b/tests/test_dup_filter.cpp +index 8ae2ee60..380c859c 100644 +--- a/tests/test_dup_filter.cpp ++++ b/tests/test_dup_filter.cpp +@@ -4,8 +4,8 @@ + + TEST_CASE("dup_filter_test1", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_st; +- using spdlog::sinks::test_sink_mt; ++ using ds_spdlog::sinks::dup_filter_sink_st; ++ using ds_spdlog::sinks::test_sink_mt; + + dup_filter_sink_st dup_sink{std::chrono::seconds{5}}; + auto test_sink = std::make_shared(); +@@ -13,7 +13,7 @@ TEST_CASE("dup_filter_test1", "[dup_filter_sink]") + + for (int i = 0; i < 10; i++) + { +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message1"}); + } + + REQUIRE(test_sink->msg_counter() == 1); +@@ -21,8 +21,8 @@ TEST_CASE("dup_filter_test1", "[dup_filter_sink]") + + TEST_CASE("dup_filter_test2", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_st; +- using spdlog::sinks::test_sink_mt; ++ using ds_spdlog::sinks::dup_filter_sink_st; ++ using ds_spdlog::sinks::test_sink_mt; + + dup_filter_sink_st dup_sink{std::chrono::seconds{0}}; + auto test_sink = std::make_shared(); +@@ -30,7 +30,7 @@ TEST_CASE("dup_filter_test2", "[dup_filter_sink]") + + for (int i = 0; i < 10; i++) + { +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message1"}); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + +@@ -39,8 +39,8 @@ TEST_CASE("dup_filter_test2", "[dup_filter_sink]") + + TEST_CASE("dup_filter_test3", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_st; +- using spdlog::sinks::test_sink_mt; ++ using ds_spdlog::sinks::dup_filter_sink_st; ++ using ds_spdlog::sinks::test_sink_mt; + + dup_filter_sink_st dup_sink{std::chrono::seconds{1}}; + auto test_sink = std::make_shared(); +@@ -48,8 +48,8 @@ TEST_CASE("dup_filter_test3", "[dup_filter_sink]") + + for (int i = 0; i < 10; i++) + { +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message2"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message1"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message2"}); + } + + REQUIRE(test_sink->msg_counter() == 20); +@@ -57,33 +57,33 @@ TEST_CASE("dup_filter_test3", "[dup_filter_sink]") + + TEST_CASE("dup_filter_test4", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_mt; +- using spdlog::sinks::test_sink_mt; ++ using ds_spdlog::sinks::dup_filter_sink_mt; ++ using ds_spdlog::sinks::test_sink_mt; + + dup_filter_sink_mt dup_sink{std::chrono::milliseconds{10}}; + auto test_sink = std::make_shared(); + dup_sink.add_sink(test_sink); + +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message"}); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message"}); + REQUIRE(test_sink->msg_counter() == 2); + } + + TEST_CASE("dup_filter_test5", "[dup_filter_sink]") + { +- using spdlog::sinks::dup_filter_sink_mt; +- using spdlog::sinks::test_sink_mt; ++ using ds_spdlog::sinks::dup_filter_sink_mt; ++ using ds_spdlog::sinks::test_sink_mt; + + dup_filter_sink_mt dup_sink{std::chrono::seconds{5}}; + auto test_sink = std::make_shared(); + test_sink->set_pattern("%v"); + dup_sink.add_sink(test_sink); + +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message1"}); +- dup_sink.log(spdlog::details::log_msg{"test", spdlog::level::info, "message2"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message1"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message1"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message1"}); ++ dup_sink.log(ds_spdlog::details::log_msg{"test", ds_spdlog::level::info, "message2"}); + + REQUIRE(test_sink->msg_counter() == 3); // skip 2 messages but log the "skipped.." message before message2 + REQUIRE(test_sink->lines()[1] == "Skipped 2 duplicate messages.."); +diff --git a/tests/test_errors.cpp b/tests/test_errors.cpp +index 78032482..6ef120f4 100644 +--- a/tests/test_errors.cpp ++++ b/tests/test_errors.cpp +@@ -8,10 +8,10 @@ + #define SIMPLE_LOG "test_logs/simple_log.txt" + #define SIMPLE_ASYNC_LOG "test_logs/simple_async_log.txt" + +-class failing_sink : public spdlog::sinks::base_sink ++class failing_sink : public ds_spdlog::sinks::base_sink + { + protected: +- void sink_it_(const spdlog::details::log_msg &) final ++ void sink_it_(const ds_spdlog::details::log_msg &) final + { + throw std::runtime_error("some error happened during log"); + } +@@ -28,24 +28,24 @@ struct custom_ex + TEST_CASE("default_error_handler", "[errors]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); + +- auto logger = spdlog::create("test-error", filename, true); ++ auto logger = ds_spdlog::create("test-error", filename, true); + logger->set_pattern("%v"); + logger->info(SPDLOG_FMT_RUNTIME("Test message {} {}"), 1); + logger->info("Test message {}", 2); + logger->flush(); +- using spdlog::details::os::default_eol; +- REQUIRE(file_contents(SIMPLE_LOG) == spdlog::fmt_lib::format("Test message 2{}", default_eol)); ++ using ds_spdlog::details::os::default_eol; ++ REQUIRE(file_contents(SIMPLE_LOG) == ds_spdlog::fmt_lib::format("Test message 2{}", default_eol)); + REQUIRE(count_lines(SIMPLE_LOG) == 1); + } + + TEST_CASE("custom_error_handler", "[errors]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); +- auto logger = spdlog::create("logger", filename, true); +- logger->flush_on(spdlog::level::info); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ auto logger = ds_spdlog::create("logger", filename, true); ++ logger->flush_on(ds_spdlog::level::info); + logger->set_error_handler([=](const std::string &) { throw custom_ex(); }); + logger->info("Good message #1"); + +@@ -57,16 +57,16 @@ TEST_CASE("custom_error_handler", "[errors]") + + TEST_CASE("default_error_handler2", "[errors]") + { +- spdlog::drop_all(); +- auto logger = spdlog::create("failed_logger"); ++ ds_spdlog::drop_all(); ++ auto logger = ds_spdlog::create("failed_logger"); + logger->set_error_handler([=](const std::string &) { throw custom_ex(); }); + REQUIRE_THROWS_AS(logger->info("Some message"), custom_ex); + } + + TEST_CASE("flush_error_handler", "[errors]") + { +- spdlog::drop_all(); +- auto logger = spdlog::create("failed_logger"); ++ ds_spdlog::drop_all(); ++ auto logger = ds_spdlog::create("failed_logger"); + logger->set_error_handler([=](const std::string &) { throw custom_ex(); }); + REQUIRE_THROWS_AS(logger->flush(), custom_ex); + } +@@ -77,10 +77,10 @@ TEST_CASE("async_error_handler", "[errors]") + prepare_logdir(); + std::string err_msg("log failed with some msg"); + +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_ASYNC_LOG); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_ASYNC_LOG); + { +- spdlog::init_thread_pool(128, 1); +- auto logger = spdlog::create_async("logger", filename, true); ++ ds_spdlog::init_thread_pool(128, 1); ++ auto logger = ds_spdlog::create_async("logger", filename, true); + logger->set_error_handler([=](const std::string &) { + std::ofstream ofs("test_logs/custom_err.txt"); + if (!ofs) +@@ -92,9 +92,9 @@ TEST_CASE("async_error_handler", "[errors]") + logger->info("Good message #1"); + logger->info(SPDLOG_FMT_RUNTIME("Bad format msg {} {}"), "xxx"); + logger->info("Good message #2"); +- spdlog::drop("logger"); // force logger to drain the queue and shutdown ++ ds_spdlog::drop("logger"); // force logger to drain the queue and shutdown + } +- spdlog::init_thread_pool(128, 1); ++ ds_spdlog::init_thread_pool(128, 1); + require_message_count(SIMPLE_ASYNC_LOG, 2); + REQUIRE(file_contents("test_logs/custom_err.txt") == err_msg); + } +@@ -106,9 +106,9 @@ TEST_CASE("async_error_handler2", "[errors]") + prepare_logdir(); + std::string err_msg("This is async handler error message"); + { +- spdlog::details::os::create_dir(SPDLOG_FILENAME_T("test_logs")); +- spdlog::init_thread_pool(128, 1); +- auto logger = spdlog::create_async("failed_logger"); ++ ds_spdlog::details::os::create_dir(SPDLOG_FILENAME_T("test_logs")); ++ ds_spdlog::init_thread_pool(128, 1); ++ auto logger = ds_spdlog::create_async("failed_logger"); + logger->set_error_handler([=](const std::string &) { + std::ofstream ofs("test_logs/custom_err2.txt"); + if (!ofs) +@@ -116,9 +116,9 @@ TEST_CASE("async_error_handler2", "[errors]") + ofs << err_msg; + }); + logger->info("Hello failure"); +- spdlog::drop("failed_logger"); // force logger to drain the queue and shutdown ++ ds_spdlog::drop("failed_logger"); // force logger to drain the queue and shutdown + } + +- spdlog::init_thread_pool(128, 1); ++ ds_spdlog::init_thread_pool(128, 1); + REQUIRE(file_contents("test_logs/custom_err2.txt") == err_msg); + } +diff --git a/tests/test_eventlog.cpp b/tests/test_eventlog.cpp +index 5253c5a7..f796f5a0 100644 +--- a/tests/test_eventlog.cpp ++++ b/tests/test_eventlog.cpp +@@ -46,16 +46,16 @@ static void test_single_print(std::function do_log, s + REQUIRE((expected_time_generated - record->TimeGenerated) <= 3u); + + std::string message_in_log(((char *)record + record->StringOffset)); +- REQUIRE(message_in_log == expected_contents + spdlog::details::os::default_eol); ++ REQUIRE(message_in_log == expected_contents + ds_spdlog::details::os::default_eol); + } + + TEST_CASE("eventlog", "[eventlog]") + { +- using namespace spdlog; ++ using namespace ds_spdlog; + + auto test_sink = std::make_shared(TEST_SOURCE); + +- spdlog::logger test_logger("eventlog", test_sink); ++ ds_spdlog::logger test_logger("eventlog", test_sink); + test_logger.set_level(level::trace); + + test_sink->set_pattern("%v"); +diff --git a/tests/test_file_helper.cpp b/tests/test_file_helper.cpp +index dd3ca4f8..b564b542 100644 +--- a/tests/test_file_helper.cpp ++++ b/tests/test_file_helper.cpp +@@ -5,12 +5,12 @@ + + #define TEST_FILENAME "test_logs/file_helper_test.txt" + +-using spdlog::details::file_helper; ++using ds_spdlog::details::file_helper; + + static void write_with_helper(file_helper &helper, size_t howmany) + { +- spdlog::memory_buf_t formatted; +- spdlog::fmt_lib::format_to(std::back_inserter(formatted), "{}", std::string(howmany, '1')); ++ ds_spdlog::memory_buf_t formatted; ++ ds_spdlog::fmt_lib::format_to(std::back_inserter(formatted), "{}", std::string(howmany, '1')); + helper.write(formatted); + helper.flush(); + } +@@ -20,7 +20,7 @@ TEST_CASE("file_helper_filename", "[file_helper::filename()]") + prepare_logdir(); + + file_helper helper; +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + helper.open(target_filename); + REQUIRE(helper.filename() == target_filename); + } +@@ -28,7 +28,7 @@ TEST_CASE("file_helper_filename", "[file_helper::filename()]") + TEST_CASE("file_helper_size", "[file_helper::size()]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + size_t expected_size = 123; + { + file_helper helper; +@@ -42,7 +42,7 @@ TEST_CASE("file_helper_size", "[file_helper::size()]") + TEST_CASE("file_helper_reopen", "[file_helper::reopen()]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + file_helper helper; + helper.open(target_filename); + write_with_helper(helper, 12); +@@ -54,7 +54,7 @@ TEST_CASE("file_helper_reopen", "[file_helper::reopen()]") + TEST_CASE("file_helper_reopen2", "[file_helper::reopen(false)]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + size_t expected_size = 14; + file_helper helper; + helper.open(target_filename); +@@ -64,15 +64,15 @@ TEST_CASE("file_helper_reopen2", "[file_helper::reopen(false)]") + REQUIRE(helper.size() == expected_size); + } + +-static void test_split_ext(const spdlog::filename_t::value_type *fname, const spdlog::filename_t::value_type *expect_base, +- const spdlog::filename_t::value_type *expect_ext) ++static void test_split_ext(const ds_spdlog::filename_t::value_type *fname, const ds_spdlog::filename_t::value_type *expect_base, ++ const ds_spdlog::filename_t::value_type *expect_ext) + { +- spdlog::filename_t filename(fname); +- spdlog::filename_t expected_base(expect_base); +- spdlog::filename_t expected_ext(expect_ext); ++ ds_spdlog::filename_t filename(fname); ++ ds_spdlog::filename_t expected_base(expect_base); ++ ds_spdlog::filename_t expected_ext(expect_ext); + +- spdlog::filename_t basename; +- spdlog::filename_t ext; ++ ds_spdlog::filename_t basename; ++ ds_spdlog::filename_t ext; + std::tie(basename, ext) = file_helper::split_by_extension(filename); + REQUIRE(basename == expected_base); + REQUIRE(ext == expected_ext); +@@ -111,32 +111,32 @@ TEST_CASE("file_event_handlers", "[file_helper]") + }; + prepare_logdir(); + +- spdlog::filename_t test_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t test_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + // define event handles that update vector of flags when called + std::vector events; +- spdlog::file_event_handlers handlers; +- handlers.before_open = [&](spdlog::filename_t filename) { ++ ds_spdlog::file_event_handlers handlers; ++ handlers.before_open = [&](ds_spdlog::filename_t filename) { + REQUIRE(filename == test_filename); + events.push_back(flags::before_open); + }; +- handlers.after_open = [&](spdlog::filename_t filename, std::FILE *fstream) { ++ handlers.after_open = [&](ds_spdlog::filename_t filename, std::FILE *fstream) { + REQUIRE(filename == test_filename); + REQUIRE(fstream); + fputs("after_open\n", fstream); + events.push_back(flags::after_open); + }; +- handlers.before_close = [&](spdlog::filename_t filename, std::FILE *fstream) { ++ handlers.before_close = [&](ds_spdlog::filename_t filename, std::FILE *fstream) { + REQUIRE(filename == test_filename); + REQUIRE(fstream); + fputs("before_close\n", fstream); + events.push_back(flags::before_close); + }; +- handlers.after_close = [&](spdlog::filename_t filename) { ++ handlers.after_close = [&](ds_spdlog::filename_t filename) { + REQUIRE(filename == test_filename); + events.push_back(flags::after_close); + }; + { +- spdlog::details::file_helper helper{handlers}; ++ ds_spdlog::details::file_helper helper{handlers}; + REQUIRE(events.empty()); + + helper.open(test_filename); +@@ -158,11 +158,11 @@ TEST_CASE("file_event_handlers", "[file_helper]") + TEST_CASE("file_helper_open", "[file_helper]") + { + prepare_logdir(); +- spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t target_filename = SPDLOG_FILENAME_T(TEST_FILENAME); + file_helper helper; + helper.open(target_filename); + helper.close(); + + target_filename += SPDLOG_FILENAME_T("/invalid"); +- REQUIRE_THROWS_AS(helper.open(target_filename), spdlog::spdlog_ex); ++ REQUIRE_THROWS_AS(helper.open(target_filename), ds_spdlog::spdlog_ex); + } +diff --git a/tests/test_file_logging.cpp b/tests/test_file_logging.cpp +index 7a7119ad..d2fbe092 100644 +--- a/tests/test_file_logging.cpp ++++ b/tests/test_file_logging.cpp +@@ -9,9 +9,9 @@ + TEST_CASE("simple_file_logger", "[simple_logger]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); + +- auto logger = spdlog::create("logger", filename); ++ auto logger = ds_spdlog::create("logger", filename); + logger->set_pattern("%v"); + + logger->info("Test message {}", 1); +@@ -19,19 +19,19 @@ TEST_CASE("simple_file_logger", "[simple_logger]") + + logger->flush(); + require_message_count(SIMPLE_LOG, 2); +- using spdlog::details::os::default_eol; +- REQUIRE(file_contents(SIMPLE_LOG) == spdlog::fmt_lib::format("Test message 1{}Test message 2{}", default_eol, default_eol)); ++ using ds_spdlog::details::os::default_eol; ++ REQUIRE(file_contents(SIMPLE_LOG) == ds_spdlog::fmt_lib::format("Test message 1{}Test message 2{}", default_eol, default_eol)); + } + + TEST_CASE("flush_on", "[flush_on]") + { + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(SIMPLE_LOG); + +- auto logger = spdlog::create("logger", filename); ++ auto logger = ds_spdlog::create("logger", filename); + logger->set_pattern("%v"); +- logger->set_level(spdlog::level::trace); +- logger->flush_on(spdlog::level::info); ++ logger->set_level(ds_spdlog::level::trace); ++ logger->flush_on(ds_spdlog::level::info); + logger->trace("Should not be flushed"); + REQUIRE(count_lines(SIMPLE_LOG) == 0); + +@@ -39,17 +39,17 @@ TEST_CASE("flush_on", "[flush_on]") + logger->info("Test message {}", 2); + + require_message_count(SIMPLE_LOG, 3); +- using spdlog::details::os::default_eol; ++ using ds_spdlog::details::os::default_eol; + REQUIRE(file_contents(SIMPLE_LOG) == +- spdlog::fmt_lib::format("Should not be flushed{}Test message 1{}Test message 2{}", default_eol, default_eol, default_eol)); ++ ds_spdlog::fmt_lib::format("Should not be flushed{}Test message 1{}Test message 2{}", default_eol, default_eol, default_eol)); + } + + TEST_CASE("rotating_file_logger1", "[rotating_logger]") + { + prepare_logdir(); + size_t max_size = 1024 * 10; +- spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); +- auto logger = spdlog::rotating_logger_mt("logger", basename, max_size, 0); ++ ds_spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); ++ auto logger = ds_spdlog::rotating_logger_mt("logger", basename, max_size, 0); + + for (int i = 0; i < 10; ++i) + { +@@ -64,21 +64,21 @@ TEST_CASE("rotating_file_logger2", "[rotating_logger]") + { + prepare_logdir(); + size_t max_size = 1024 * 10; +- spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); ++ ds_spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); + + { + // make an initial logger to create the first output file +- auto logger = spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); ++ auto logger = ds_spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); + for (int i = 0; i < 10; ++i) + { + logger->info("Test message {}", i); + } + // drop causes the logger destructor to be called, which is required so the + // next logger can rename the first output file. +- spdlog::drop(logger->name()); ++ ds_spdlog::drop(logger->name()); + } + +- auto logger = spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); ++ auto logger = ds_spdlog::rotating_logger_mt("logger", basename, max_size, 2, true); + for (int i = 0; i < 10; ++i) + { + logger->info("Test message {}", i); +@@ -104,6 +104,6 @@ TEST_CASE("rotating_file_logger3", "[rotating_logger]") + { + prepare_logdir(); + size_t max_size = 0; +- spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); +- REQUIRE_THROWS_AS(spdlog::rotating_logger_mt("logger", basename, max_size, 0), spdlog::spdlog_ex); ++ ds_spdlog::filename_t basename = SPDLOG_FILENAME_T(ROTATING_LOG); ++ REQUIRE_THROWS_AS(ds_spdlog::rotating_logger_mt("logger", basename, max_size, 0), ds_spdlog::spdlog_ex); + } +diff --git a/tests/test_fmt_helper.cpp b/tests/test_fmt_helper.cpp +index 52141902..b0973534 100644 +--- a/tests/test_fmt_helper.cpp ++++ b/tests/test_fmt_helper.cpp +@@ -1,13 +1,13 @@ + + #include "includes.h" + +-using spdlog::memory_buf_t; +-using spdlog::details::to_string_view; ++using ds_spdlog::memory_buf_t; ++using ds_spdlog::details::to_string_view; + + void test_pad2(int n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad2(n, buf); ++ ds_spdlog::details::fmt_helper::pad2(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +@@ -15,7 +15,7 @@ void test_pad2(int n, const char *expected) + void test_pad3(uint32_t n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad3(n, buf); ++ ds_spdlog::details::fmt_helper::pad3(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +@@ -23,7 +23,7 @@ void test_pad3(uint32_t n, const char *expected) + void test_pad6(std::size_t n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad6(n, buf); ++ ds_spdlog::details::fmt_helper::pad6(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +@@ -31,7 +31,7 @@ void test_pad6(std::size_t n, const char *expected) + void test_pad9(std::size_t n, const char *expected) + { + memory_buf_t buf; +- spdlog::details::fmt_helper::pad9(n, buf); ++ ds_spdlog::details::fmt_helper::pad9(n, buf); + + REQUIRE(to_string_view(buf) == expected); + } +diff --git a/tests/test_macros.cpp b/tests/test_macros.cpp +index 36537958..1b4249da 100644 +--- a/tests/test_macros.cpp ++++ b/tests/test_macros.cpp +@@ -14,30 +14,30 @@ TEST_CASE("debug and trace w/o format string", "[macros]") + { + + prepare_logdir(); +- spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); ++ ds_spdlog::filename_t filename = SPDLOG_FILENAME_T(TEST_FILENAME); + +- auto logger = spdlog::create("logger", filename); ++ auto logger = ds_spdlog::create("logger", filename); + logger->set_pattern("%v"); +- logger->set_level(spdlog::level::trace); ++ logger->set_level(ds_spdlog::level::trace); + + SPDLOG_LOGGER_TRACE(logger, "Test message 1"); + SPDLOG_LOGGER_DEBUG(logger, "Test message 2"); + logger->flush(); + +- using spdlog::details::os::default_eol; +- REQUIRE(ends_with(file_contents(TEST_FILENAME), spdlog::fmt_lib::format("Test message 2{}", default_eol))); ++ using ds_spdlog::details::os::default_eol; ++ REQUIRE(ends_with(file_contents(TEST_FILENAME), ds_spdlog::fmt_lib::format("Test message 2{}", default_eol))); + REQUIRE(count_lines(TEST_FILENAME) == 1); + +- auto orig_default_logger = spdlog::default_logger(); +- spdlog::set_default_logger(logger); ++ auto orig_default_logger = ds_spdlog::default_logger(); ++ ds_spdlog::set_default_logger(logger); + + SPDLOG_TRACE("Test message 3"); + SPDLOG_DEBUG("Test message {}", 4); + logger->flush(); + + require_message_count(TEST_FILENAME, 2); +- REQUIRE(ends_with(file_contents(TEST_FILENAME), spdlog::fmt_lib::format("Test message 4{}", default_eol))); +- spdlog::set_default_logger(std::move(orig_default_logger)); ++ REQUIRE(ends_with(file_contents(TEST_FILENAME), ds_spdlog::fmt_lib::format("Test message 4{}", default_eol))); ++ ds_spdlog::set_default_logger(std::move(orig_default_logger)); + } + + TEST_CASE("disable param evaluation", "[macros]") +@@ -47,7 +47,7 @@ TEST_CASE("disable param evaluation", "[macros]") + + TEST_CASE("pass logger pointer", "[macros]") + { +- auto logger = spdlog::create("refmacro"); ++ auto logger = ds_spdlog::create("refmacro"); + auto &ref = *logger; + SPDLOG_LOGGER_TRACE(&ref, "Test message 1"); + SPDLOG_LOGGER_DEBUG(&ref, "Test message 2"); +diff --git a/tests/test_misc.cpp b/tests/test_misc.cpp +index 9f3cb174..617a1933 100644 +--- a/tests/test_misc.cpp ++++ b/tests/test_misc.cpp +@@ -2,18 +2,18 @@ + #include "test_sink.h" + + template +-std::string log_info(const T &what, spdlog::level::level_enum logger_level = spdlog::level::info) ++std::string log_info(const T &what, ds_spdlog::level::level_enum logger_level = ds_spdlog::level::info) + { + + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); ++ auto oss_sink = std::make_shared(oss); + +- spdlog::logger oss_logger("oss", oss_sink); ++ ds_spdlog::logger oss_logger("oss", oss_sink); + oss_logger.set_level(logger_level); + oss_logger.set_pattern("%v"); + oss_logger.info(what); + +- return oss.str().substr(0, oss.str().length() - strlen(spdlog::details::os::default_eol)); ++ return oss.str().substr(0, oss.str().length() - strlen(ds_spdlog::details::os::default_eol)); + } + + TEST_CASE("basic_logging ", "[basic_logging]") +@@ -36,66 +36,66 @@ TEST_CASE("basic_logging ", "[basic_logging]") + + TEST_CASE("log_levels", "[log_levels]") + { +- REQUIRE(log_info("Hello", spdlog::level::err).empty()); +- REQUIRE(log_info("Hello", spdlog::level::critical).empty()); +- REQUIRE(log_info("Hello", spdlog::level::info) == "Hello"); +- REQUIRE(log_info("Hello", spdlog::level::debug) == "Hello"); +- REQUIRE(log_info("Hello", spdlog::level::trace) == "Hello"); ++ REQUIRE(log_info("Hello", ds_spdlog::level::err).empty()); ++ REQUIRE(log_info("Hello", ds_spdlog::level::critical).empty()); ++ REQUIRE(log_info("Hello", ds_spdlog::level::info) == "Hello"); ++ REQUIRE(log_info("Hello", ds_spdlog::level::debug) == "Hello"); ++ REQUIRE(log_info("Hello", ds_spdlog::level::trace) == "Hello"); + } + + TEST_CASE("level_to_string_view", "[convert_to_string_view") + { +- REQUIRE(spdlog::level::to_string_view(spdlog::level::trace) == "trace"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::debug) == "debug"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::info) == "info"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::warn) == "warning"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::err) == "error"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::critical) == "critical"); +- REQUIRE(spdlog::level::to_string_view(spdlog::level::off) == "off"); ++ REQUIRE(ds_spdlog::level::to_string_view(ds_spdlog::level::trace) == "trace"); ++ REQUIRE(ds_spdlog::level::to_string_view(ds_spdlog::level::debug) == "debug"); ++ REQUIRE(ds_spdlog::level::to_string_view(ds_spdlog::level::info) == "info"); ++ REQUIRE(ds_spdlog::level::to_string_view(ds_spdlog::level::warn) == "warning"); ++ REQUIRE(ds_spdlog::level::to_string_view(ds_spdlog::level::err) == "error"); ++ REQUIRE(ds_spdlog::level::to_string_view(ds_spdlog::level::critical) == "critical"); ++ REQUIRE(ds_spdlog::level::to_string_view(ds_spdlog::level::off) == "off"); + } + + TEST_CASE("to_short_c_str", "[convert_to_short_c_str]") + { +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::trace)) == "T"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::debug)) == "D"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::info)) == "I"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::warn)) == "W"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::err)) == "E"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::critical)) == "C"); +- REQUIRE(std::string(spdlog::level::to_short_c_str(spdlog::level::off)) == "O"); ++ REQUIRE(std::string(ds_spdlog::level::to_short_c_str(ds_spdlog::level::trace)) == "T"); ++ REQUIRE(std::string(ds_spdlog::level::to_short_c_str(ds_spdlog::level::debug)) == "D"); ++ REQUIRE(std::string(ds_spdlog::level::to_short_c_str(ds_spdlog::level::info)) == "I"); ++ REQUIRE(std::string(ds_spdlog::level::to_short_c_str(ds_spdlog::level::warn)) == "W"); ++ REQUIRE(std::string(ds_spdlog::level::to_short_c_str(ds_spdlog::level::err)) == "E"); ++ REQUIRE(std::string(ds_spdlog::level::to_short_c_str(ds_spdlog::level::critical)) == "C"); ++ REQUIRE(std::string(ds_spdlog::level::to_short_c_str(ds_spdlog::level::off)) == "O"); + } + + TEST_CASE("to_level_enum", "[convert_to_level_enum]") + { +- REQUIRE(spdlog::level::from_str("trace") == spdlog::level::trace); +- REQUIRE(spdlog::level::from_str("debug") == spdlog::level::debug); +- REQUIRE(spdlog::level::from_str("info") == spdlog::level::info); +- REQUIRE(spdlog::level::from_str("warning") == spdlog::level::warn); +- REQUIRE(spdlog::level::from_str("warn") == spdlog::level::warn); +- REQUIRE(spdlog::level::from_str("error") == spdlog::level::err); +- REQUIRE(spdlog::level::from_str("critical") == spdlog::level::critical); +- REQUIRE(spdlog::level::from_str("off") == spdlog::level::off); +- REQUIRE(spdlog::level::from_str("null") == spdlog::level::off); ++ REQUIRE(ds_spdlog::level::from_str("trace") == ds_spdlog::level::trace); ++ REQUIRE(ds_spdlog::level::from_str("debug") == ds_spdlog::level::debug); ++ REQUIRE(ds_spdlog::level::from_str("info") == ds_spdlog::level::info); ++ REQUIRE(ds_spdlog::level::from_str("warning") == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::level::from_str("warn") == ds_spdlog::level::warn); ++ REQUIRE(ds_spdlog::level::from_str("error") == ds_spdlog::level::err); ++ REQUIRE(ds_spdlog::level::from_str("critical") == ds_spdlog::level::critical); ++ REQUIRE(ds_spdlog::level::from_str("off") == ds_spdlog::level::off); ++ REQUIRE(ds_spdlog::level::from_str("null") == ds_spdlog::level::off); + } + + TEST_CASE("periodic flush", "[periodic_flush]") + { +- using spdlog::sinks::test_sink_mt; +- auto logger = spdlog::create("periodic_flush"); ++ using ds_spdlog::sinks::test_sink_mt; ++ auto logger = ds_spdlog::create("periodic_flush"); + auto test_sink = std::static_pointer_cast(logger->sinks()[0]); + +- spdlog::flush_every(std::chrono::seconds(1)); ++ ds_spdlog::flush_every(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::milliseconds(1250)); + REQUIRE(test_sink->flush_counter() == 1); +- spdlog::flush_every(std::chrono::seconds(0)); +- spdlog::drop_all(); ++ ds_spdlog::flush_every(std::chrono::seconds(0)); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("clone-logger", "[clone]") + { +- using spdlog::sinks::test_sink_mt; ++ using ds_spdlog::sinks::test_sink_mt; + auto test_sink = std::make_shared(); +- auto logger = std::make_shared("orig", test_sink); ++ auto logger = std::make_shared("orig", test_sink); + logger->set_pattern("%v"); + auto cloned = logger->clone("clone"); + +@@ -110,15 +110,15 @@ TEST_CASE("clone-logger", "[clone]") + REQUIRE(test_sink->lines()[0] == "Some message 1"); + REQUIRE(test_sink->lines()[1] == "Some message 2"); + +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("clone async", "[clone]") + { +- using spdlog::sinks::test_sink_st; +- spdlog::init_thread_pool(4, 1); ++ using ds_spdlog::sinks::test_sink_st; ++ ds_spdlog::init_thread_pool(4, 1); + auto test_sink = std::make_shared(); +- auto logger = std::make_shared("orig", test_sink, spdlog::thread_pool()); ++ auto logger = std::make_shared("orig", test_sink, ds_spdlog::thread_pool()); + logger->set_pattern("%v"); + auto cloned = logger->clone("clone"); + +@@ -130,51 +130,51 @@ TEST_CASE("clone async", "[clone]") + logger->info("Some message 1"); + cloned->info("Some message 2"); + +- spdlog::details::os::sleep_for_millis(100); ++ ds_spdlog::details::os::sleep_for_millis(100); + + REQUIRE(test_sink->lines().size() == 2); + REQUIRE(test_sink->lines()[0] == "Some message 1"); + REQUIRE(test_sink->lines()[1] == "Some message 2"); + +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("default logger API", "[default logger]") + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); ++ auto oss_sink = std::make_shared(oss); + +- spdlog::set_default_logger(std::make_shared("oss", oss_sink)); +- spdlog::set_pattern("*** %v"); ++ ds_spdlog::set_default_logger(std::make_shared("oss", oss_sink)); ++ ds_spdlog::set_pattern("*** %v"); + +- spdlog::default_logger()->set_level(spdlog::level::trace); +- spdlog::trace("hello trace"); +- REQUIRE(oss.str() == "*** hello trace" + std::string(spdlog::details::os::default_eol)); ++ ds_spdlog::default_logger()->set_level(ds_spdlog::level::trace); ++ ds_spdlog::trace("hello trace"); ++ REQUIRE(oss.str() == "*** hello trace" + std::string(ds_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::debug("hello debug"); +- REQUIRE(oss.str() == "*** hello debug" + std::string(spdlog::details::os::default_eol)); ++ ds_spdlog::debug("hello debug"); ++ REQUIRE(oss.str() == "*** hello debug" + std::string(ds_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::info("Hello"); +- REQUIRE(oss.str() == "*** Hello" + std::string(spdlog::details::os::default_eol)); ++ ds_spdlog::info("Hello"); ++ REQUIRE(oss.str() == "*** Hello" + std::string(ds_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::warn("Hello again {}", 2); +- REQUIRE(oss.str() == "*** Hello again 2" + std::string(spdlog::details::os::default_eol)); ++ ds_spdlog::warn("Hello again {}", 2); ++ REQUIRE(oss.str() == "*** Hello again 2" + std::string(ds_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::error(123); +- REQUIRE(oss.str() == "*** 123" + std::string(spdlog::details::os::default_eol)); ++ ds_spdlog::error(123); ++ REQUIRE(oss.str() == "*** 123" + std::string(ds_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::critical(std::string("some string")); +- REQUIRE(oss.str() == "*** some string" + std::string(spdlog::details::os::default_eol)); ++ ds_spdlog::critical(std::string("some string")); ++ REQUIRE(oss.str() == "*** some string" + std::string(ds_spdlog::details::os::default_eol)); + + oss.str(""); +- spdlog::set_level(spdlog::level::info); +- spdlog::debug("should not be logged"); ++ ds_spdlog::set_level(ds_spdlog::level::info); ++ ds_spdlog::debug("should not be logged"); + REQUIRE(oss.str().empty()); +- spdlog::drop_all(); +- spdlog::set_pattern("%v"); ++ ds_spdlog::drop_all(); ++ ds_spdlog::set_pattern("%v"); + } +diff --git a/tests/test_mpmc_q.cpp b/tests/test_mpmc_q.cpp +index 1540dcc8..c15df92c 100644 +--- a/tests/test_mpmc_q.cpp ++++ b/tests/test_mpmc_q.cpp +@@ -11,7 +11,7 @@ TEST_CASE("dequeue-empty-nowait", "[mpmc_blocking_q]") + { + size_t q_size = 100; + milliseconds tolerance_wait(20); +- spdlog::details::mpmc_blocking_queue q(q_size); ++ ds_spdlog::details::mpmc_blocking_queue q(q_size); + int popped_item = 0; + + auto start = test_clock::now(); +@@ -30,7 +30,7 @@ TEST_CASE("dequeue-empty-wait", "[mpmc_blocking_q]") + milliseconds wait_ms(250); + milliseconds tolerance_wait(250); + +- spdlog::details::mpmc_blocking_queue q(q_size); ++ ds_spdlog::details::mpmc_blocking_queue q(q_size); + int popped_item = 0; + auto start = test_clock::now(); + auto rv = q.dequeue_for(popped_item, wait_ms); +@@ -45,7 +45,7 @@ TEST_CASE("dequeue-empty-wait", "[mpmc_blocking_q]") + + TEST_CASE("dequeue-full-nowait", "[mpmc_blocking_q]") + { +- spdlog::details::mpmc_blocking_queue q(1); ++ ds_spdlog::details::mpmc_blocking_queue q(1); + q.enqueue(42); + + int item = 0; +@@ -55,7 +55,7 @@ TEST_CASE("dequeue-full-nowait", "[mpmc_blocking_q]") + + TEST_CASE("dequeue-full-wait", "[mpmc_blocking_q]") + { +- spdlog::details::mpmc_blocking_queue q(1); ++ ds_spdlog::details::mpmc_blocking_queue q(1); + q.enqueue(42); + + int item = 0; +@@ -67,7 +67,7 @@ TEST_CASE("enqueue_nowait", "[mpmc_blocking_q]") + { + + size_t q_size = 1; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ ds_spdlog::details::mpmc_blocking_queue q(q_size); + milliseconds tolerance_wait(10); + + q.enqueue(1); +@@ -85,7 +85,7 @@ TEST_CASE("enqueue_nowait", "[mpmc_blocking_q]") + TEST_CASE("bad_queue", "[mpmc_blocking_q]") + { + size_t q_size = 0; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ ds_spdlog::details::mpmc_blocking_queue q(q_size); + q.enqueue_nowait(1); + REQUIRE(q.overrun_counter() == 1); + int i = 0; +@@ -95,7 +95,7 @@ TEST_CASE("bad_queue", "[mpmc_blocking_q]") + TEST_CASE("empty_queue", "[mpmc_blocking_q]") + { + size_t q_size = 10; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ ds_spdlog::details::mpmc_blocking_queue q(q_size); + int i = 0; + REQUIRE(q.dequeue_for(i, milliseconds(10)) == false); + } +@@ -103,7 +103,7 @@ TEST_CASE("empty_queue", "[mpmc_blocking_q]") + TEST_CASE("full_queue", "[mpmc_blocking_q]") + { + size_t q_size = 100; +- spdlog::details::mpmc_blocking_queue q(q_size); ++ ds_spdlog::details::mpmc_blocking_queue q(q_size); + for (int i = 0; i < static_cast(q_size); i++) + { + q.enqueue(i + 0); // i+0 to force rvalue and avoid tidy warnings on the same time if we std::move(i) instead +diff --git a/tests/test_pattern_formatter.cpp b/tests/test_pattern_formatter.cpp +index bafea884..4f034223 100644 +--- a/tests/test_pattern_formatter.cpp ++++ b/tests/test_pattern_formatter.cpp +@@ -1,19 +1,19 @@ + #include "includes.h" + #include "test_sink.h" + +-using spdlog::memory_buf_t; +-using spdlog::details::to_string_view; ++using ds_spdlog::memory_buf_t; ++using ds_spdlog::details::to_string_view; + + // log to str and return it + template + static std::string log_to_str(const std::string &msg, const Args &...args) + { + std::ostringstream oss; +- auto oss_sink = std::make_shared(oss); +- spdlog::logger oss_logger("pattern_tester", oss_sink); +- oss_logger.set_level(spdlog::level::info); ++ auto oss_sink = std::make_shared(oss); ++ ds_spdlog::logger oss_logger("pattern_tester", oss_sink); ++ oss_logger.set_level(ds_spdlog::level::info); + +- oss_logger.set_formatter(std::unique_ptr(new spdlog::pattern_formatter(args...))); ++ oss_logger.set_formatter(std::unique_ptr(new ds_spdlog::pattern_formatter(args...))); + + oss_logger.info(msg); + return oss.str(); +@@ -23,75 +23,75 @@ TEST_CASE("custom eol", "[pattern_formatter]") + { + std::string msg = "Hello custom eol test"; + std::string eol = ";)"; +- REQUIRE(log_to_str(msg, "%v", spdlog::pattern_time_type::local, ";)") == msg + eol); ++ REQUIRE(log_to_str(msg, "%v", ds_spdlog::pattern_time_type::local, ";)") == msg + eol); + } + + TEST_CASE("empty format", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "", spdlog::pattern_time_type::local, "").empty()); ++ REQUIRE(log_to_str("Some message", "", ds_spdlog::pattern_time_type::local, "").empty()); + } + + TEST_CASE("empty format2", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "", spdlog::pattern_time_type::local, "\n") == "\n"); ++ REQUIRE(log_to_str("Some message", "", ds_spdlog::pattern_time_type::local, "\n") == "\n"); + } + + TEST_CASE("level", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%l] %v", spdlog::pattern_time_type::local, "\n") == "[info] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%l] %v", ds_spdlog::pattern_time_type::local, "\n") == "[info] Some message\n"); + } + + TEST_CASE("short level", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%L] %v", spdlog::pattern_time_type::local, "\n") == "[I] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%L] %v", ds_spdlog::pattern_time_type::local, "\n") == "[I] Some message\n"); + } + + TEST_CASE("name", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%n] %v", ds_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); + } + + TEST_CASE("date MM/DD/YY ", "[pattern_formatter]") + { +- auto now_tm = spdlog::details::os::localtime(); ++ auto now_tm = ds_spdlog::details::os::localtime(); + std::stringstream oss; + oss << std::setfill('0') << std::setw(2) << now_tm.tm_mon + 1 << "/" << std::setw(2) << now_tm.tm_mday << "/" << std::setw(2) + << (now_tm.tm_year + 1900) % 1000 << " Some message\n"; +- REQUIRE(log_to_str("Some message", "%D %v", spdlog::pattern_time_type::local, "\n") == oss.str()); ++ REQUIRE(log_to_str("Some message", "%D %v", ds_spdlog::pattern_time_type::local, "\n") == oss.str()); + } + + TEST_CASE("color range test1", "[pattern_formatter]") + { +- auto formatter = std::make_shared("%^%v%$", spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared("%^%v%$", ds_spdlog::pattern_time_type::local, "\n"); + + memory_buf_t buf; +- spdlog::fmt_lib::format_to(std::back_inserter(buf), "Hello"); ++ ds_spdlog::fmt_lib::format_to(std::back_inserter(buf), "Hello"); + memory_buf_t formatted; + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, spdlog::string_view_t(buf.data(), buf.size())); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, ds_spdlog::string_view_t(buf.data(), buf.size())); + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); + REQUIRE(msg.color_range_end == 5); +- REQUIRE(log_to_str("hello", "%^%v%$", spdlog::pattern_time_type::local, "\n") == "hello\n"); ++ REQUIRE(log_to_str("hello", "%^%v%$", ds_spdlog::pattern_time_type::local, "\n") == "hello\n"); + } + + TEST_CASE("color range test2", "[pattern_formatter]") + { +- auto formatter = std::make_shared("%^%$", spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared("%^%$", ds_spdlog::pattern_time_type::local, "\n"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, ""); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, ""); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); + REQUIRE(msg.color_range_end == 0); +- REQUIRE(log_to_str("", "%^%$", spdlog::pattern_time_type::local, "\n") == "\n"); ++ REQUIRE(log_to_str("", "%^%$", ds_spdlog::pattern_time_type::local, "\n") == "\n"); + } + + TEST_CASE("color range test3", "[pattern_formatter]") + { +- auto formatter = std::make_shared("%^***%$"); ++ auto formatter = std::make_shared("%^***%$"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "ignored"); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); +@@ -100,22 +100,22 @@ TEST_CASE("color range test3", "[pattern_formatter]") + + TEST_CASE("color range test4", "[pattern_formatter]") + { +- auto formatter = std::make_shared("XX%^YYY%$", spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared("XX%^YYY%$", ds_spdlog::pattern_time_type::local, "\n"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "ignored"); + + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 2); + REQUIRE(msg.color_range_end == 5); +- REQUIRE(log_to_str("ignored", "XX%^YYY%$", spdlog::pattern_time_type::local, "\n") == "XXYYY\n"); ++ REQUIRE(log_to_str("ignored", "XX%^YYY%$", ds_spdlog::pattern_time_type::local, "\n") == "XXYYY\n"); + } + + TEST_CASE("color range test5", "[pattern_formatter]") + { +- auto formatter = std::make_shared("**%^"); ++ auto formatter = std::make_shared("**%^"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "ignored"); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 2); +@@ -124,9 +124,9 @@ TEST_CASE("color range test5", "[pattern_formatter]") + + TEST_CASE("color range test6", "[pattern_formatter]") + { +- auto formatter = std::make_shared("**%$"); ++ auto formatter = std::make_shared("**%$"); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "ignored"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "ignored"); + memory_buf_t formatted; + formatter->format(msg, formatted); + REQUIRE(msg.color_range_start == 0); +@@ -139,73 +139,73 @@ TEST_CASE("color range test6", "[pattern_formatter]") + + TEST_CASE("level_left_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%8l] %v", spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%8!l] %v", spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%8l] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%8!l] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ info] Some message\n"); + } + + TEST_CASE("level_right_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-8l] %v", spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%-8!l] %v", spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-8l] %v", ds_spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-8!l] %v", ds_spdlog::pattern_time_type::local, "\n") == "[info ] Some message\n"); + } + + TEST_CASE("level_center_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%=8l] %v", spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%=8!l] %v", spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=8l] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=8!l] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ info ] Some message\n"); + } + + TEST_CASE("short level_left_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%3L] %v", spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%3!L] %v", spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3L] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3!L] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ I] Some message\n"); + } + + TEST_CASE("short level_right_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-3L] %v", spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%-3!L] %v", spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3L] %v", ds_spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3!L] %v", ds_spdlog::pattern_time_type::local, "\n") == "[I ] Some message\n"); + } + + TEST_CASE("short level_center_padded", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%=3L] %v", spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%=3!L] %v", spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3L] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3!L] %v", ds_spdlog::pattern_time_type::local, "\n") == "[ I ] Some message\n"); + } + + TEST_CASE("left_padded_short", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%3n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%3!n] %v", spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3n] %v", ds_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%3!n] %v", ds_spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); + } + + TEST_CASE("right_padded_short", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-3n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%-3!n] %v", spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3n] %v", ds_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%-3!n] %v", ds_spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); + } + + TEST_CASE("center_padded_short", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%=3n] %v", spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); +- REQUIRE(log_to_str("Some message", "[%=3!n] %v", spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3n] %v", ds_spdlog::pattern_time_type::local, "\n") == "[pattern_tester] Some message\n"); ++ REQUIRE(log_to_str("Some message", "[%=3!n] %v", ds_spdlog::pattern_time_type::local, "\n") == "[pat] Some message\n"); + } + + TEST_CASE("left_padded_huge", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-300n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-300n] %v", ds_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + +- REQUIRE(log_to_str("Some message", "[%-300!n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-300!n] %v", ds_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + } + + TEST_CASE("left_padded_max", "[pattern_formatter]") + { +- REQUIRE(log_to_str("Some message", "[%-64n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-64n] %v", ds_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + +- REQUIRE(log_to_str("Some message", "[%-64!n] %v", spdlog::pattern_time_type::local, "\n") == ++ REQUIRE(log_to_str("Some message", "[%-64!n] %v", ds_spdlog::pattern_time_type::local, "\n") == + "[pattern_tester ] Some message\n"); + } + +@@ -213,61 +213,61 @@ TEST_CASE("left_padded_max", "[pattern_formatter]") + + TEST_CASE("paddinng_truncate", "[pattern_formatter]") + { +- REQUIRE(log_to_str("123456", "%6!v", spdlog::pattern_time_type::local, "\n") == "123456\n"); +- REQUIRE(log_to_str("123456", "%5!v", spdlog::pattern_time_type::local, "\n") == "12345\n"); +- REQUIRE(log_to_str("123456", "%7!v", spdlog::pattern_time_type::local, "\n") == " 123456\n"); ++ REQUIRE(log_to_str("123456", "%6!v", ds_spdlog::pattern_time_type::local, "\n") == "123456\n"); ++ REQUIRE(log_to_str("123456", "%5!v", ds_spdlog::pattern_time_type::local, "\n") == "12345\n"); ++ REQUIRE(log_to_str("123456", "%7!v", ds_spdlog::pattern_time_type::local, "\n") == " 123456\n"); + +- REQUIRE(log_to_str("123456", "%-6!v", spdlog::pattern_time_type::local, "\n") == "123456\n"); +- REQUIRE(log_to_str("123456", "%-5!v", spdlog::pattern_time_type::local, "\n") == "12345\n"); +- REQUIRE(log_to_str("123456", "%-7!v", spdlog::pattern_time_type::local, "\n") == "123456 \n"); ++ REQUIRE(log_to_str("123456", "%-6!v", ds_spdlog::pattern_time_type::local, "\n") == "123456\n"); ++ REQUIRE(log_to_str("123456", "%-5!v", ds_spdlog::pattern_time_type::local, "\n") == "12345\n"); ++ REQUIRE(log_to_str("123456", "%-7!v", ds_spdlog::pattern_time_type::local, "\n") == "123456 \n"); + +- REQUIRE(log_to_str("123456", "%=6!v", spdlog::pattern_time_type::local, "\n") == "123456\n"); +- REQUIRE(log_to_str("123456", "%=5!v", spdlog::pattern_time_type::local, "\n") == "12345\n"); +- REQUIRE(log_to_str("123456", "%=7!v", spdlog::pattern_time_type::local, "\n") == "123456 \n"); ++ REQUIRE(log_to_str("123456", "%=6!v", ds_spdlog::pattern_time_type::local, "\n") == "123456\n"); ++ REQUIRE(log_to_str("123456", "%=5!v", ds_spdlog::pattern_time_type::local, "\n") == "12345\n"); ++ REQUIRE(log_to_str("123456", "%=7!v", ds_spdlog::pattern_time_type::local, "\n") == "123456 \n"); + +- REQUIRE(log_to_str("123456", "%0!v", spdlog::pattern_time_type::local, "\n") == "\n"); ++ REQUIRE(log_to_str("123456", "%0!v", ds_spdlog::pattern_time_type::local, "\n") == "\n"); + } + + TEST_CASE("padding_truncate_funcname", "[pattern_formatter]") + { +- spdlog::sinks::test_sink_st test_sink; ++ ds_spdlog::sinks::test_sink_st test_sink; + + const char *pattern = "%v [%5!!]"; +- auto formatter = std::unique_ptr(new spdlog::pattern_formatter(pattern)); ++ auto formatter = std::unique_ptr(new ds_spdlog::pattern_formatter(pattern)); + test_sink.set_formatter(std::move(formatter)); + +- spdlog::details::log_msg msg1{spdlog::source_loc{"ignored", 1, "func"}, "test_logger", spdlog::level::info, "message"}; ++ ds_spdlog::details::log_msg msg1{ds_spdlog::source_loc{"ignored", 1, "func"}, "test_logger", ds_spdlog::level::info, "message"}; + test_sink.log(msg1); + REQUIRE(test_sink.lines()[0] == "message [ func]"); + +- spdlog::details::log_msg msg2{spdlog::source_loc{"ignored", 1, "function"}, "test_logger", spdlog::level::info, "message"}; ++ ds_spdlog::details::log_msg msg2{ds_spdlog::source_loc{"ignored", 1, "function"}, "test_logger", ds_spdlog::level::info, "message"}; + test_sink.log(msg2); + REQUIRE(test_sink.lines()[1] == "message [funct]"); + } + + TEST_CASE("padding_funcname", "[pattern_formatter]") + { +- spdlog::sinks::test_sink_st test_sink; ++ ds_spdlog::sinks::test_sink_st test_sink; + + const char *pattern = "%v [%10!]"; +- auto formatter = std::unique_ptr(new spdlog::pattern_formatter(pattern)); ++ auto formatter = std::unique_ptr(new ds_spdlog::pattern_formatter(pattern)); + test_sink.set_formatter(std::move(formatter)); + +- spdlog::details::log_msg msg1{spdlog::source_loc{"ignored", 1, "func"}, "test_logger", spdlog::level::info, "message"}; ++ ds_spdlog::details::log_msg msg1{ds_spdlog::source_loc{"ignored", 1, "func"}, "test_logger", ds_spdlog::level::info, "message"}; + test_sink.log(msg1); + REQUIRE(test_sink.lines()[0] == "message [ func]"); + +- spdlog::details::log_msg msg2{spdlog::source_loc{"ignored", 1, "func567890123"}, "test_logger", spdlog::level::info, "message"}; ++ ds_spdlog::details::log_msg msg2{ds_spdlog::source_loc{"ignored", 1, "func567890123"}, "test_logger", ds_spdlog::level::info, "message"}; + test_sink.log(msg2); + REQUIRE(test_sink.lines()[1] == "message [func567890123]"); + } + + TEST_CASE("clone-default-formatter", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared(); ++ auto formatter_1 = std::make_shared(); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -279,10 +279,10 @@ TEST_CASE("clone-default-formatter", "[pattern_formatter]") + + TEST_CASE("clone-default-formatter2", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared("%+"); ++ auto formatter_1 = std::make_shared("%+"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -294,10 +294,10 @@ TEST_CASE("clone-default-formatter2", "[pattern_formatter]") + + TEST_CASE("clone-formatter", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared("%D %X [%] [%n] %v"); ++ auto formatter_1 = std::make_shared("%D %X [%] [%n] %v"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -309,11 +309,11 @@ TEST_CASE("clone-formatter", "[pattern_formatter]") + + TEST_CASE("clone-formatter-2", "[pattern_formatter]") + { +- using spdlog::pattern_time_type; +- auto formatter_1 = std::make_shared("%D %X [%] [%n] %v", pattern_time_type::utc, "xxxxxx\n"); ++ using ds_spdlog::pattern_time_type; ++ auto formatter_1 = std::make_shared("%D %X [%] [%n] %v", pattern_time_type::utc, "xxxxxx\n"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "test2"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; +@@ -323,29 +323,29 @@ TEST_CASE("clone-formatter-2", "[pattern_formatter]") + REQUIRE(to_string_view(formatted_1) == to_string_view(formatted_2)); + } + +-class custom_test_flag : public spdlog::custom_flag_formatter ++class custom_test_flag : public ds_spdlog::custom_flag_formatter + { + public: + explicit custom_test_flag(std::string txt) + : some_txt{std::move(txt)} + {} + +- void format(const spdlog::details::log_msg &, const std::tm &tm, spdlog::memory_buf_t &dest) override ++ void format(const ds_spdlog::details::log_msg &, const std::tm &tm, ds_spdlog::memory_buf_t &dest) override + { + if (some_txt == "throw_me") + { +- throw spdlog::spdlog_ex("custom_flag_exception_test"); ++ throw ds_spdlog::spdlog_ex("custom_flag_exception_test"); + } + else if (some_txt == "time") + { +- auto formatted = spdlog::fmt_lib::format("{:d}:{:02d}{:s}", tm.tm_hour % 12, tm.tm_min, tm.tm_hour / 12 ? "PM" : "AM"); ++ auto formatted = ds_spdlog::fmt_lib::format("{:d}:{:02d}{:s}", tm.tm_hour % 12, tm.tm_min, tm.tm_hour / 12 ? "PM" : "AM"); + dest.append(formatted.data(), formatted.data() + formatted.size()); + return; + } + some_txt = std::string(padinfo_.width_, ' ') + some_txt; + dest.append(some_txt.data(), some_txt.data() + some_txt.size()); + } +- spdlog::details::padding_info get_padding_info() ++ ds_spdlog::details::padding_info get_padding_info() + { + return padinfo_; + } +@@ -354,24 +354,24 @@ public: + + std::unique_ptr clone() const override + { +- return spdlog::details::make_unique(some_txt); ++ return ds_spdlog::details::make_unique(some_txt); + } + }; + // test clone with custom flag formatters + TEST_CASE("clone-custom_formatter", "[pattern_formatter]") + { +- auto formatter_1 = std::make_shared(); ++ auto formatter_1 = std::make_shared(); + formatter_1->add_flag('t', "custom_output").set_pattern("[%n] [%t] %v"); + auto formatter_2 = formatter_1->clone(); + std::string logger_name = "logger-name"; +- spdlog::details::log_msg msg(logger_name, spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(logger_name, ds_spdlog::level::info, "some message"); + + memory_buf_t formatted_1; + memory_buf_t formatted_2; + formatter_1->format(msg, formatted_1); + formatter_2->format(msg, formatted_2); + +- auto expected = spdlog::fmt_lib::format("[logger-name] [custom_output] some message{}", spdlog::details::os::default_eol); ++ auto expected = ds_spdlog::fmt_lib::format("[logger-name] [custom_output] some message{}", ds_spdlog::details::os::default_eol); + + REQUIRE(to_string_view(formatted_1) == expected); + REQUIRE(to_string_view(formatted_2) == expected); +@@ -389,11 +389,11 @@ static const char *const test_path = "/a/b//myfile.cpp"; + + TEST_CASE("short filename formatter-1", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%s", spdlog::pattern_time_type::local, ""); ++ ds_spdlog::pattern_formatter formatter("%s", ds_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{test_path, 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ ds_spdlog::source_loc source_loc{test_path, 123, "some_func()"}; ++ ds_spdlog::details::log_msg msg(source_loc, "logger-name", ds_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == "myfile.cpp"); +@@ -401,11 +401,11 @@ TEST_CASE("short filename formatter-1", "[pattern_formatter]") + + TEST_CASE("short filename formatter-2", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%s:%#", spdlog::pattern_time_type::local, ""); ++ ds_spdlog::pattern_formatter formatter("%s:%#", ds_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{"myfile.cpp", 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ ds_spdlog::source_loc source_loc{"myfile.cpp", 123, "some_func()"}; ++ ds_spdlog::details::log_msg msg(source_loc, "logger-name", ds_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == "myfile.cpp:123"); +@@ -413,11 +413,11 @@ TEST_CASE("short filename formatter-2", "[pattern_formatter]") + + TEST_CASE("short filename formatter-3", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%s %v", spdlog::pattern_time_type::local, ""); ++ ds_spdlog::pattern_formatter formatter("%s %v", ds_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{"", 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ ds_spdlog::source_loc source_loc{"", 123, "some_func()"}; ++ ds_spdlog::details::log_msg msg(source_loc, "logger-name", ds_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == " Hello"); +@@ -425,11 +425,11 @@ TEST_CASE("short filename formatter-3", "[pattern_formatter]") + + TEST_CASE("full filename formatter", "[pattern_formatter]") + { +- spdlog::pattern_formatter formatter("%g", spdlog::pattern_time_type::local, ""); ++ ds_spdlog::pattern_formatter formatter("%g", ds_spdlog::pattern_time_type::local, ""); + memory_buf_t formatted; + std::string logger_name = "logger-name"; +- spdlog::source_loc source_loc{test_path, 123, "some_func()"}; +- spdlog::details::log_msg msg(source_loc, "logger-name", spdlog::level::info, "Hello"); ++ ds_spdlog::source_loc source_loc{test_path, 123, "some_func()"}; ++ ds_spdlog::details::log_msg msg(source_loc, "logger-name", ds_spdlog::level::info, "Hello"); + formatter.format(msg, formatted); + + REQUIRE(to_string_view(formatted) == test_path); +@@ -437,50 +437,50 @@ TEST_CASE("full filename formatter", "[pattern_formatter]") + + TEST_CASE("custom flags", "[pattern_formatter]") + { +- auto formatter = std::make_shared(); ++ auto formatter = std::make_shared(); + formatter->add_flag('t', "custom1").add_flag('u', "custom2").set_pattern("[%n] [%t] [%u] %v"); + + memory_buf_t formatted; + +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(ds_spdlog::source_loc{}, "logger-name", ds_spdlog::level::info, "some message"); + formatter->format(msg, formatted); +- auto expected = spdlog::fmt_lib::format("[logger-name] [custom1] [custom2] some message{}", spdlog::details::os::default_eol); ++ auto expected = ds_spdlog::fmt_lib::format("[logger-name] [custom1] [custom2] some message{}", ds_spdlog::details::os::default_eol); + + REQUIRE(to_string_view(formatted) == expected); + } + + TEST_CASE("custom flags-padding", "[pattern_formatter]") + { +- auto formatter = std::make_shared(); ++ auto formatter = std::make_shared(); + formatter->add_flag('t', "custom1").add_flag('u', "custom2").set_pattern("[%n] [%t] [%5u] %v"); + + memory_buf_t formatted; + +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(ds_spdlog::source_loc{}, "logger-name", ds_spdlog::level::info, "some message"); + formatter->format(msg, formatted); +- auto expected = spdlog::fmt_lib::format("[logger-name] [custom1] [ custom2] some message{}", spdlog::details::os::default_eol); ++ auto expected = ds_spdlog::fmt_lib::format("[logger-name] [custom1] [ custom2] some message{}", ds_spdlog::details::os::default_eol); + + REQUIRE(to_string_view(formatted) == expected); + } + + TEST_CASE("custom flags-exception", "[pattern_formatter]") + { +- auto formatter = std::make_shared(); ++ auto formatter = std::make_shared(); + formatter->add_flag('t', "throw_me").add_flag('u', "custom2").set_pattern("[%n] [%t] [%u] %v"); + + memory_buf_t formatted; +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); +- CHECK_THROWS_AS(formatter->format(msg, formatted), spdlog::spdlog_ex); ++ ds_spdlog::details::log_msg msg(ds_spdlog::source_loc{}, "logger-name", ds_spdlog::level::info, "some message"); ++ CHECK_THROWS_AS(formatter->format(msg, formatted), ds_spdlog::spdlog_ex); + } + + TEST_CASE("override need_localtime", "[pattern_formatter]") + { +- auto formatter = std::make_shared(spdlog::pattern_time_type::local, "\n"); ++ auto formatter = std::make_shared(ds_spdlog::pattern_time_type::local, "\n"); + formatter->add_flag('t', "time").set_pattern("%t> %v"); + + { + memory_buf_t formatted; +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(ds_spdlog::source_loc{}, "logger-name", ds_spdlog::level::info, "some message"); + formatter->format(msg, formatted); + REQUIRE(to_string_view(formatted) == "0:00AM> some message\n"); + } +@@ -488,13 +488,13 @@ TEST_CASE("override need_localtime", "[pattern_formatter]") + { + formatter->need_localtime(); + +- auto now_tm = spdlog::details::os::localtime(); ++ auto now_tm = ds_spdlog::details::os::localtime(); + std::stringstream oss; + oss << (now_tm.tm_hour % 12) << ":" << std::setfill('0') << std::setw(2) << now_tm.tm_min << (now_tm.tm_hour / 12 ? "PM" : "AM") + << "> some message\n"; + + memory_buf_t formatted; +- spdlog::details::log_msg msg(spdlog::source_loc{}, "logger-name", spdlog::level::info, "some message"); ++ ds_spdlog::details::log_msg msg(ds_spdlog::source_loc{}, "logger-name", ds_spdlog::level::info, "some message"); + formatter->format(msg, formatted); + REQUIRE(to_string_view(formatted) == oss.str()); + } +diff --git a/tests/test_registry.cpp b/tests/test_registry.cpp +index 8e632cc6..89c48a52 100644 +--- a/tests/test_registry.cpp ++++ b/tests/test_registry.cpp +@@ -6,39 +6,39 @@ static const char *const tested_logger_name2 = "null_logger2"; + #ifndef SPDLOG_NO_EXCEPTIONS + TEST_CASE("register_drop", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- REQUIRE(spdlog::get(tested_logger_name) != nullptr); ++ ds_spdlog::drop_all(); ++ ds_spdlog::create(tested_logger_name); ++ REQUIRE(ds_spdlog::get(tested_logger_name) != nullptr); + // Throw if registering existing name +- REQUIRE_THROWS_AS(spdlog::create(tested_logger_name), spdlog::spdlog_ex); ++ REQUIRE_THROWS_AS(ds_spdlog::create(tested_logger_name), ds_spdlog::spdlog_ex); + } + + TEST_CASE("explicit register", "[registry]") + { +- spdlog::drop_all(); +- auto logger = std::make_shared(tested_logger_name, std::make_shared()); +- spdlog::register_logger(logger); +- REQUIRE(spdlog::get(tested_logger_name) != nullptr); ++ ds_spdlog::drop_all(); ++ auto logger = std::make_shared(tested_logger_name, std::make_shared()); ++ ds_spdlog::register_logger(logger); ++ REQUIRE(ds_spdlog::get(tested_logger_name) != nullptr); + // Throw if registering existing name +- REQUIRE_THROWS_AS(spdlog::create(tested_logger_name), spdlog::spdlog_ex); ++ REQUIRE_THROWS_AS(ds_spdlog::create(tested_logger_name), ds_spdlog::spdlog_ex); + } + #endif + + TEST_CASE("apply_all", "[registry]") + { +- spdlog::drop_all(); +- auto logger = std::make_shared(tested_logger_name, std::make_shared()); +- spdlog::register_logger(logger); +- auto logger2 = std::make_shared(tested_logger_name2, std::make_shared()); +- spdlog::register_logger(logger2); ++ ds_spdlog::drop_all(); ++ auto logger = std::make_shared(tested_logger_name, std::make_shared()); ++ ds_spdlog::register_logger(logger); ++ auto logger2 = std::make_shared(tested_logger_name2, std::make_shared()); ++ ds_spdlog::register_logger(logger2); + + int counter = 0; +- spdlog::apply_all([&counter](std::shared_ptr) { counter++; }); ++ ds_spdlog::apply_all([&counter](std::shared_ptr) { counter++; }); + REQUIRE(counter == 2); + + counter = 0; +- spdlog::drop(tested_logger_name2); +- spdlog::apply_all([&counter](std::shared_ptr l) { ++ ds_spdlog::drop(tested_logger_name2); ++ ds_spdlog::apply_all([&counter](std::shared_ptr l) { + REQUIRE(l->name() == tested_logger_name); + counter++; + }); +@@ -47,70 +47,70 @@ TEST_CASE("apply_all", "[registry]") + + TEST_CASE("drop", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- spdlog::drop(tested_logger_name); +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); ++ ds_spdlog::drop_all(); ++ ds_spdlog::create(tested_logger_name); ++ ds_spdlog::drop(tested_logger_name); ++ REQUIRE_FALSE(ds_spdlog::get(tested_logger_name)); + } + + TEST_CASE("drop-default", "[registry]") + { +- spdlog::set_default_logger(spdlog::null_logger_st(tested_logger_name)); +- spdlog::drop(tested_logger_name); +- REQUIRE_FALSE(spdlog::default_logger()); +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); ++ ds_spdlog::set_default_logger(ds_spdlog::null_logger_st(tested_logger_name)); ++ ds_spdlog::drop(tested_logger_name); ++ REQUIRE_FALSE(ds_spdlog::default_logger()); ++ REQUIRE_FALSE(ds_spdlog::get(tested_logger_name)); + } + + TEST_CASE("drop_all", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- spdlog::create(tested_logger_name2); +- spdlog::drop_all(); +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); +- REQUIRE_FALSE(spdlog::get(tested_logger_name2)); +- REQUIRE_FALSE(spdlog::default_logger()); ++ ds_spdlog::drop_all(); ++ ds_spdlog::create(tested_logger_name); ++ ds_spdlog::create(tested_logger_name2); ++ ds_spdlog::drop_all(); ++ REQUIRE_FALSE(ds_spdlog::get(tested_logger_name)); ++ REQUIRE_FALSE(ds_spdlog::get(tested_logger_name2)); ++ REQUIRE_FALSE(ds_spdlog::default_logger()); + } + + TEST_CASE("drop non existing", "[registry]") + { +- spdlog::drop_all(); +- spdlog::create(tested_logger_name); +- spdlog::drop("some_name"); +- REQUIRE_FALSE(spdlog::get("some_name")); +- REQUIRE(spdlog::get(tested_logger_name)); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); ++ ds_spdlog::create(tested_logger_name); ++ ds_spdlog::drop("some_name"); ++ REQUIRE_FALSE(ds_spdlog::get("some_name")); ++ REQUIRE(ds_spdlog::get(tested_logger_name)); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("default logger", "[registry]") + { +- spdlog::drop_all(); +- spdlog::set_default_logger(spdlog::null_logger_st(tested_logger_name)); +- REQUIRE(spdlog::get(tested_logger_name) == spdlog::default_logger()); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); ++ ds_spdlog::set_default_logger(ds_spdlog::null_logger_st(tested_logger_name)); ++ REQUIRE(ds_spdlog::get(tested_logger_name) == ds_spdlog::default_logger()); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("set_default_logger(nullptr)", "[registry]") + { +- spdlog::set_default_logger(nullptr); +- REQUIRE_FALSE(spdlog::default_logger()); ++ ds_spdlog::set_default_logger(nullptr); ++ REQUIRE_FALSE(ds_spdlog::default_logger()); + } + + TEST_CASE("disable automatic registration", "[registry]") + { + // set some global parameters +- spdlog::level::level_enum log_level = spdlog::level::level_enum::warn; +- spdlog::set_level(log_level); ++ ds_spdlog::level::level_enum log_level = ds_spdlog::level::level_enum::warn; ++ ds_spdlog::set_level(log_level); + // but disable automatic registration +- spdlog::set_automatic_registration(false); +- auto logger1 = spdlog::create(tested_logger_name, SPDLOG_FILENAME_T("filename"), 11, 59); +- auto logger2 = spdlog::create_async(tested_logger_name2); ++ ds_spdlog::set_automatic_registration(false); ++ auto logger1 = ds_spdlog::create(tested_logger_name, SPDLOG_FILENAME_T("filename"), 11, 59); ++ auto logger2 = ds_spdlog::create_async(tested_logger_name2); + // loggers should not be part of the registry +- REQUIRE_FALSE(spdlog::get(tested_logger_name)); +- REQUIRE_FALSE(spdlog::get(tested_logger_name2)); ++ REQUIRE_FALSE(ds_spdlog::get(tested_logger_name)); ++ REQUIRE_FALSE(ds_spdlog::get(tested_logger_name2)); + // but make sure they are still initialized according to global defaults + REQUIRE(logger1->level() == log_level); + REQUIRE(logger2->level() == log_level); +- spdlog::set_level(spdlog::level::info); +- spdlog::set_automatic_registration(true); ++ ds_spdlog::set_level(ds_spdlog::level::info); ++ ds_spdlog::set_automatic_registration(true); + } +diff --git a/tests/test_sink.h b/tests/test_sink.h +index 57db65c1..4963d018 100644 +--- a/tests/test_sink.h ++++ b/tests/test_sink.h +@@ -12,7 +12,7 @@ + #include + #include + +-namespace spdlog { ++namespace ds_spdlog { + namespace sinks { + + template +@@ -76,4 +76,4 @@ using test_sink_mt = test_sink; + using test_sink_st = test_sink; + + } // namespace sinks +-} // namespace spdlog ++} // namespace ds_spdlog +diff --git a/tests/test_stdout_api.cpp b/tests/test_stdout_api.cpp +index d55223ff..e407bcf9 100644 +--- a/tests/test_stdout_api.cpp ++++ b/tests/test_stdout_api.cpp +@@ -6,93 +6,93 @@ + #include "spdlog/sinks/stdout_color_sinks.h" + TEST_CASE("stdout_st", "[stdout]") + { +- auto l = spdlog::stdout_logger_st("test"); ++ auto l = ds_spdlog::stdout_logger_st("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::trace); ++ l->set_level(ds_spdlog::level::trace); + l->trace("Test stdout_st"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("stdout_mt", "[stdout]") + { +- auto l = spdlog::stdout_logger_mt("test"); ++ auto l = ds_spdlog::stdout_logger_mt("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::debug); ++ l->set_level(ds_spdlog::level::debug); + l->debug("Test stdout_mt"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("stderr_st", "[stderr]") + { +- auto l = spdlog::stderr_logger_st("test"); ++ auto l = ds_spdlog::stderr_logger_st("test"); + l->set_pattern("%+"); + l->info("Test stderr_st"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("stderr_mt", "[stderr]") + { +- auto l = spdlog::stderr_logger_mt("test"); ++ auto l = ds_spdlog::stderr_logger_mt("test"); + l->set_pattern("%+"); + l->info("Test stderr_mt"); + l->warn("Test stderr_mt"); + l->error("Test stderr_mt"); + l->critical("Test stderr_mt"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + // color loggers + TEST_CASE("stdout_color_st", "[stdout]") + { +- auto l = spdlog::stdout_color_st("test"); ++ auto l = ds_spdlog::stdout_color_st("test"); + l->set_pattern("%+"); + l->info("Test stdout_color_st"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("stdout_color_mt", "[stdout]") + { +- auto l = spdlog::stdout_color_mt("test"); ++ auto l = ds_spdlog::stdout_color_mt("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::trace); ++ l->set_level(ds_spdlog::level::trace); + l->trace("Test stdout_color_mt"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("stderr_color_st", "[stderr]") + { +- auto l = spdlog::stderr_color_st("test"); ++ auto l = ds_spdlog::stderr_color_st("test"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::debug); ++ l->set_level(ds_spdlog::level::debug); + l->debug("Test stderr_color_st"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + TEST_CASE("stderr_color_mt", "[stderr]") + { +- auto l = spdlog::stderr_color_mt("test"); ++ auto l = ds_spdlog::stderr_color_mt("test"); + l->set_pattern("%+"); + l->info("Test stderr_color_mt"); + l->warn("Test stderr_color_mt"); + l->error("Test stderr_color_mt"); + l->critical("Test stderr_color_mt"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + #ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT + + TEST_CASE("wchar_api", "[stdout]") + { +- auto l = spdlog::stdout_logger_st("wchar_logger"); ++ auto l = ds_spdlog::stdout_logger_st("wchar_logger"); + l->set_pattern("%+"); +- l->set_level(spdlog::level::trace); ++ l->set_level(ds_spdlog::level::trace); + l->trace(L"Test wchar_api"); + l->trace(L"Test wchar_api {}", L"param"); + l->trace(L"Test wchar_api {}", 1); + l->trace(L"Test wchar_api {}", std::wstring{L"wstring param"}); + l->trace(std::wstring{L"Test wchar_api wstring"}); + SPDLOG_LOGGER_DEBUG(l, L"Test SPDLOG_LOGGER_DEBUG {}", L"param"); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } + + #endif +diff --git a/tests/test_stopwatch.cpp b/tests/test_stopwatch.cpp +index 81827b87..83984e59 100644 +--- a/tests/test_stopwatch.cpp ++++ b/tests/test_stopwatch.cpp +@@ -9,7 +9,7 @@ TEST_CASE("stopwatch1", "[stopwatch]") + milliseconds wait_ms(200); + milliseconds tolerance_ms(250); + auto start = clock::now(); +- spdlog::stopwatch sw; ++ ds_spdlog::stopwatch sw; + std::this_thread::sleep_for(wait_ms); + auto stop = clock::now(); + auto diff_ms = std::chrono::duration_cast(stop - start); +@@ -19,7 +19,7 @@ TEST_CASE("stopwatch1", "[stopwatch]") + + TEST_CASE("stopwatch2", "[stopwatch]") + { +- using spdlog::sinks::test_sink_st; ++ using ds_spdlog::sinks::test_sink_st; + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using clock = std::chrono::steady_clock; +@@ -30,8 +30,8 @@ TEST_CASE("stopwatch2", "[stopwatch]") + auto test_sink = std::make_shared(); + + auto start = clock::now(); +- spdlog::stopwatch sw; +- spdlog::logger logger("test-stopwatch", test_sink); ++ ds_spdlog::stopwatch sw; ++ ds_spdlog::logger logger("test-stopwatch", test_sink); + logger.set_pattern("%v"); + std::this_thread::sleep_for(wait_duration); + auto stop = clock::now(); +diff --git a/tests/test_systemd.cpp b/tests/test_systemd.cpp +index 8688f41d..78fb7864 100644 +--- a/tests/test_systemd.cpp ++++ b/tests/test_systemd.cpp +@@ -3,9 +3,9 @@ + + TEST_CASE("systemd", "[all]") + { +- auto systemd_sink = std::make_shared(); +- spdlog::logger logger("spdlog_systemd_test", systemd_sink); +- logger.set_level(spdlog::level::trace); ++ auto systemd_sink = std::make_shared(); ++ ds_spdlog::logger logger("spdlog_systemd_test", systemd_sink); ++ logger.set_level(ds_spdlog::level::trace); + logger.trace("test spdlog trace"); + logger.debug("test spdlog debug"); + SPDLOG_LOGGER_INFO((&logger), "test spdlog info"); +diff --git a/tests/test_time_point.cpp b/tests/test_time_point.cpp +index bacff69a..3f0311a7 100644 +--- a/tests/test_time_point.cpp ++++ b/tests/test_time_point.cpp +@@ -4,10 +4,10 @@ + + TEST_CASE("time_point1", "[time_point log_msg]") + { +- std::shared_ptr test_sink(new spdlog::sinks::test_sink_st); +- spdlog::logger logger("test-time_point", test_sink); ++ std::shared_ptr test_sink(new ds_spdlog::sinks::test_sink_st); ++ ds_spdlog::logger logger("test-time_point", test_sink); + +- spdlog::source_loc source{}; ++ ds_spdlog::source_loc source{}; + std::chrono::system_clock::time_point tp{std::chrono::system_clock::now()}; + test_sink->set_pattern("%T.%F"); // interested in the time_point + +@@ -15,15 +15,15 @@ TEST_CASE("time_point1", "[time_point log_msg]") + test_sink->set_delay(std::chrono::milliseconds(10)); + for (int i = 0; i < 5; i++) + { +- spdlog::details::log_msg msg{tp, source, "test_logger", spdlog::level::info, "message"}; ++ ds_spdlog::details::log_msg msg{tp, source, "test_logger", ds_spdlog::level::info, "message"}; + test_sink->log(msg); + } + +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(tp, source, spdlog::level::info, "formatted message"); +- logger.log(source, spdlog::level::info, "formatted message"); // last line has different time_point ++ logger.log(tp, source, ds_spdlog::level::info, "formatted message"); ++ logger.log(tp, source, ds_spdlog::level::info, "formatted message"); ++ logger.log(tp, source, ds_spdlog::level::info, "formatted message"); ++ logger.log(tp, source, ds_spdlog::level::info, "formatted message"); ++ logger.log(source, ds_spdlog::level::info, "formatted message"); // last line has different time_point + + // now the real test... that the times are the same. + std::vector lines = test_sink->lines(); +@@ -32,5 +32,5 @@ TEST_CASE("time_point1", "[time_point log_msg]") + REQUIRE(lines[4] == lines[5]); + REQUIRE(lines[6] == lines[7]); + REQUIRE(lines[8] != lines[9]); +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + } +diff --git a/tests/utils.cpp b/tests/utils.cpp +index 6d027797..f14b015e 100644 +--- a/tests/utils.cpp ++++ b/tests/utils.cpp +@@ -9,7 +9,7 @@ + + void prepare_logdir() + { +- spdlog::drop_all(); ++ ds_spdlog::drop_all(); + #ifdef _WIN32 + system("rmdir /S /Q test_logs"); + #else +@@ -48,7 +48,7 @@ std::size_t count_lines(const std::string &filename) + + void require_message_count(const std::string &filename, const std::size_t messages) + { +- if (strlen(spdlog::details::os::default_eol) == 0) ++ if (strlen(ds_spdlog::details::os::default_eol) == 0) + { + REQUIRE(count_lines(filename) == 1); + } +-- +2.33.0 + -- Gitee