From dd94c9721cd035af647a8b07da35c4cab26b4551 Mon Sep 17 00:00:00 2001 From: PaddlePaddle-Gardener Date: Thu, 13 Jan 2022 14:28:52 +0800 Subject: [PATCH] mirgate_38582 --- cmake/external/libmct.cmake | 8 +- .../distributed/common/chunk_allocator.h | 95 +++ paddle/fluid/distributed/fleet.cc | 191 ++++- .../distributed/service/brpc_ps_client.cc | 738 +++++++++++++++++- .../distributed/service/brpc_ps_server.cc | 10 + paddle/fluid/distributed/table/CMakeLists.txt | 11 +- .../distributed/table/common_dense_table.h | 38 +- .../fluid/distributed/table/depends/dense.h | 85 ++ .../distributed/table/depends/feature_value.h | 247 +++--- .../distributed/table/memory_sparse_table.cc | 635 +++++++++++++++ .../distributed/table/memory_sparse_table.h | 96 +++ .../distributed/test/dense_table_test.cc | 62 +- .../distributed/test/feature_value_test.cc | 31 +- 13 files changed, 2000 insertions(+), 247 deletions(-) create mode 100644 paddle/fluid/distributed/common/chunk_allocator.h diff --git a/cmake/external/libmct.cmake b/cmake/external/libmct.cmake index d318bc7d0f..92c3165fba 100644 --- a/cmake/external/libmct.cmake +++ b/cmake/external/libmct.cmake @@ -19,11 +19,11 @@ IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL)) MESSAGE(STATUS "use pre defined download url") SET(LIBMCT_VER "0.1.0" CACHE STRING "" FORCE) SET(LIBMCT_NAME "libmct" CACHE STRING "" FORCE) - SET(LIBMCT_URL "https://pslib.bj.bcebos.com/libmct.tar.gz" CACHE STRING "" FORCE) + SET(LIBMCT_URL "https://pslib.bj.bcebos.com/libmct/libmct.tar.gz" CACHE STRING "" FORCE) ENDIF() MESSAGE(STATUS "LIBMCT_NAME: ${LIBMCT_NAME}, LIBMCT_URL: ${LIBMCT_URL}") -SET(LIBMCT_SOURCE_DIR "${THIRD_PARTY_PATH}/libmct") -SET(LIBMCT_DOWNLOAD_DIR "${LIBMCT_SOURCE_DIR}/src/${LIBMCT_PROJECT}") +SET(LIBMCT_PREFIX_DIR "${THIRD_PARTY_PATH}/libmct") +SET(LIBMCT_DOWNLOAD_DIR "${LIBMCT_PREFIX_DIR}/src/${LIBMCT_PROJECT}") SET(LIBMCT_DST_DIR "libmct") SET(LIBMCT_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") SET(LIBMCT_INSTALL_DIR ${LIBMCT_INSTALL_ROOT}/${LIBMCT_DST_DIR}) @@ -42,7 +42,7 @@ FILE(WRITE ${LIBMCT_DOWNLOAD_DIR}/CMakeLists.txt ExternalProject_Add( ${LIBMCT_PROJECT} ${EXTERNAL_PROJECT_LOG_ARGS} - PREFIX ${LIBMCT_SOURCE_DIR} + PREFIX ${LIBMCT_PREFIX_DIR} DOWNLOAD_DIR ${LIBMCT_DOWNLOAD_DIR} DOWNLOAD_COMMAND wget --no-check-certificate ${LIBMCT_URL} -c -q -O ${LIBMCT_NAME}.tar.gz && tar zxvf ${LIBMCT_NAME}.tar.gz diff --git a/paddle/fluid/distributed/common/chunk_allocator.h b/paddle/fluid/distributed/common/chunk_allocator.h new file mode 100644 index 0000000000..17f7bb1422 --- /dev/null +++ b/paddle/fluid/distributed/common/chunk_allocator.h @@ -0,0 +1,95 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace paddle { +namespace distributed { + +// Fast allocation and deallocation of objects by allocating them in chunks. +template +class ChunkAllocator { + public: + explicit ChunkAllocator(size_t chunk_size = 64) { + CHECK(sizeof(Node) == std::max(sizeof(void*), sizeof(T))); + _chunk_size = chunk_size; + _chunks = NULL; + _free_nodes = NULL; + _counter = 0; + } + ChunkAllocator(const ChunkAllocator&) = delete; + ~ChunkAllocator() { + while (_chunks != NULL) { + Chunk* x = _chunks; + _chunks = _chunks->next; + free(x); + } + } + template + T* acquire(ARGS&&... args) { + if (_free_nodes == NULL) { + create_new_chunk(); + } + + T* x = (T*)(void*)_free_nodes; // NOLINT + _free_nodes = _free_nodes->next; + new (x) T(std::forward(args)...); + _counter++; + return x; + } + void release(T* x) { + x->~T(); + Node* node = (Node*)(void*)x; // NOLINT + node->next = _free_nodes; + _free_nodes = node; + _counter--; + } + size_t size() const { return _counter; } + + private: + struct alignas(T) Node { + union { + Node* next; + char data[sizeof(T)]; + }; + }; + struct Chunk { + Chunk* next; + Node nodes[]; + }; + + size_t _chunk_size; // how many elements in one chunk + Chunk* _chunks; // a list + Node* _free_nodes; // a list + size_t _counter; // how many elements are acquired + + void create_new_chunk() { + Chunk* chunk; + posix_memalign(reinterpret_cast(&chunk), + std::max(sizeof(void*), alignof(Chunk)), + sizeof(Chunk) + sizeof(Node) * _chunk_size); + chunk->next = _chunks; + _chunks = chunk; + + for (size_t i = 0; i < _chunk_size; i++) { + Node* node = &chunk->nodes[i]; + node->next = _free_nodes; + _free_nodes = node; + } + } +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index 9e2a0b3522..5caeab832a 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -135,13 +135,15 @@ uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { std::vector FleetWrapper::GetClientsInfo() { VLOG(3) << "Going to get client info"; - return pserver_ptr_->get_client_info(); - return std::vector(); + auto* communicator = Communicator::GetInstance(); + std::vector res = communicator->GetClientInfo(); + return res; } void FleetWrapper::CreateClient2ClientConnection() { - VLOG(3) << "Going to create client2client connection"; - pserver_ptr_->create_client2client_connection( + VLOG(1) << "Going to create client2client connection"; + auto* communicator = Communicator::GetInstance(); + communicator->_worker_ptr->create_client2client_connection( client2client_request_timeout_ms_, client2client_connect_timeout_ms_, client2client_max_retry_); } @@ -370,12 +372,26 @@ void FleetWrapper::PushDenseVarsAsync( const std::vector& var_names, std::vector>* push_sparse_status, float scale_datanorm, int batch_size) { - auto* communicator = Communicator::GetInstance(); - PADDLE_ENFORCE_EQ( - communicator->Check(table_id), true, - platform::errors::InvalidArgument( - "can not find table: %s, please check your config", table_id)); - communicator->Send(var_names, scope); + auto place = platform::CPUPlace(); + std::vector regions; + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->mutable_data(place); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id " + << table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] " + << g[tensor->numel() - 1]; + } + + auto* communicator = + dynamic_cast(Communicator::GetInstance()); + auto push_status = communicator->_worker_ptr->push_dense( + regions.data(), regions.size(), table_id); + + communicator->PushDensePostProcessing(); } void FleetWrapper::PushSparseVarsAsync( @@ -417,15 +433,139 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( return; } -void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) { +void FleetWrapper::PushSparseFromTensorAsync( + const uint64_t table_id, int fea_dim, uint64_t padding_id, + platform::Place place, std::vector* inputs, + const LoDTensor* shows, const LoDTensor* clks, + std::vector* outputs) { + int batch_size = -1; + bool batch_size_consist = true; + for (auto* input : *inputs) { + int cur_batch_size = + input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0]; + if (batch_size == -1) { + batch_size = cur_batch_size; + } else { + // CHECK(batch_size == cur_batch_size); // NOLINT + batch_size_consist = false; + break; + } + } + CHECK(batch_size > 0); // NOLINT + + int show_size = + shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0]; + CHECK(show_size == batch_size || show_size == 1); + int clk_size = + clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0]; + CHECK(clk_size == batch_size || clk_size == 1); + + CHECK(outputs->size() == inputs->size()); + std::vector push_keys; + push_keys.reserve(MAX_FEASIGN_NUM / 100); + std::vector> push_values; + push_values.reserve(MAX_FEASIGN_NUM / 100); + size_t output_len = 0; + size_t input_idx = 0; + + VLOG(2) << "fleet.cc::emb_dim: " << fea_dim; + + // TODO(zhaocaibei123): check type of show/clk is int? float? uint64? + // const long int* show_tensor = shows->data(); + // const long int* clk_tensor = clks->data(); + const int64_t* show_tensor = shows->data(); + const int64_t* clk_tensor = clks->data(); + + for (size_t index = 0; index < inputs->size(); ++index) { + framework::LoDTensor* g_tensor = outputs->at(index); + float* g = g_tensor->data(); + // no cvm + if (batch_size_consist) { // TODO(zhaocaibei123): add config + // scale_sparse_gradient_with_batch_size_ + Eigen::Map< + Eigen::Matrix> + g_mat(g, g_tensor->numel() / fea_dim, fea_dim); + g_mat.rightCols(fea_dim) *= batch_size; + } + + const framework::LoDTensor* tensor = inputs->at(index); + const int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + output_len = 0; + + if (tensor->lod().size() > 0) { + for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) { + for (int j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1]; + ++j, output_len += fea_dim) { + uint64_t real_id = static_cast(ids[j]); + if (real_id == padding_id) { + continue; + } + push_keys.emplace_back(real_id); + push_values.emplace_back(fea_dim + 3); + // slot show clk grad... consistent with CtrCommonPushValue defined in + // ctr_accessor.h + push_values.back()[0] = 2; // TODO(zhaocaibei123): slot + push_values.back()[1] = + (i >= show_size ? 1 : static_cast(show_tensor[i])); + push_values.back()[2] = + (i >= clk_size ? 0 : static_cast(clk_tensor[i])); + + float* data = push_values.back().data() + 3; + + memcpy(data, g + output_len, sizeof(float) * fea_dim); + + ++input_idx; + } + } + } else { + for (size_t i = 0; i < len; ++i, output_len += fea_dim) { + uint64_t real_id = static_cast(ids[i]); + if (real_id == padding_id) { + continue; + } + push_keys.emplace_back(real_id); + push_values.emplace_back(fea_dim + 3); + // slot show clk grad... consistent with CtrCommonPushValue defined in + // ctr_accessor.h + push_values.back()[0] = 2; // TODO(zhaocaibei123): slot + push_values.back()[1] = + (i >= show_size ? 1 : static_cast(show_tensor[i])); + push_values.back()[2] = + (i >= clk_size ? 0 : static_cast(clk_tensor[i])); + + float* data = push_values.back().data() + 3; + + memcpy(data, g + output_len, sizeof(float) * fea_dim); + + ++input_idx; + } + } + CHECK(output_len == g_tensor->numel()); + } + + std::vector push_g_vec(input_idx, nullptr); + + for (auto i = 0u; i < push_keys.size(); ++i) { + push_g_vec[i] = push_values.at(i).data(); + } + + auto* communicator = Communicator::GetInstance(); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + auto status = communicator->_worker_ptr->push_sparse( + table_id, push_keys.data(), (const float**)push_g_vec.data(), + push_keys.size()); +} + +void FleetWrapper::LoadModel(const std::string& path, const int mode) { auto* communicator = Communicator::GetInstance(); - auto ret = communicator->_worker_ptr->load(path, mode); - // auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); + auto ret = communicator->_worker_ptr->load(path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model from path:" << path << " failed"; - sleep(sleep_seconds_before_fail_exit_); - exit(-1); } } @@ -450,8 +590,6 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { int32_t feasign_cnt = ret.get(); if (feasign_cnt == -1) { LOG(ERROR) << "save model failed"; - sleep(sleep_seconds_before_fail_exit_); - exit(-1); } } @@ -562,16 +700,23 @@ void FleetWrapper::ClientFlush() { int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler) { - VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; - VLOG(3) << "pserver_ptr_=" << pserver_ptr_; - VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr; - return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, - handler); + VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; + auto* communicator = Communicator::GetInstance(); + // for unittest which does not call fleet.init_worker() first + if (communicator == nullptr) { + VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler communicator is " + "null"; + return -1; + } else { + return communicator->_worker_ptr->registe_client2client_msg_handler( + msg_type, handler); + } } std::future FleetWrapper::SendClientToClientMsg( int msg_type, int to_client_id, const std::string& msg) { - return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type, + auto* communicator = Communicator::GetInstance(); + return communicator->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg); } diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc index a6ad9d08f5..a0a09b14db 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/framework/archive.h" -const static int max_port = 65535; +static const int max_port = 65535; DEFINE_int32(pserver_push_dense_merge_limit, 12, "limit max push_dense local merge requests"); @@ -52,6 +52,9 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000, DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); +DEFINE_int32(pserver_sparse_table_shard_num, 1000, + "sparse table shard for save & load"); + namespace paddle { namespace framework { class Scope; @@ -102,6 +105,7 @@ int32_t BrpcPsClient::start_client_service() { LOG(ERROR) << "BrpcPsServer start failed"; return -1; } + _server_started = true; _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, _client_id); return 0; @@ -117,6 +121,12 @@ int32_t BrpcPsClient::create_client2client_connection( options.max_retry = max_retry; std::vector client_list = _env->get_ps_clients(); + VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: " + << client_list.size(); + for (auto cc : client_list) { + VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: " + << cc.to_string(); + } _client_channels.resize(client_list.size()); std::ostringstream os; std::string server_ip_port; @@ -184,8 +194,51 @@ int32_t BrpcPsClient::initialize() { // 启动client探听接口, 并相互建立连接 start_client_service(); + // 异步push 请求队列初始化 + const auto &worker_param = _config.worker_param().downpour_worker_param(); + for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) { + auto type = worker_param.downpour_table_param(i).type(); + auto table_id = worker_param.downpour_table_param(i).table_id(); + if (type == PS_DENSE_TABLE) { + _push_dense_task_queue_map[table_id] = + paddle::framework::MakeChannel(); + } + if (type == PS_SPARSE_TABLE) { + _push_sparse_task_queue_map[table_id] = + paddle::framework::MakeChannel(); + _push_sparse_merge_count_map[table_id] = 0; + } + } + + auto &profiler = CostProfiler::instance(); + profiler.register_profiler("pserver_client_pull_dense"); + profiler.register_profiler("pserver_client_pull_sparse"); + profiler.register_profiler("pserver_client_pull_sparse_local"); + profiler.register_profiler("pserver_client_push_sparse"); + profiler.register_profiler("pserver_client_push_sparse_parse"); + profiler.register_profiler("client_push_sparse_put"); + profiler.register_profiler("pserver_client_push_sparse"); + profiler.register_profiler("pserver_client_push_sparse_merge"); + profiler.register_profiler("pserver_client_push_sparse_rpc"); + profiler.register_profiler("pserver_client_push_dense"); + profiler.register_profiler("pserver_client_push_dense_parse"); + profiler.register_profiler("push_dense_put"); + profiler.register_profiler("pserver_client_push_dense_merge"); + profiler.register_profiler("pserver_client_push_dense_rpc"); + profiler.register_profiler("pserver_client_push_dense_send"); + _running = true; _flushing = false; + // 启动异步push线程 + _async_push_sparse_thread = + std::thread(std::bind(&BrpcPsClient::push_sparse_task_consume, this)); + // _async_push_sparse_thread.detach(); + _async_push_dense_thread = + std::thread(std::bind(&BrpcPsClient::push_dense_task_consume, this)); + // for debug + // _print_thread = + // std::thread(std::bind(&BrpcPsClient::print_queue_size_thread, this)); + return 0; } @@ -238,7 +291,7 @@ std::future BrpcPsClient::print_table_stat(uint32_t table_id) { uint64_t feasign_size = 0; uint64_t mf_size = 0; paddle::framework::BinaryArchive ar; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) { ret = -1; @@ -277,7 +330,7 @@ std::future BrpcPsClient::send_cmd( DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [request_call_num, cmd_id](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, cmd_id) != 0) { ret = -1; @@ -298,7 +351,7 @@ std::future BrpcPsClient::send_cmd( } PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_timeout_ms( - 10800000); // cmd msg don't limit timeout for save/load + 10800000 * 2); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); } @@ -312,7 +365,7 @@ std::future BrpcPsClient::send_save_cmd( request_call_num, [request_call_num, cmd_id](void *done) { int ret = 0; uint32_t feasign_size = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_save_response(i, cmd_id) < 0) { ret = -1; @@ -362,11 +415,14 @@ std::future BrpcPsClient::load(uint32_t table_id, std::future BrpcPsClient::save(const std::string &epoch, const std::string &mode) { + VLOG(1) << "BrpcPsClient::save path " << epoch; return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); } std::future BrpcPsClient::save(uint32_t table_id, const std::string &epoch, const std::string &mode) { + VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id " + << table_id; return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); } @@ -378,6 +434,7 @@ std::future BrpcPsClient::clear(uint32_t table_id) { } std::future BrpcPsClient::flush() { + VLOG(0) << "BrpcPsClient::flush begin"; _flushing = true; std::promise promise; std::future fut = promise.get_future(); @@ -385,16 +442,49 @@ std::future BrpcPsClient::flush() { VLOG(3) << "wait _async_call_num:" << _async_call_num; usleep(100000); // sleep 100ms wait async end } while (_async_call_num > 0); + VLOG(1) << "flush _async_call_num = 0"; promise.set_value(0); _flushing = false; + VLOG(0) << "BrpcPsClient::flush done"; + print_queue_size(); return fut; } +void BrpcPsClient::print_queue_size() { + for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { + auto table_id = push_sparse_task_itr.first; + auto queue_size = push_sparse_task_itr.second->Size(); + VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + << " size: " << queue_size; + } + + for (auto &task_queue_itr : _push_dense_task_queue_map) { + auto table_id = task_queue_itr.first; + auto queue_size = task_queue_itr.second->Size(); + VLOG(0) << "BrpcPsClient::print_queue_size: table " << table_id + << " size: " << queue_size; + } +} + +void BrpcPsClient::print_queue_size_thread() { + while (_running) { + usleep(1000000 * 60 * 2); + print_queue_size(); + } +} + void BrpcPsClient::finalize_worker() { flush(); + VLOG(0) << "BrpcPsClient::finalize_worker begin join thread"; _running = false; + _async_push_dense_thread.join(); + _async_push_sparse_thread.join(); + // _print_thread.join(); + VLOG(0) << "BrpcPsClient::finalize_worker begin join server"; _server.Stop(1000); _server.Join(); + _server_started = false; + VLOG(0) << "BrpcPsClient::finalize_worker done"; } std::future BrpcPsClient::stop_server() { @@ -422,19 +512,20 @@ std::future BrpcPsClient::pull_geo_param(size_t table_id, DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); uint32_t shard_nums; if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) { ret = -1; } auto &res_io_buffer = closure->cntl(0)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); - io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t)); + io_buffer_itr.copy_and_forward(reinterpret_cast(&shard_nums), + sizeof(uint32_t)); keys->resize(shard_nums); values->resize(shard_nums * accessor->update_dim()); - io_buffer_itr.copy_and_forward((void *)(keys->data()), + io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT sizeof(uint64_t) * shard_nums); - io_buffer_itr.copy_and_forward((void *)(values->data()), + io_buffer_itr.copy_and_forward((void *)(values->data()), // NOLINT shard_nums * accessor->update_size()); closure->set_promise_value(ret); }); @@ -466,8 +557,19 @@ std::future BrpcPsClient::push_sparse_param( std::vector> value_ptrs; ids.resize(request_call_num); value_ptrs.resize(request_call_num); + + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { - size_t pserver_idx = keys[i] % request_call_num; + size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]); ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } @@ -481,7 +583,7 @@ std::future BrpcPsClient::push_sparse_param( push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); - push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); char *push_data_ptr = const_cast(push_data->data()); @@ -503,6 +605,7 @@ std::future BrpcPsClient::push_sparse_param( std::future BrpcPsClient::pull_dense(Region *regions, size_t region_num, size_t table_id) { + auto timer = std::make_shared("pserver_client_pull_dense"); auto *accessor = table_accessor(table_id); size_t request_call_num = _server_channels.size(); uint32_t num_per_shard = @@ -514,7 +617,7 @@ std::future BrpcPsClient::pull_dense(Region *regions, int ret = 0; size_t region_idx = 0; // 当前填充的region偏移 size_t region_data_idx = 0; // 当前填充的region内data偏移 - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); size_t shard_data_size = num_per_shard * accessor->select_size(); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { @@ -537,7 +640,8 @@ std::future BrpcPsClient::pull_dense(Region *regions, if (region.size - region_data_idx >= shard_buffer_remain) { // region待填充空间 >= 分片buffer数据, 直接拷贝置入 io_buffer_itr.copy_and_forward( - (void *)(region.data + region_data_idx), shard_buffer_remain); + reinterpret_cast(region.data + region_data_idx), + shard_buffer_remain); region_data_idx += shard_buffer_remain; shard_buffer_remain = 0; } else if (region.size - region_data_idx == 0) { @@ -547,7 +651,7 @@ std::future BrpcPsClient::pull_dense(Region *regions, } else { // region不足以容纳所有数据,则能放多少 拷贝多少 io_buffer_itr.copy_and_forward( - (void *)(region.data + region_data_idx), + reinterpret_cast(region.data + region_data_idx), region.size - region_data_idx); shard_buffer_remain -= (region.size - region_data_idx); ++region_idx; @@ -557,6 +661,7 @@ std::future BrpcPsClient::pull_dense(Region *regions, } closure->set_promise_value(ret); }); + closure->add_timer(timer); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); @@ -564,7 +669,7 @@ std::future BrpcPsClient::pull_dense(Region *regions, closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); - closure->request(i)->add_params((char *)&num_per_shard, + closure->request(i)->add_params((char *)&num_per_shard, // NOLINT sizeof(num_per_shard)); PsService_Stub rpc_stub(get_dense_channel(i)); rpc_stub.service(closure->cntl(i), closure->request(i), @@ -608,7 +713,7 @@ std::future BrpcPsClient::push_dense_param(const Region *regions, DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) { ret = -1; @@ -621,26 +726,28 @@ std::future BrpcPsClient::push_dense_param(const Region *regions, closure->add_promise(promise); std::future fut = promise->get_future(); static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10; - static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐 - //开始多shard并行拷贝&请求 + static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; // 用于数据补齐 + // 开始多shard并行拷贝&请求 for (size_t i = 0; i < request_call_num; ++i) { closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); auto &request_buffer = closure->cntl(i)->request_attachment(); - request_buffer.append((void *)&num_per_shard, sizeof(uint32_t)); + request_buffer.append(reinterpret_cast(&num_per_shard), + sizeof(uint32_t)); auto ®ion_list = regions_partition[i]; size_t fill_remain_size = shard_data_size; for (auto ®ion : region_list) { fill_remain_size -= region.size; - request_buffer.append((void *)region.data, region.size); + request_buffer.append(reinterpret_cast(region.data), region.size); } - //保证各分片数据对齐 + // 保证各分片数据对齐 while (fill_remain_size > 0) { size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE ? REGION_ASSIGN_BUFFER_SIZE : fill_remain_size; - request_buffer.append((void *)region_assign_buffer, fill_num); + request_buffer.append(reinterpret_cast(region_assign_buffer), + fill_num); fill_remain_size -= fill_num; } PsService_Stub rpc_stub(get_dense_channel(i)); @@ -654,7 +761,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( size_t table_id, const uint64_t *keys, const float **update_values, size_t num, void *done) { auto *accessor = table_accessor(table_id); - //发送RPC请求 + // 发送RPC请求 DownpourBrpcClosure *closure = reinterpret_cast(done); auto promise = std::make_shared>(); closure->add_promise(promise); @@ -666,8 +773,18 @@ std::future BrpcPsClient::push_sparse_raw_gradient( ids.resize(request_call_num); value_ptrs.resize(request_call_num); + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { - size_t pserver_idx = keys[i] % request_call_num; + size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]); ids[pserver_idx].push_back(keys[i]); value_ptrs[pserver_idx].push_back(update_values[i]); } @@ -684,7 +801,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient( push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); - push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); char *push_data_ptr = const_cast(push_data->data()); @@ -726,14 +843,11 @@ std::future BrpcPsClient::push_dense_raw_gradient( memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); memcpy(push_data_ptr + sizeof(uint32_t), total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); - VLOG(1) << "push_dense_raw_gradient finish memcpy"; // closure->cntl(i)->set_request_compress_type( // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); PsService_Stub rpc_stub(get_dense_channel(i)); - VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i; rpc_stub.service(closure->cntl(i), closure->request(i), closure->response(i), closure); - VLOG(1) << "push_dense_raw_gradient async service " << i; } return fut; } @@ -770,14 +884,27 @@ std::future BrpcPsClient::pull_sparse(float **select_values, size_t table_id, const uint64_t *keys, size_t num, bool is_training) { + auto timer = std::make_shared("pserver_client_pull_sparse"); + auto local_timer = + std::make_shared("pserver_client_pull_sparse_local"); size_t request_call_num = _server_channels.size(); auto shard_sorted_kvs = std::make_shared< std::vector>>>(); shard_sorted_kvs->resize(request_call_num); + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { - size_t shard_id = keys[i] % request_call_num; + size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]); shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); } @@ -787,7 +914,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [shard_sorted_kvs, value_size](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) { if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { ret = -1; @@ -803,14 +930,14 @@ std::future BrpcPsClient::pull_sparse(float **select_values, for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { auto *kv_pair = &(request_kvs[kv_idx]); if (kv_pair->first == last_key) { - memcpy((void *)kv_pair->second, (void *)last_value_data, - value_size); + memcpy(reinterpret_cast(kv_pair->second), + reinterpret_cast(last_value_data), value_size); } else { last_key = kv_pair->first; last_value_data = kv_pair->second; if (value_size != - io_buffer_itr.copy_and_forward((void *)(last_value_data), - value_size)) { + io_buffer_itr.copy_and_forward( + reinterpret_cast(last_value_data), value_size)) { LOG(WARNING) << "res data is lack or not in format"; ret = -1; break; @@ -820,7 +947,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, } closure->set_promise_value(ret); }); - + closure->add_timer(timer); auto promise = std::make_shared>(); closure->add_promise(promise); std::future fut = promise->get_future(); @@ -838,7 +965,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, size_t sorted_kv_size = sorted_kvs.size(); auto &request_buffer = closure->cntl(i)->request_attachment(); - request_buffer.append((void *)&is_training, sizeof(bool)); + request_buffer.append(reinterpret_cast(&is_training), sizeof(bool)); std::vector keys_counter; keys_counter.reserve(sorted_kv_size); @@ -846,7 +973,8 @@ std::future BrpcPsClient::pull_sparse(float **select_values, ++kv_request_count; uint32_t keys = 1; last_key = sorted_kvs[kv_idx].first; - request_buffer.append((void *)&last_key, sizeof(uint64_t)); + request_buffer.append(reinterpret_cast(&last_key), + sizeof(uint64_t)); while (kv_idx < sorted_kv_size - 1 && last_key == sorted_kvs[kv_idx + 1].first) { ++kv_idx; @@ -855,7 +983,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, keys_counter.push_back(keys); } - request_buffer.append((void *)keys_counter.data(), + request_buffer.append(reinterpret_cast(keys_counter.data()), sizeof(uint32_t) * keys_counter.size()); if (kv_request_count == 0) { @@ -864,7 +992,7 @@ std::future BrpcPsClient::pull_sparse(float **select_values, closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); closure->request(i)->set_table_id(table_id); closure->request(i)->set_client_id(_client_id); - closure->request(i)->add_params((char *)&kv_request_count, + closure->request(i)->add_params((char *)&kv_request_count, // NOLINT sizeof(uint32_t)); PsService_Stub rpc_stub(get_cmd_channel(i)); closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); @@ -886,7 +1014,7 @@ std::future BrpcPsClient::send_client2client_msg( return fut; } auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) { - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); int32_t ret = closure->check_response(0, msg_type + 1000); closure->set_promise_value(ret); }); @@ -915,7 +1043,7 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); push_request->set_table_id(table_id); push_request->set_client_id(_client_id); - push_request->add_params((char *)&num, sizeof(uint32_t)); + push_request->add_params((char *)&num, sizeof(uint32_t)); // NOLINT auto *push_data = push_request->mutable_data(); push_data->resize(num * (sizeof(uint64_t) + value_size)); char *push_data_ptr = const_cast(push_data->data()); @@ -966,8 +1094,8 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, save_vec.push_back(save_huge_vec.data() + i * var_shape); } - auto status = pull_sparse((float **)save_vec.data(), table_id, - save_key.data(), save_key.size(), true); + auto status = pull_sparse(reinterpret_cast(save_vec.data()), + table_id, save_key.data(), save_key.size(), true); status.wait(); // create lod tensor @@ -1000,5 +1128,529 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, return 0; } +std::future BrpcPsClient::push_sparse(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num) { + auto push_timer = std::make_shared("pserver_client_push_sparse"); + CostTimer parse_timer("pserver_client_push_sparse_parse"); + int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); + while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) { + // LOG(INFO) << "push_sparse Waiting for async_call_num comsume, task_num:" + // << push_sparse_async_num << ", max_task_limit:" << + // FLAGS_pserver_max_async_call_num; + usleep(5000); // 5ms + // push_sparse_async_num = _push_sparse_task_queue_map[table_id]->size(); + push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); + } + auto put_timer = std::make_shared("client_push_sparse_put"); + thread_local std::vector>> + shard_sorted_kv_list; + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + shard_sorted_kv_list.resize(request_call_num); + for (auto &x : shard_sorted_kv_list) { + x.clear(); + } + const auto &server_param = _config.server_param().downpour_server_param(); + uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num; + for (int i = 0; i < server_param.downpour_table_param_size(); ++i) { + const auto &table_param = server_param.downpour_table_param(i); + if (table_param.table_id() == table_id) { + shard_num = table_param.shard_num(); + break; + } + } + for (size_t i = 0; i < num; ++i) { + size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]); + shard_sorted_kv_list[shard_id].push_back({keys[i], update_values[i]}); + } + auto sparse_task_data = _sparse_task_pool.get(); + sparse_task_data->shared_data.resize(request_call_num); + auto async_task = new SparseAsyncTask(sparse_task_data, table_id, push_timer); + + for (size_t i = 0; i < request_call_num; ++i) { + auto &sorted_kv_list = shard_sorted_kv_list[i]; + size_t sorted_kv_size = sorted_kv_list.size(); + auto &shard_kv_data = async_task->data()->shared_data[i]; + shard_kv_data.key_list.resize(sorted_kv_size); + shard_kv_data.value_list.resize(sorted_kv_size); + + if (sorted_kv_size == 0) { + shard_kv_data.kv_num = 0; + continue; + } + + uint32_t value_size = accessor->update_size(); + for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { + shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first; + shard_kv_data.value_list[kv_idx].assign( + (const char *)sorted_kv_list[kv_idx].second, value_size); + } + shard_kv_data.kv_num = sorted_kv_size; + } + + std::future fut = async_task->get_future(); + _push_sparse_task_queue_map[table_id]->Put(std::move(async_task)); + return fut; +} + +void BrpcPsClient::push_sparse_task_consume() { + uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit; + std::vector> task_list; + size_t request_call_num = _server_channels.size(); + ::ThreadPool async_push_sparse_shard_threads( + FLAGS_pserver_sparse_merge_thread); + while (_running) { + platform::Timer timeline; + timeline.Start(); + // 所有sparseTable的pushTask 进行处理 + for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) { + auto table_id = push_sparse_task_itr.first; + auto *accessor = table_accessor(table_id); + auto &task_queue = push_sparse_task_itr.second; + auto queue_size = task_queue->Size(); + if (queue_size == 0) { + continue; + } + if (merge_size > 0 && (queue_size <= 1 && _flushing == false)) { + continue; + } + ++_async_call_num; + + int merge_count = 0; + for (size_t i = 0; i < task_list.size(); ++i) { + if (task_list[i]->data()) { + _sparse_task_pool.push(task_list[i]->data()); + } + } + auto sparse_task_data = _sparse_task_pool.get(); + + task_list.clear(); + int cur_meger_size = task_queue->Size(); + + // task_list[0] 为一个空SparseAsyncTask, 分shard异步merge结果存入此结构。 + sparse_task_data->shared_data.resize(request_call_num); + auto push_timer = + std::make_shared("pserver_client_push_sparse"); + + auto async_task = + new SparseAsyncTask(sparse_task_data, table_id, push_timer); + + task_list.reserve(cur_meger_size + 1); + + task_list.push_back( + std::move(std::shared_ptr(async_task))); + + while (!task_queue->Empty() && merge_count < cur_meger_size) { + ++merge_count; + SparseAsyncTask *task; + task_queue->Get(task); + task_list.push_back(std::shared_ptr(task)); + } + + _push_sparse_merge_count_map[table_id] += merge_count; + + // 达到或大于 merge_size发送, 发送过程中 + std::vector request_kv_num(request_call_num, 0); + + if (_push_sparse_merge_count_map[table_id] >= merge_size || + _flushing == true) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = reinterpret_cast(done); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + + for_each(task_list.begin() + 1, task_list.end(), + [&request_kv_num, request_call_num, + closure](std::shared_ptr &task) { + closure->add_timer(task->timer()); + closure->add_promise(task->promise()); + }); + + CostTimer merge_timer("pserver_client_push_sparse_merge"); + auto rpc_timer = + std::make_shared("pserver_client_push_sparse_rpc"); + closure->add_timer(rpc_timer); + + std::vector> merge_status(request_call_num); + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx] = + async_push_sparse_shard_threads.enqueue(std::bind( + &BrpcPsClient::push_sparse_async_shard_push, this, task_list, + request_kv_num, table_id, shard_idx, closure, accessor)); + } + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx].wait(); + } + merge_status.clear(); + std::vector>().swap(merge_status); + _push_sparse_merge_count_map[table_id] = 0; + + auto queue_size = task_queue->Size(); + } else { // 未达到阈值 只做多路归并 + std::vector> merge_status(request_call_num); + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx] = + async_push_sparse_shard_threads.enqueue(std::bind( + &BrpcPsClient::push_sparse_async_shard_merge, this, task_list, + request_kv_num, table_id, shard_idx, accessor)); + } + for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + merge_status[shard_idx].wait(); + } + + // meger到task_list[0] + auto async_task = new SparseAsyncTask(*(task_list[0].get())); + + task_queue->Put(std::move(async_task)); + --_async_call_num; + merge_status.clear(); + std::vector>().swap(merge_status); + } + } + timeline.Pause(); + auto wait_ms = + FLAGS_pserver_async_push_sparse_interval_ms - (timeline.ElapsedMS()); + if (wait_ms > 0) { + usleep(wait_ms * 1000); + } + } +} + +void sparse_local_merge(ValueAccessor *accessor, float *merge_data, + const float *another_data) { + size_t col_num = accessor->update_size() / sizeof(float); + float *merge_data_shell[col_num]; + const float *another_data_shell[col_num]; + for (int i = 0; i < col_num; ++i) { + merge_data_shell[i] = merge_data + i; + another_data_shell[i] = another_data + i; + } + accessor->merge(merge_data_shell, another_data_shell, 1); +} + +int BrpcPsClient::push_sparse_async_shard_merge( + std::vector> &task_list, + std::vector &request_kv_num, int table_id, int shard_idx, + ValueAccessor *accessor) { + size_t merged_kv_count = 0; + uint64_t min_key = UINT64_MAX; + uint32_t value_size = accessor->update_size(); + + thread_local std::vector> sorted_kv_list; + sorted_kv_list.clear(); + for (int i = 1; i < task_list.size(); ++i) { + size_t kv_num = task_list[i]->data()->shared_data[shard_idx].kv_num; + auto &key_list = task_list[i]->data()->shared_data[shard_idx].key_list; + auto &value_list = task_list[i]->data()->shared_data[shard_idx].value_list; + + for (int j = 0; j < kv_num; ++j) { + if (value_list[j].size() < value_size) { + LOG(WARNING) << "value_list[" << j << "]: " << value_list[j].c_str() + << "is invalid."; + continue; + } + char *task_data_ptr = const_cast(value_list[j].data()); + sorted_kv_list.push_back( + {key_list[j], reinterpret_cast(task_data_ptr)}); + } + } + + // 按key排序&去重 + std::sort(sorted_kv_list.begin(), sorted_kv_list.end(), + [](const std::pair &k1, + const std::pair &k2) { + return k1.first < k2.first; + }); + + auto &async_task = task_list[0]; + size_t sorted_kv_size = sorted_kv_list.size(); + auto &shard_kv_data = async_task->data()->shared_data[shard_idx]; + shard_kv_data.key_list.resize(sorted_kv_size); + shard_kv_data.value_list.resize(sorted_kv_size); + + // 将去重后数据写入分shard包 + if (sorted_kv_size == 0) { + shard_kv_data.kv_num = 0; + return 0; + } else if (sorted_kv_size == 1) { + shard_kv_data.kv_num = 1; + shard_kv_data.key_list[0] = sorted_kv_list[0].first; + shard_kv_data.value_list[0].assign((const char *)(sorted_kv_list[0].second), + value_size); + return 0; + } + + // 去重 本地merge + uint64_t last_key = sorted_kv_list[0].first; + const float *last_value_data = sorted_kv_list[0].second; + float *last_merge_data = NULL; + std::shared_ptr merger_buffer(new char[value_size], + array_deleter()); + for (size_t kv_idx = 1; kv_idx < sorted_kv_size; ++kv_idx) { + while (kv_idx < sorted_kv_size && + last_key == sorted_kv_list[kv_idx].first) { + if (last_merge_data == NULL) { + last_merge_data = reinterpret_cast(merger_buffer.get()); + memcpy(last_merge_data, last_value_data, value_size); + } + sparse_local_merge(accessor, last_merge_data, + sorted_kv_list[kv_idx].second); + ++kv_idx; + } + if (last_merge_data != NULL) { + shard_kv_data.value_list[merged_kv_count].assign( + (const char *)last_merge_data, value_size); + last_merge_data = NULL; + } else { + shard_kv_data.value_list[merged_kv_count].assign( + (const char *)sorted_kv_list[kv_idx - 1].second, value_size); + } + shard_kv_data.key_list[merged_kv_count++] = last_key; + if (kv_idx < sorted_kv_size) { + last_key = sorted_kv_list[kv_idx].first; + last_value_data = sorted_kv_list[kv_idx].second; + } + if (kv_idx == sorted_kv_size - 1) { + shard_kv_data.value_list[merged_kv_count].assign( + (const char *)last_value_data, value_size); + shard_kv_data.key_list[merged_kv_count++] = last_key; + } + } + shard_kv_data.kv_num = merged_kv_count; + return 0; +} + +int BrpcPsClient::push_sparse_async_shard_push( + std::vector> &task_list, + std::vector &request_kv_num, int table_id, int shard_idx, + DownpourBrpcClosure *closure, ValueAccessor *accessor) { + push_sparse_async_shard_merge(task_list, request_kv_num, table_id, shard_idx, + accessor); + size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num; + + auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list; + auto &merged_value_list = + task_list[0]->data()->shared_data[shard_idx].value_list; + + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params(reinterpret_cast(&merged_kv_count), + sizeof(uint32_t)); // NOLINT + auto *push_data = push_request->mutable_data(); + push_data->resize(merged_kv_count * + (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, merged_key_list.data(), + merged_kv_count * sizeof(uint64_t)); + push_data_ptr += merged_kv_count * sizeof(uint64_t); + for (int i = 0; i < merged_kv_count; ++i) { + const char *task_data_ptr = merged_value_list[i].data(); + + memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT + accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + _push_sparse_merge_count_map[table_id] = 0; + return 0; +} + +std::future BrpcPsClient::push_dense(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + auto push_timer = std::make_shared("pserver_client_push_dense"); + auto parse_timer = + std::make_shared("pserver_client_push_dense_parse"); + int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); + while (push_dense_async_num > FLAGS_pserver_max_async_call_num) { + LOG(INFO) << "push_dense Waiting for async_call_num comsume, task_num:" + << push_dense_async_num + << ", max_task_limit:" << FLAGS_pserver_max_async_call_num; + usleep(5000); // 5ms + push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); + } + auto push_dense_timer = std::make_shared("push_dense_put"); + // auto dense_data = _dense_matrix_obj_pool.get(); + auto dense_data = std::make_shared>(); + auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer); + size_t request_call_num = _server_channels.size(); + + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + + // 将region数据拷贝到转置矩阵中 + async_task->data()->resize(num_per_shard * request_call_num * + accessor->update_dim()); + float *data = async_task->data()->data(); + size_t data_size = async_task->data()->size(); + uint32_t pos = 0; + for (size_t i = 0; i < region_num; ++i) { + uint32_t data_num = regions[i].size / sizeof(float); + CHECK(pos + data_num <= data_size) + << "invalid dense size, cur pos[" << pos << "]" + << " data_num[" << data_num << "] size[" << data_size << "]"; + const float *region_data = (const float *)(regions[i].data); + memcpy(data + pos, region_data, regions[i].size); + pos += data_num; + } + std::future fut = async_task->get_future(); + _push_dense_task_queue_map[table_id]->Put(std::move(async_task)); + return fut; +} + +void BrpcPsClient::push_dense_task_consume() { + uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit; + static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge; + ::ThreadPool async_merge_dense_threads(10); + while (_running) { + platform::Timer timeline; + timeline.Start(); + for (auto &task_queue_itr : _push_dense_task_queue_map) { + auto &task_queue = task_queue_itr.second; + auto queue_size = task_queue->Size(); + if (queue_size == 0) { + continue; + } + if (queue_size <= merge_size && _flushing == false) { + continue; + } + ++_async_call_num; + DenseAsyncTask *task; + task_queue->Get(task); + auto *accessor = table_accessor(task->table_id()); + // 设置请求回调 + size_t request_call_num = _server_channels.size(); + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = reinterpret_cast(done); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + + auto &total_send_data_vec = *(task->data()); + float *total_send_data = + reinterpret_cast(total_send_data_vec.data()); + size_t total_send_data_size = total_send_data_vec.size(); + { + CostTimer merge_timer("pserver_client_push_dense_merge"); + uint32_t merge_count = 0; + std::vector> merge_status(merge_size); + while (!task_queue->Empty() && merge_count < merge_size) { + auto *async_task = new DenseAsyncTask(); + task_queue->Get(async_task); + closure->add_timer(async_task->timer()); + closure->add_promise(async_task->promise()); + merge_status[merge_count] = async_merge_dense_threads.enqueue( + [closure, accessor, &total_send_data, total_send_data_size, + async_task]() -> int { + auto &tmp_task_vec = *(async_task->data()); + const float *merge_data = tmp_task_vec.data(); + accessor->merge(&total_send_data, &merge_data, + total_send_data_size); +#pragma optimize("", off) + auto *debug_closure = closure; + auto *debug_task = async_task; + delete async_task; +#pragma optimize("", on) + return 0; + }); + ++merge_count; + } + for (int i = 0; i < merge_count; ++i) { + merge_status[i].wait(); + } + + VLOG(3) << "BrpcPsClient::push_dense_task_consume before merge " + "total_send_data[0]" + << total_send_data[0] << " total_send_data[-2]" + << total_send_data[total_send_data_size - 2] + << total_send_data[0] << " total_send_data[-1]" + << total_send_data[total_send_data_size - 1]; + + if (scale_gradient && merge_count > 1) { + Eigen::Map mat(total_send_data, 1, + total_send_data_size); + mat *= (1.0 / (merge_count + 1)); + } + + VLOG(3) << "BrpcPsClient::push_dense_task_consume after merge " + "total_send_data[0]" + << total_send_data[0] << " total_send_data[-2]" + << total_send_data[total_send_data_size - 2] + << " total_send_data[-1]" + << total_send_data[total_send_data_size - 1] << " merge_count " + << merge_count; + } + std::shared_ptr task_ptr(task); + push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size, + closure); + } + timeline.Pause(); + auto wait_ms = + FLAGS_pserver_async_push_dense_interval_ms - (timeline.ElapsedMS()); + if (wait_ms > 0) { + usleep(wait_ms * 1000); + } + } +} + +void BrpcPsClient::push_dense_raw_gradient( + std::shared_ptr &task, float *total_send_data, + size_t total_send_data_size, DownpourBrpcClosure *closure) { + auto *accessor = table_accessor(task->table_id()); + size_t request_call_num = _server_channels.size(); + // 将数据拷贝到请求buffer区 + auto timer = std::make_shared("pserver_client_push_dense_rpc"); + closure->add_timer(timer); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + auto send_timer = + std::make_shared("pserver_client_push_dense_send"); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); + closure->request(i)->set_table_id(task->table_id()); + closure->request(i)->set_client_id(_client_id); + auto *push_data = closure->request(i)->mutable_data(); + push_data->clear(); + push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); + memcpy(push_data_ptr + sizeof(uint32_t), + total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); + closure->cntl(i)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc index a1440260bf..dd7072be7d 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/distributed/service/brpc_ps_server.h" #include // NOLINT #include "butil/object_pool.h" +#include "paddle/fluid/distributed/common/cost_timer.h" #include "paddle/fluid/distributed/table/depends/sparse_utils.h" #include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/framework/archive.h" @@ -117,6 +118,11 @@ int32_t BrpcPsService::initialize() { _service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler; _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler; _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::push_global_step; + auto &profiler = CostProfiler::instance(); + profiler.register_profiler("pserver_server_pull_dense"); + profiler.register_profiler("pserver_server_push_dense"); + profiler.register_profiler("pserver_server_pull_sparse"); + profiler.register_profiler("pserver_server_push_sparse"); // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); @@ -190,6 +196,7 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, "PsRequestMessage.datas is requeired at least 1 for num of dense"); return 0; } + CostTimer timer("pserver_server_pull_dense"); uint32_t num = *(const uint32_t *)request.params(0).c_str(); if (num < 0) { set_response_code(response, -1, @@ -246,6 +253,7 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, return 0; } + CostTimer timer("pserver_server_push_dense"); /* Push Content: |--num--|---valuesData---| @@ -356,6 +364,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, return 0; } + CostTimer timer("pserver_server_pull_sparse"); uint32_t num = *(uint32_t *)(request.params(0).c_str()); auto dim = table->value_accesor()->select_dim(); @@ -396,6 +405,7 @@ int32_t BrpcPsService::push_sparse(Table *table, "least 1 for num of sparse_key"); return 0; } + CostTimer timer("pserver_server_push_sparse"); uint32_t num = *(uint32_t *)(request.params(0).c_str()); /* Push Content: diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index 7ec7041b63..b0a553f210 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -16,6 +16,11 @@ set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DIS get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) +set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/") +include_directories(${PADDLE_LIB_THIRD_PARTY_PATH}libmct/src/extern_libmct/libmct/include) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + set(EXTERN_DEP "") if(WITH_HETERPS) set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc) @@ -37,7 +42,11 @@ set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPI set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto) cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule) +cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table) + +cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) -cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost ctr_accessor) +target_link_libraries(table -fopenmp) diff --git a/paddle/fluid/distributed/table/common_dense_table.h b/paddle/fluid/distributed/table/common_dense_table.h index 74366f0358..c8813dc330 100644 --- a/paddle/fluid/distributed/table/common_dense_table.h +++ b/paddle/fluid/distributed/table/common_dense_table.h @@ -32,39 +32,32 @@ class DenseOptimizer; class CommonDenseTable : public DenseTable { public: - explicit CommonDenseTable() {} + CommonDenseTable() {} virtual ~CommonDenseTable() {} - virtual int32_t initialize() override; - virtual int32_t initialize_shard() override { return 0; } + int32_t initialize() override; + int32_t initialize_shard() override { return 0; } virtual void create_initializer(const std::string& attr, const std::string& name); virtual int32_t initialize_value(); virtual int32_t initialize_optimizer(); - virtual int32_t pull_dense(float* pull_values, size_t num) override; - virtual int32_t push_dense_param(const float* values, size_t num) override; - virtual int32_t push_dense(const float* values, size_t num) override; - virtual int32_t pour() override; - virtual int32_t set_global_lr(float* lr) override; + int32_t pull_dense(float* pull_values, size_t num) override; + int32_t push_dense_param(const float* values, size_t num) override; + int32_t push_dense(const float* values, size_t num) override; + int32_t pour() override; + int32_t set_global_lr(float* lr) override; - int32_t load(const std::string& path, const std::string& param) override { - VLOG(0) << "WARNING: dense variables will load on No.0 trainer"; - return 0; - } + int32_t load(const std::string& path, const std::string& param) override; + int32_t save(const std::string& path, const std::string& param) override; - int32_t save(const std::string& path, const std::string& param) override { - VLOG(0) << "WARNING: dense variables will save on No.0 trainer"; - return 0; - } - - virtual int32_t flush() override { return 0; } - virtual int32_t shrink(const std::string& param) override { return 0; } - virtual void clear() override { return; } + int32_t flush() override { return 0; } + int32_t shrink(const std::string& param) override { return 0; } + void clear() override { return; } protected: int32_t _push_dense(const float* values, size_t num); private: - const int task_pool_size_ = 1; + const int task_pool_size_ = 10; bool sync = true; std::vector> _shards_task_pool; int param_dim_ = 0; @@ -74,6 +67,9 @@ class CommonDenseTable : public DenseTable { ReservoirValue pull_reservoir_; std::unordered_map initializers_; std::unordered_map names_index_; + int total_dim_ = 0; + int fixed_len_params_dim_ = 0; // used for save/load + std::vector param_col_ids_; // used for save/load }; } // namespace distributed diff --git a/paddle/fluid/distributed/table/depends/dense.h b/paddle/fluid/distributed/table/depends/dense.h index 8079003d1b..d2042b7a71 100644 --- a/paddle/fluid/distributed/table/depends/dense.h +++ b/paddle/fluid/distributed/table/depends/dense.h @@ -99,6 +99,7 @@ class DSGD : public DenseOptimizer { }; // adam optimizer for dense tensor +// TODO(zhaocaibei123): add CHECK(common_dense_table.task_pool_size_) == 1 class DAdam : public DenseOptimizer { public: explicit DAdam(const CommonAccessorParameter& accessor, @@ -131,6 +132,8 @@ class DAdam : public DenseOptimizer { epsilon = 1.0e-8; } + // make sure common_dense_table.task_pool_size_ == 1; + // otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication void update(const float* update_values, size_t num, int begin, int end) override { auto update_numel = end - begin; @@ -183,5 +186,87 @@ class DAdam : public DenseOptimizer { float epsilon; }; +// adam optimizer for dense tensor +class DAdamD2Sum : public DenseOptimizer { + public: + explicit DAdamD2Sum(const CommonAccessorParameter& accessor, + std::vector>* values) { + lr_hardcode = 5e-6; + auto& names = accessor.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "LearningRate") { + learning_rate = (*values)[x].data(); + } + if (names[x] == "Param") { + param = (*values)[x].data(); + } + if (names[x] == "Moment") { + mom_velocity = (*values)[x].data(); + } + if (names[x] == "G2Sum") { + ada_g2sum = (*values)[x].data(); + } + if (names[x] == "D2Sum") { + ada_d2sum = (*values)[x].data(); + } + if (names[x] == "MomentDecayRate") { + mom_decay_rate = (*values)[x].data(); + } + if (names[x] == "AdaDecayRate") { + ada_decay_rate = (*values)[x].data(); + } + if (names[x] == "AdaEpsilon") { + ada_epsilon = (*values)[x].data(); + } + } + } + + void update(const float* update_values, size_t num, int begin, + int end) override { + auto update_numel = end - begin; + Eigen::Map mat_ada_g2sum(ada_g2sum + begin, 1, + update_numel); + + Eigen::Map mat_ada_d2sum(ada_d2sum + begin, 1, + update_numel); + Eigen::Map mat_mom_velocity(mom_velocity + begin, 1, + update_numel); + Eigen::Map mat_w(param + begin, 1, update_numel); + + Eigen::Map mat_grad(update_values + begin, 1, + update_numel); + + mat_ada_d2sum = (mat_ada_d2sum * ada_decay_rate[0]).array() + 1; + mat_ada_g2sum = + (mat_ada_g2sum * ada_decay_rate[0]) + mat_grad.cwiseProduct(mat_grad); + + thread_local std::vector scale_vec; + scale_vec.resize(update_numel); + Eigen::Map scale(scale_vec.data(), 1, update_numel); + memcpy(scale_vec.data(), mat_ada_d2sum.data(), + sizeof(float) * update_numel); + + scale = scale.array() * ada_epsilon[0]; + scale = (mat_ada_d2sum + scale).cwiseQuotient(mat_ada_g2sum + scale); + scale = scale.cwiseSqrt(); + mat_mom_velocity = + (mat_mom_velocity - mat_grad) * mom_decay_rate[0] + mat_grad; + + mat_w -= learning_rate[0] * mat_mom_velocity.cwiseProduct(scale); + } + + float* learning_rate; + float lr_hardcode; + + float* param; + float* mom_velocity; + float* ada_g2sum; + float* ada_d2sum; + + float* mom_decay_rate; + float* ada_decay_rate; + float* ada_epsilon; +}; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/table/depends/feature_value.h b/paddle/fluid/distributed/table/depends/feature_value.h index ad037a86bc..7a83fdec1d 100644 --- a/paddle/fluid/distributed/table/depends/feature_value.h +++ b/paddle/fluid/distributed/table/depends/feature_value.h @@ -14,35 +14,11 @@ #pragma once -#include -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include -#include -#include #include #include "gflags/gflags.h" -#include "butil/object_pool.h" -#include "paddle/fluid/distributed/common/utils.h" -#include "paddle/fluid/distributed/table/depends/initializers.h" -#include "paddle/fluid/distributed/thirdparty/round_robin.h" -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/rw_lock.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/fluid/platform/port.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/string_helper.h" +#include +#include "paddle/fluid/distributed/common/chunk_allocator.h" namespace paddle { namespace distributed { @@ -55,112 +31,169 @@ class FixedFeatureValue { public: FixedFeatureValue() {} ~FixedFeatureValue() {} - float *data() { return data_.data(); } - size_t size() { return data_.size(); } - void resize(size_t size) { data_.resize(size); } - void shrink_to_fit() { data_.shrink_to_fit(); } + float* data() { return _data.data(); } + size_t size() { return _data.size(); } + void resize(size_t size) { _data.resize(size); } + void shrink_to_fit() { _data.shrink_to_fit(); } private: - std::vector data_; + std::vector _data; }; -class SparseTableShard { +template +struct alignas(64) SparseTableShard { public: - typedef typename robin_hood::unordered_map + typedef typename mct::closed_hash_map> map_type; - SparseTableShard() {} - ~SparseTableShard() {} + struct iterator { + typename map_type::iterator it; + size_t bucket; + map_type* buckets; + friend bool operator==(const iterator& a, const iterator& b) { + return a.it == b.it; + } + friend bool operator!=(const iterator& a, const iterator& b) { + return a.it != b.it; + } + const KEY& key() const { return it->first; } + VALUE& value() const { return *(VALUE*)(void*)it->second; } // NOLINT + iterator& operator++() { + ++it; - FixedFeatureValue *Init(const uint64_t &id) { - size_t hash = hasher_(id); - size_t bucket = compute_bucket(hash); - auto &table = values_[bucket]; + while (it == buckets[bucket].end() && + bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) { + it = buckets[++bucket].begin(); + } - FixedFeatureValue *value = nullptr; - value = butil::get_object(); - table[id] = value; - return value; + return *this; + } + iterator operator++(int) { + iterator ret = *this; + ++*this; + return ret; + } + }; + struct local_iterator { + typename map_type::iterator it; + friend bool operator==(const local_iterator& a, const local_iterator& b) { + return a.it == b.it; + } + friend bool operator!=(const local_iterator& a, const local_iterator& b) { + return a.it != b.it; + } + const KEY& key() const { return it->first; } + VALUE& value() const { return *(VALUE*)(void*)it->second; } // NOLINT + local_iterator& operator++() { + ++it; + return *this; + } + local_iterator operator++(int) { return {it++}; } + }; + + ~SparseTableShard() { clear(); } + bool empty() { return _alloc.size() == 0; } + size_t size() { return _alloc.size(); } + void set_max_load_factor(float x) { + for (size_t bucket = 0; bucket < CTR_SPARSE_SHARD_BUCKET_NUM; bucket++) { + _buckets[bucket].max_load_factor(x); + } } - - // dont judge if (has(id)) - float *Get(const uint64_t &id) { - size_t hash = hasher_(id); - size_t bucket = compute_bucket(hash); - auto &table = values_[bucket]; - - // auto &value = table.at(id); - // return value->data_.data(); - auto res = table.find(id); - FixedFeatureValue *value = res->second; - return value->data(); + size_t bucket_count() { return CTR_SPARSE_SHARD_BUCKET_NUM; } + size_t bucket_size(size_t bucket) { return _buckets[bucket].size(); } + void clear() { + for (size_t bucket = 0; bucket < CTR_SPARSE_SHARD_BUCKET_NUM; bucket++) { + map_type& data = _buckets[bucket]; + for (auto it = data.begin(); it != data.end(); ++it) { + _alloc.release((VALUE*)(void*)it->second); // NOLINT + } + data.clear(); + } } - - // for load, to reset count, unseen_days - FixedFeatureValue *GetValue(const uint64_t &id) { - size_t hash = hasher_(id); - size_t bucket = compute_bucket(hash); - - auto &table = values_[bucket]; - auto res = table.find(id); - return res->second; + iterator begin() { + auto it = _buckets[0].begin(); + size_t bucket = 0; + while (it == _buckets[bucket].end() && + bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) { + it = _buckets[++bucket].begin(); + } + return {it, bucket, _buckets}; } - - void erase(uint64_t feasign) { - size_t hash = hasher_(feasign); + iterator end() { + return {_buckets[CTR_SPARSE_SHARD_BUCKET_NUM - 1].end(), + CTR_SPARSE_SHARD_BUCKET_NUM - 1, _buckets}; + } + local_iterator begin(size_t bucket) { return {_buckets[bucket].begin()}; } + local_iterator end(size_t bucket) { return {_buckets[bucket].end()}; } + iterator find(const KEY& key) { + size_t hash = _hasher(key); size_t bucket = compute_bucket(hash); - auto &table = values_[bucket]; - - auto iter = table.find(feasign); - if (iter != table.end()) { - butil::return_object(iter->second); - iter = table.erase(iter); + auto it = _buckets[bucket].find_with_hash(key, hash); + if (it == _buckets[bucket].end()) { + return end(); } + return {it, bucket, _buckets}; + } + VALUE& operator[](const KEY& key) { return emplace(key).first.value(); } + std::pair insert(const KEY& key, const VALUE& val) { + return emplace(key, val); } + std::pair insert(const KEY& key, VALUE&& val) { + return emplace(key, std::move(val)); + } + template + std::pair emplace(const KEY& key, ARGS&&... args) { + size_t hash = _hasher(key); + size_t bucket = compute_bucket(hash); + auto res = _buckets[bucket].insert_with_hash({key, NULL}, hash); - void clear() {} + if (res.second) { + res.first->second = _alloc.acquire(std::forward(args)...); + } - size_t compute_bucket(size_t hash) { - if (CTR_SPARSE_SHARD_BUCKET_NUM == 1) { - return 0; - } else { - return hash >> (sizeof(size_t) * 8 - CTR_SPARSE_SHARD_BUCKET_NUM_BITS); + return {{res.first, bucket, _buckets}, res.second}; + } + iterator erase(iterator it) { + _alloc.release((VALUE*)(void*)it.it->second); // NOLINT + size_t bucket = it.bucket; + auto it2 = _buckets[bucket].erase(it.it); + while (it2 == _buckets[bucket].end() && + bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) { + it2 = _buckets[++bucket].begin(); } + return {it2, bucket, _buckets}; } - - map_type::iterator end() { - return values_[CTR_SPARSE_SHARD_BUCKET_NUM - 1].end(); + void quick_erase(iterator it) { + _alloc.release((VALUE*)(void*)it.it->second); // NOLINT + _buckets[it.bucket].quick_erase(it.it); } - - map_type::iterator Find(uint64_t id) { - size_t hash = hasher_(id); - size_t bucket = compute_bucket(hash); - auto &table = values_[bucket]; - - auto got = table.find(id); - if (got == table.end()) { - return end(); - } else { - return got; + local_iterator erase(size_t bucket, local_iterator it) { + _alloc.release((VALUE*)(void*)it.it->second); // NOLINT + return {_buckets[bucket].erase(it.it)}; + } + void quick_erase(size_t bucket, local_iterator it) { + _alloc.release((VALUE*)(void*)it.it->second); // NOLINT + _buckets[bucket].quick_erase(it.it); + } + size_t erase(const KEY& key) { + auto it = find(key); + if (it == end()) { + return 0; } + quick_erase(it); + return 1; } - - private: - bool Has(const uint64_t id) { - size_t hash = hasher_(id); - size_t bucket = compute_bucket(hash); - auto &table = values_[bucket]; - - auto got = table.find(id); - if (got == table.end()) { - return false; + size_t compute_bucket(size_t hash) { + if (CTR_SPARSE_SHARD_BUCKET_NUM == 1) { + return 0; } else { - return true; + return hash >> (sizeof(size_t) * 8 - CTR_SPARSE_SHARD_BUCKET_NUM_BITS); } } - public: - map_type values_[CTR_SPARSE_SHARD_BUCKET_NUM]; - std::hash hasher_; + private: + map_type _buckets[CTR_SPARSE_SHARD_BUCKET_NUM]; + ChunkAllocator _alloc; + std::hash _hasher; }; } // namespace distributed diff --git a/paddle/fluid/distributed/table/memory_sparse_table.cc b/paddle/fluid/distributed/table/memory_sparse_table.cc index e69de29bb2..7501207abe 100644 --- a/paddle/fluid/distributed/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/table/memory_sparse_table.cc @@ -0,0 +1,635 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/distributed/common/cost_timer.h" +#include "paddle/fluid/distributed/table/memory_sparse_table.h" +#include "paddle/fluid/framework/io/fs.h" + +#include "boost/lexical_cast.hpp" +#include "glog/logging.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace distributed { + +// TODO(zhaocaibei123): configure +bool FLAGS_pserver_create_value_when_push = false; +int FLAGS_pserver_table_save_max_retry = 3; +bool FLAGS_pserver_enable_create_feasign_randomly = false; + +int32_t MemorySparseTable::initialize() { + _shards_task_pool.resize(_task_pool_size); + for (int i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + auto& profiler = CostProfiler::instance(); + profiler.register_profiler("pserver_sparse_update_all"); + profiler.register_profiler("pserver_sparse_select_all"); + initialize_value(); + VLOG(0) << "initalize MemorySparseTable succ"; + return 0; +} + +int32_t MemorySparseTable::initialize_value() { + _sparse_table_shard_num = static_cast(_config.shard_num()); + _avg_local_shard_num = + SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num); + _real_local_shard_num = _avg_local_shard_num; + if (_real_local_shard_num * (_shard_idx + 1) > _sparse_table_shard_num) { + _real_local_shard_num = + _sparse_table_shard_num - _real_local_shard_num * _shard_idx; + _real_local_shard_num = + _real_local_shard_num < 0 ? 0 : _real_local_shard_num; + } + VLOG(1) << "memory sparse table _avg_local_shard_num: " + << _avg_local_shard_num + << " _real_local_shard_num: " << _real_local_shard_num; + + _local_shards.reset(new shard_type[_real_local_shard_num]); + + return 0; +} + +int32_t MemorySparseTable::load(const std::string& path, + const std::string& param) { + std::string table_path = table_dir(path); + auto file_list = _afs_client.list(table_path); + + std::sort(file_list.begin(), file_list.end()); + for (auto file : file_list) { + VLOG(1) << "MemorySparseTable::load() file list: " << file; + } + + int load_param = atoi(param.c_str()); + auto expect_shard_num = _sparse_table_shard_num; + if (file_list.size() != expect_shard_num) { + LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size() + << " not equal to expect_shard_num:" << expect_shard_num; + return -1; + } + if (file_list.size() == 0) { + LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path; + return -1; + } + + size_t file_start_idx = _shard_idx * _avg_local_shard_num; + + size_t feature_value_size = _value_accesor->size() / sizeof(float); + + int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + FsChannelConfig channel_config; + channel_config.path = file_list[file_start_idx + i]; + VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path + << " into local shard " << i; + channel_config.converter = _value_accesor->converter(load_param).converter; + channel_config.deconverter = + _value_accesor->converter(load_param).deconverter; + + bool is_read_failed = false; + int retry_num = 0; + int err_no = 0; + do { + is_read_failed = false; + err_no = 0; + std::string line_data; + auto read_channel = _afs_client.open_r(channel_config, 0, &err_no); + char* end = NULL; + auto& shard = _local_shards[i]; + try { + while (read_channel->read_line(line_data) == 0 && + line_data.size() > 1) { + uint64_t key = std::strtoul(line_data.data(), &end, 10); + auto& value = shard[key]; + value.resize(feature_value_size); + int parse_size = + _value_accesor->parse_from_string(++end, value.data()); + value.resize(parse_size); + + // for debug + for (int ii = 0; ii < parse_size; ++ii) { + VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii + << ": " << value.data()[ii] << " local_shard: " << i; + } + } + read_channel->close(); + if (err_no == -1) { + ++retry_num; + is_read_failed = true; + LOG(ERROR) + << "MemorySparseTable load failed after read, retry it! path:" + << channel_config.path << " , retry_num=" << retry_num; + } + } catch (...) { + ++retry_num; + is_read_failed = true; + LOG(ERROR) << "MemorySparseTable load failed, retry it! path:" + << channel_config.path << " , retry_num=" << retry_num; + } + if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) { + LOG(ERROR) << "MemorySparseTable load failed reach max limit!"; + exit(-1); + } + } while (is_read_failed); + } + LOG(INFO) << "MemorySparseTable load success, path from " + << file_list[file_start_idx] << " to " + << file_list[file_start_idx + _real_local_shard_num - 1]; + return 0; +} + +int32_t MemorySparseTable::load_local_fs(const std::string& path, + const std::string& param) { + std::string table_path = table_dir(path); + auto file_list = paddle::framework::localfs_list(table_path); + + int load_param = atoi(param.c_str()); + auto expect_shard_num = _sparse_table_shard_num; + if (file_list.size() != expect_shard_num) { + LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size() + << " not equal to expect_shard_num:" << expect_shard_num; + return -1; + } + if (file_list.size() == 0) { + LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path; + return -1; + } + + size_t file_start_idx = _shard_idx * _avg_local_shard_num; + + size_t feature_value_size = _value_accesor->size() / sizeof(float); + + int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + bool is_read_failed = false; + int retry_num = 0; + int err_no = 0; + do { + is_read_failed = false; + err_no = 0; + std::string line_data; + std::ifstream file(file_list[file_start_idx + i]); + char* end = NULL; + auto& shard = _local_shards[i]; + try { + while (std::getline(file, line_data) && line_data.size() > 1) { + uint64_t key = std::strtoul(line_data.data(), &end, 10); + auto& value = shard[key]; + value.resize(feature_value_size); + int parse_size = + _value_accesor->parse_from_string(++end, value.data()); + value.resize(parse_size); + } + file.close(); + if (err_no == -1) { + ++retry_num; + is_read_failed = true; + LOG(ERROR) + << "MemorySparseTable load failed after read, retry it! path:" + << file_list[file_start_idx + i] << " , retry_num=" << retry_num; + } + } catch (...) { + ++retry_num; + is_read_failed = true; + LOG(ERROR) << "MemorySparseTable load failed, retry it! path:" + << file_list[file_start_idx + i] + << " , retry_num=" << retry_num; + } + if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) { + LOG(ERROR) << "MemorySparseTable load failed reach max limit!"; + exit(-1); + } + } while (is_read_failed); + } + LOG(INFO) << "MemorySparseTable load success, path from " + << file_list[file_start_idx] << " to " + << file_list[file_start_idx + _real_local_shard_num - 1]; + return 0; +} + +int32_t MemorySparseTable::save(const std::string& dirname, + const std::string& param) { + VLOG(0) << "MemorySparseTable::save dirname: " << dirname; + int save_param = + atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 + std::string table_path = table_dir(dirname); + _afs_client.remove(paddle::string::format_string( + "%s/part-%03d-*", table_path.c_str(), _shard_idx)); + std::atomic feasign_size_all{0}; + + size_t file_start_idx = _avg_local_shard_num * _shard_idx; + + int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + FsChannelConfig channel_config; + if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) { + channel_config.path = paddle::string::format_string( + "%s/part-%03d-%05d.gz", table_path.c_str(), _shard_idx, + file_start_idx + i); + } else { + channel_config.path = + paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(), + _shard_idx, file_start_idx + i); + } + channel_config.converter = _value_accesor->converter(save_param).converter; + channel_config.deconverter = + _value_accesor->converter(save_param).deconverter; + bool is_write_failed = false; + int feasign_size = 0; + int retry_num = 0; + int err_no = 0; + auto& shard = _local_shards[i]; + do { + err_no = 0; + feasign_size = 0; + is_write_failed = false; + auto write_channel = + _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); + for (auto it = shard.begin(); it != shard.end(); ++it) { + if (_value_accesor->save(it.value().data(), save_param)) { + std::string format_value = _value_accesor->parse_to_string( + it.value().data(), it.value().size()); + if (0 != + write_channel->write_line(paddle::string::format_string( + "%lu %s", it.key(), format_value.c_str()))) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) + << "MemorySparseTable save prefix failed, retry it! path:" + << channel_config.path << " , retry_num=" << retry_num; + break; + } + ++feasign_size; + } + } + write_channel->close(); + if (err_no == -1) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) + << "MemorySparseTable save prefix failed after write, retry it! " + << "path:" << channel_config.path << " , retry_num=" << retry_num; + } + if (is_write_failed) { + _afs_client.remove(channel_config.path); + } + if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) { + LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!"; + exit(-1); + } + } while (is_write_failed); + feasign_size_all += feasign_size; + for (auto it = shard.begin(); it != shard.end(); ++it) { + _value_accesor->update_stat_after_save(it.value().data(), save_param); + } + LOG(INFO) << "MemorySparseTable save prefix success, path: " + << channel_config.path; + } + // int32 may overflow need to change return value + return 0; +} + +int32_t MemorySparseTable::save_local_fs(const std::string& dirname, + const std::string& param, + const std::string& prefix) { + int save_param = + atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 + std::string table_path = table_dir(dirname); + int feasign_cnt = 0; + size_t file_start_idx = _avg_local_shard_num * _shard_idx; + + int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; + std::atomic feasign_size_all{0}; + + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + feasign_cnt = 0; + auto& shard = _local_shards[i]; + std::string file_name = paddle::string::format_string( + "%s/part-%s-%03d-%05d", table_path.c_str(), prefix.c_str(), _shard_idx, + file_start_idx + i); + std::ofstream os; + os.open(file_name); + for (auto it = shard.begin(); it != shard.end(); ++it) { + if (_value_accesor->save(it.value().data(), save_param)) { + std::string format_value = _value_accesor->parse_to_string( + it.value().data(), it.value().size()); + std::string out_line = paddle::string::format_string( + "%lu %s\n", it.key(), format_value.c_str()); + // VLOG(2) << out_line.c_str(); + os.write(out_line.c_str(), sizeof(char) * out_line.size()); + ++feasign_cnt; + } + } + os.close(); + LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name + << "feasign_cnt: " << feasign_cnt; + } + return 0; +} + +int64_t MemorySparseTable::local_size() { + int64_t local_size = 0; + for (size_t i = 0; i < _real_local_shard_num; ++i) { + local_size += _local_shards[i].size(); + } + return local_size; +} + +int64_t MemorySparseTable::local_mf_size() { + std::vector size_arr(_real_local_shard_num, 0); + std::vector> tasks(_real_local_shard_num); + int64_t ret_size = 0; + for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = + _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( + [this, shard_id, &size_arr]() -> int { + auto& local_shard = _local_shards[shard_id]; + for (auto it = local_shard.begin(); it != local_shard.end(); + ++it) { + if (_value_accesor->has_mf(it.value().size())) { + size_arr[shard_id] += 1; + } + } + return 0; + }); + } + for (size_t i = 0; i < _real_local_shard_num; ++i) { + tasks[i].wait(); + } + for (auto x : size_arr) { + ret_size += x; + } + return ret_size; +} + +std::pair MemorySparseTable::print_table_stat() { + int64_t feasign_size = local_size(); + int64_t mf_size = local_mf_size(); + return {feasign_size, mf_size}; +} + +int32_t MemorySparseTable::pull_sparse(float* pull_values, + const PullSparseValue& pull_value) { + CostTimer timer("pserver_sparse_select_all"); + std::vector> tasks(_real_local_shard_num); + + const size_t value_size = _value_accesor->size() / sizeof(float); + size_t mf_value_size = _value_accesor->mf_size() / sizeof(float); + size_t select_value_size = _value_accesor->select_size() / sizeof(float); + // std::atomic missed_keys{0}; + + std::vector>> task_keys( + _real_local_shard_num); + size_t num = pull_value.numel_; + for (size_t i = 0; i < num; ++i) { + int shard_id = (pull_value.feasigns_[i] % _sparse_table_shard_num) % + _avg_local_shard_num; + task_keys[shard_id].push_back({pull_value.feasigns_[i], i}); + } + for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = + _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( + [this, shard_id, &task_keys, value_size, pull_values, mf_value_size, + select_value_size]() -> int { + auto& local_shard = _local_shards[shard_id]; + float data_buffer[value_size]; // NOLINT + float* data_buffer_ptr = data_buffer; + + auto& keys = task_keys[shard_id]; + for (size_t i = 0; i < keys.size(); i++) { + uint64_t key = keys[i].first; + auto itr = local_shard.find(key); + size_t data_size = value_size - mf_value_size; + if (itr == local_shard.end()) { + // ++missed_keys; + if (FLAGS_pserver_create_value_when_push) { + memset(data_buffer, 0, sizeof(float) * data_size); + } else { + auto& feature_value = local_shard[key]; + feature_value.resize(data_size); + float* data_ptr = feature_value.data(); + _value_accesor->create(&data_buffer_ptr, 1); + memcpy(data_ptr, data_buffer_ptr, + data_size * sizeof(float)); + } + } else { + data_size = itr.value().size(); + memcpy(data_buffer_ptr, itr.value().data(), + data_size * sizeof(float)); + } + for (int mf_idx = data_size; mf_idx < value_size; ++mf_idx) { + data_buffer[mf_idx] = 0.0; + } + auto offset = keys[i].second; + float* select_data = pull_values + select_value_size * offset; + _value_accesor->select(&select_data, + (const float**)&data_buffer_ptr, 1); + } + + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + + return 0; +} + +int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values, + const uint64_t* keys, size_t num) { + return 0; +} + +int32_t MemorySparseTable::push_sparse(const uint64_t* keys, + const float* values, size_t num) { + CostTimer timer("pserver_sparse_update_all"); + std::vector> tasks(_real_local_shard_num); + std::vector>> task_keys( + _real_local_shard_num); + for (size_t i = 0; i < num; ++i) { + int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num; + task_keys[shard_id].push_back({keys[i], i}); + } + + const size_t value_col = _value_accesor->size() / sizeof(float); + size_t mf_value_col = _value_accesor->mf_size() / sizeof(float); + size_t update_value_col = _value_accesor->update_size() / sizeof(float); + + for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( + [this, shard_id, value_col, mf_value_col, update_value_col, values, + &task_keys]() -> int { + auto& keys = task_keys[shard_id]; + auto& local_shard = _local_shards[shard_id]; + float data_buffer[value_col]; // NOLINT + float* data_buffer_ptr = data_buffer; + for (int i = 0; i < keys.size(); ++i) { + uint64_t key = keys[i].first; + uint64_t push_data_idx = keys[i].second; + const float* update_data = + values + push_data_idx * update_value_col; + auto itr = local_shard.find(key); + if (itr == local_shard.end()) { + VLOG(0) << "sparse table push_sparse: " << key << "not found!"; + if (FLAGS_pserver_enable_create_feasign_randomly && + !_value_accesor->create_value(1, update_data)) { + continue; + } + auto value_size = value_col - mf_value_col; + auto& feature_value = local_shard[key]; + feature_value.resize(value_size); + _value_accesor->create(&data_buffer_ptr, 1); + memcpy(feature_value.data(), data_buffer_ptr, + value_size * sizeof(float)); + itr = local_shard.find(key); + } + + auto& feature_value = itr.value(); + float* value_data = feature_value.data(); + size_t value_size = feature_value.size(); + + if (value_size == value_col) { // 已拓展到最大size, 则就地update + _value_accesor->update(&value_data, &update_data, 1); + } else { + // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 + memcpy(data_buffer_ptr, value_data, value_size * sizeof(float)); + _value_accesor->update(&data_buffer_ptr, &update_data, 1); + + if (_value_accesor->need_extend_mf(data_buffer)) { + feature_value.resize(value_col); + value_data = feature_value.data(); + _value_accesor->create(&value_data, 1); + } + memcpy(value_data, data_buffer_ptr, value_size * sizeof(float)); + } + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +int32_t MemorySparseTable::push_sparse(const uint64_t* keys, + const float** values, size_t num) { + _push_sparse(keys, values, num); + return 0; +} + +int32_t MemorySparseTable::_push_sparse(const uint64_t* keys, + const float** values, size_t num) { + std::vector> tasks(_real_local_shard_num); + std::vector>> task_keys( + _real_local_shard_num); + for (size_t i = 0; i < num; ++i) { + int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num; + task_keys[shard_id].push_back({keys[i], i}); + } + + size_t value_col = _value_accesor->size() / sizeof(float); + size_t mf_value_col = _value_accesor->mf_size() / sizeof(float); + size_t update_value_col = _value_accesor->update_size() / sizeof(float); + + for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( + [this, shard_id, value_col, mf_value_col, update_value_col, values, + &task_keys]() -> int { + auto& keys = task_keys[shard_id]; + auto& local_shard = _local_shards[shard_id]; + float data_buffer[value_col]; // NOLINT + float* data_buffer_ptr = data_buffer; + for (int i = 0; i < keys.size(); ++i) { + uint64_t key = keys[i].first; + uint64_t push_data_idx = keys[i].second; + const float* update_data = values[push_data_idx]; + auto itr = local_shard.find(key); + if (itr == local_shard.end()) { + if (FLAGS_pserver_enable_create_feasign_randomly && + !_value_accesor->create_value(1, update_data)) { + continue; + } + auto value_size = value_col - mf_value_col; + auto& feature_value = local_shard[key]; + feature_value.resize(value_size); + _value_accesor->create(&data_buffer_ptr, 1); + memcpy(feature_value.data(), data_buffer_ptr, + value_size * sizeof(float)); + itr = local_shard.find(key); + } + auto& feature_value = itr.value(); + float* value_data = feature_value.data(); + size_t value_size = feature_value.size(); + if (value_size == value_col) { // 已拓展到最大size, 则就地update + _value_accesor->update(&value_data, &update_data, 1); + } else { + // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 + memcpy(data_buffer_ptr, value_data, value_size * sizeof(float)); + _value_accesor->update(&data_buffer_ptr, &update_data, 1); + if (_value_accesor->need_extend_mf(data_buffer)) { + feature_value.resize(value_col); + value_data = feature_value.data(); + _value_accesor->create(&value_data, 1); + } + memcpy(value_data, data_buffer_ptr, value_size * sizeof(float)); + } + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +int32_t MemorySparseTable::flush() { return 0; } + +int32_t MemorySparseTable::shrink(const std::string& param) { + VLOG(0) << "MemorySparseTable::shrink"; + // TODO(zhaocaibei123): implement with multi-thread + for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + // shrink + auto& shard = _local_shards[shard_id]; + for (auto it = shard.begin(); it != shard.end();) { + if (_value_accesor->shrink(it.value().data())) { + it = shard.erase(it); + } else { + ++it; + } + } + } + return 0; +} + +void MemorySparseTable::clear() { VLOG(0) << "clear coming soon"; } + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/memory_sparse_table.h b/paddle/fluid/distributed/table/memory_sparse_table.h index e69de29bb2..cb552beab1 100644 --- a/paddle/fluid/distributed/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/table/memory_sparse_table.h @@ -0,0 +1,96 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include "Eigen/Dense" +#include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/distributed/table/common_table.h" +#include "paddle/fluid/distributed/table/depends/feature_value.h" +#include "paddle/fluid/string/string_helper.h" + +#define PSERVER_SAVE_SUFFIX ".shard" + +namespace paddle { +namespace distributed { + +class MemorySparseTable : public SparseTable { + public: + typedef SparseTableShard shard_type; + MemorySparseTable() {} + virtual ~MemorySparseTable() {} + + // unused method begin + virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } + virtual int32_t push_dense_param(const float* values, size_t num) { + return 0; + } + virtual int32_t push_dense(const float* values, size_t num) { return 0; } + // unused method end + + virtual int32_t initialize(); + virtual int32_t initialize_shard() { return 0; } + virtual int32_t initialize_value(); + + virtual int32_t load(const std::string& path, const std::string& param); + + virtual int32_t save(const std::string& path, const std::string& param); + + int32_t load_local_fs(const std::string& path, const std::string& param); + int32_t save_local_fs(const std::string& path, const std::string& param, + const std::string& prefix); + + int64_t local_size(); + int64_t local_mf_size(); + + virtual std::pair print_table_stat(); + virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); + + virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, + size_t num); + + virtual int32_t push_sparse(const uint64_t* keys, const float* values, + size_t num); + + virtual int32_t push_sparse(const uint64_t* keys, const float** values, + size_t num); + + virtual int32_t flush(); + virtual int32_t shrink(const std::string& param); + virtual void clear(); + + protected: + virtual int32_t _push_sparse(const uint64_t* keys, const float** values, + size_t num); + + protected: + const int _task_pool_size = 24; + size_t _avg_local_shard_num; + size_t _real_local_shard_num; + size_t _sparse_table_shard_num; + std::vector> _shards_task_pool; + std::unique_ptr _local_shards; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/dense_table_test.cc b/paddle/fluid/distributed/test/dense_table_test.cc index f2f1e098fa..2e48b791dc 100644 --- a/paddle/fluid/distributed/test/dense_table_test.cc +++ b/paddle/fluid/distributed/test/dense_table_test.cc @@ -27,9 +27,6 @@ class Table; TEST(CommonDenseTable, Adam) { int fea_dim = 10; int trainers = 2; - float beta1 = 0.9; - float beta2 = 0.999; - float epsilon = 1.0e-8; TableParameter table_config; table_config.set_table_class("CommonDenseTable"); @@ -39,27 +36,33 @@ TEST(CommonDenseTable, Adam) { accessor_config->set_accessor_class("CommMergeAccessor"); CommonAccessorParameter *common_config = table_config.mutable_common(); // set adam optimize config - common_config->set_name("adam"); + common_config->set_name("adam_d2sum"); common_config->set_table_name("adam_test_table"); common_config->set_trainer_num(trainers); common_config->add_params("Param"); common_config->add_dims(fea_dim); common_config->add_initializers("gaussian_random&0&0.0&1.0"); - common_config->add_params("LearningRate"); - common_config->add_dims(1); - common_config->add_initializers("fill_constant&1.0"); - common_config->add_params("Moment1"); + common_config->add_params("D2Sum"); + common_config->add_dims(fea_dim); + common_config->add_initializers("fill_constant&0.0"); + common_config->add_params("G2Sum"); common_config->add_dims(fea_dim); common_config->add_initializers("fill_constant&0.0"); - common_config->add_params("Moment2"); + common_config->add_params("Moment"); common_config->add_dims(fea_dim); common_config->add_initializers("fill_constant&0.0"); - common_config->add_params("Beta1Pow"); + common_config->add_params("MomentDecayRate"); common_config->add_dims(1); - common_config->add_initializers("fill_constant&1.0"); - common_config->add_params("Beta2Pow"); + common_config->add_initializers("fill_constant&0.99"); + common_config->add_params("AdaDecayRate"); common_config->add_dims(1); - common_config->add_initializers("fill_constant&1.0"); + common_config->add_initializers("fill_constant&0.9999"); + common_config->add_params("AdaEpsilon"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0e-8"); + common_config->add_params("LearningRate"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&5e-6"); auto ret = table->initialize(table_config, fs_config); ASSERT_EQ(ret, 0); @@ -89,29 +92,30 @@ TEST(CommonDenseTable, Adam) { pull_values.resize(fea_dim); table->pull_dense(pull_values.data(), fea_dim); - std::vector beta1_pow, beta2_pow, lr, mom1, mom2, param; - beta1_pow.push_back(beta1); - beta2_pow.push_back(beta2); - lr.push_back(1.0); + float mom_rate = 0.99; + float decay_rate = 0.9999; + float epsilon = 1.0e-8; + float lr = 5e-6; + std::vector d2sum, g2sum, mom, param; for (int i = 0; i < fea_dim; i++) { - mom1.push_back(0.0); - mom2.push_back(0.0); + mom.push_back(0.0); + d2sum.push_back(0.0); + g2sum.push_back(0.0); param.push_back(init_values[i]); } for (int i = 0; i < trainers; i++) { - auto lr_ = lr[0] * sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]); for (int j = 0; j < fea_dim; j++) { - mom1[j] = beta1 * mom1[j] + (1 - beta1) * trainer_gradient_values[i][j]; - mom2[j] = beta2 * mom2[j] + - (1 - beta2) * trainer_gradient_values[i][j] * - trainer_gradient_values[i][j]; - param[j] = - param[j] - - lr_ * (mom1[j] / (sqrt(mom2[j]) + epsilon * sqrt(1 - beta2_pow[0]))); + d2sum[j] = d2sum[j] * decay_rate + 1; + g2sum[j] = g2sum[j] * decay_rate + + trainer_gradient_values[i][j] * trainer_gradient_values[i][j]; + float scale = d2sum[j] * epsilon; + scale = (scale + d2sum[j]) / (scale + g2sum[j]); + scale = sqrt(scale); + mom[j] = (mom[j] - trainer_gradient_values[i][j]) * mom_rate + + trainer_gradient_values[i][j]; + param[j] = param[j] - lr * scale * mom[j]; } - beta1_pow[0] *= beta1; - beta2_pow[0] *= beta2; } for (int j = 0; j < fea_dim; j++) { ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5); diff --git a/paddle/fluid/distributed/test/feature_value_test.cc b/paddle/fluid/distributed/test/feature_value_test.cc index 9c9f0ffcac..9bd00dcc56 100644 --- a/paddle/fluid/distributed/test/feature_value_test.cc +++ b/paddle/fluid/distributed/test/feature_value_test.cc @@ -12,38 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - -#include -#include -#include // NOLINT +#include "paddle/fluid/distributed/table/depends/feature_value.h" #include - -#include "google/protobuf/text_format.h" #include "gtest/gtest.h" -#include "paddle/fluid/distributed/table/depends/feature_value.h" namespace paddle { namespace distributed { TEST(BENCHMARK, LargeScaleKV) { - std::shared_ptr shard = - std::make_shared(); + typedef SparseTableShard shard_type; + shard_type shard; uint64_t key = 1; - auto itr = shard->Find(key); - ASSERT_TRUE(itr == shard->end()); + auto itr = shard.find(key); + ASSERT_TRUE(itr == shard.end()); std::vector vec = {0.0, 0.1, 0.2, 0.3}; - auto* feature_value = shard->Init(key); - feature_value->resize(vec.size()); - memcpy(feature_value->data(), vec.data(), vec.size() * sizeof(float)); + auto& feature_value = shard[key]; + feature_value.resize(vec.size()); + memcpy(feature_value.data(), vec.data(), vec.size() * sizeof(float)); - itr = shard->Find(key); - ASSERT_TRUE(itr != shard->end()); + itr = shard.find(key); + ASSERT_TRUE(itr != shard.end()); - feature_value = itr->second; - float* value_data = feature_value->data(); + feature_value = itr.value(); + float* value_data = feature_value.data(); ASSERT_FLOAT_EQ(value_data[0], 0.0); ASSERT_FLOAT_EQ(value_data[1], 0.1); -- Gitee