diff --git a/mooncake-common/common.cmake b/mooncake-common/common.cmake index a35104d38..564b1b535 100644 --- a/mooncake-common/common.cmake +++ b/mooncake-common/common.cmake @@ -61,6 +61,7 @@ option(USE_MUSA "option for enabling gpu features for MTHREADS GPU" OFF) option(USE_HIP "option for enabling gpu features for AMD GPU" OFF) option(USE_NVMEOF "option for using NVMe over Fabric" OFF) option(USE_TCP "option for using TCP transport" ON) +option(USE_BAREX "option for using accl-barex transport" OFF) option(USE_ASCEND "option for using npu with HCCL" OFF) option(USE_ASCEND_DIRECT "option for using ascend npu with adxl engine" OFF) option(USE_ASCEND_HETEROGENEOUS "option for transferring between ascend npu and gpu" OFF) @@ -143,6 +144,10 @@ if (USE_TCP) add_compile_definitions(USE_TCP) endif() +if (USE_BAREX) + add_compile_definitions(USE_BAREX) +endif() + if (USE_ASCEND OR USE_ASCEND_DIRECT) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DOPEN_BUILD_PROJECT ") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DOPEN_BUILD_PROJECT ") diff --git a/mooncake-integration/allocator.py b/mooncake-integration/allocator.py index 185d7535f..af56cbc45 100644 --- a/mooncake-integration/allocator.py +++ b/mooncake-integration/allocator.py @@ -94,6 +94,6 @@ def get_allocator(cls, device: torch_device) -> CUDAPluggableAllocator: if device not in cls._instances: so_path = cls._get_so_path() cls._instances[device] = CUDAPluggableAllocator( - so_path, "u2mm_alloc_wrapper", "u2mm_free_wrapper" + so_path, "u2mm_alloc_wrapper_with_stream", "u2mm_free_wrapper_with_stream" ) return cls._instances[device] diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 316a4d109..26151786e 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -132,7 +132,21 @@ int TransferEnginePy::initializeExt(const char *local_hostname, free_list_.resize(kSlabSizeKBTabLen); #if !defined(USE_ASCEND) && !defined(USE_ASCEND_DIRECT) && \ !defined(USE_ASCEND_HETEROGENEOUS) - doBuddyAllocate(kMaxClassId); + bool pass_alloc = false; + const char *pass_alloc_env = std::getenv("PASS_ALLOC"); + if (pass_alloc_env) { + try { + if (std::stoi(pass_alloc_env) != 0) { + pass_alloc = true; + } + } catch (const std::exception &) { + LOG(WARNING) << "Ignore value from environment variable " + "PASS_ALLOC"; + } + } + if (!pass_alloc) { + doBuddyAllocate(kMaxClassId); + } #endif return 0; } @@ -266,6 +280,9 @@ int TransferEnginePy::transferSync(const char *target_hostname, if (handle_map_.count(target_hostname)) { handle = handle_map_[target_hostname]; } else { + LOG(INFO) + << "transferSync, cache not found, openSegment with target " + << target_hostname; handle = engine_->openSegment(target_hostname); if (handle == (Transport::SegmentHandle)-1) return -1; handle_map_[target_hostname] = handle; @@ -300,7 +317,19 @@ int TransferEnginePy::transferSync(const char *target_hostname, batch_id, {entry}, TransferMetadata::NotifyDesc{notify->name, notify->msg}) : engine_->submitTransfer(batch_id, {entry}); - if (!s.ok()) return -1; + if (!s.ok()) { + Status segment_status = engine_->CheckSegmentStatus(handle); + if (!segment_status.ok()) { + LOG(WARNING) + << "submitTransfer failed with target " << target_hostname + << ", CheckSegmentStatus not ok, ready to closeSegment"; + std::lock_guard guard(mutex_); + engine_->closeSegment(handle); + engine_->getMetadata()->removeSegmentDesc(target_hostname); + handle_map_.erase(target_hostname); + } + return -1; + } TransferStatus status; bool completed = false; @@ -387,6 +416,16 @@ int TransferEnginePy::batchTransferSync( : engine_->submitTransfer(batch_id, entries); if (!s.ok()) { engine_->freeBatchID(batch_id); + Status segment_status = engine_->CheckSegmentStatus(handle); + if (!segment_status.ok()) { + LOG(WARNING) + << "submitTransfer failed with target " << target_hostname + << ", CheckSegmentStatus not ok, ready to closeSegment"; + std::lock_guard guard(mutex_); + engine_->closeSegment(handle); + engine_->getMetadata()->removeSegmentDesc(target_hostname); + handle_map_.erase(target_hostname); + } return -1; } diff --git a/mooncake-transfer-engine/example/transfer_engine_bench.cpp b/mooncake-transfer-engine/example/transfer_engine_bench.cpp index dde71b57a..21687a068 100644 --- a/mooncake-transfer-engine/example/transfer_engine_bench.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_bench.cpp @@ -66,7 +66,7 @@ DEFINE_string(mode, "initiator", "data blocks from target node"); DEFINE_string(operation, "read", "Operation type: read or write"); -DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|tcp"); +DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|barex|tcp"); DEFINE_string(device_name, "mlx5_2", "Device name to use, valid if protocol=rdma"); @@ -301,6 +301,12 @@ int initiator() { args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; xport = engine->installTransport("rdma", args); + } else if (FLAGS_protocol == "barex") { + auto nic_priority_matrix = loadNicPriorityMatrix(); + void **args = (void **)malloc(2 * sizeof(void *)); + args[0] = (void *)nic_priority_matrix.c_str(); + args[1] = nullptr; + xport = engine->installTransport("barex", args); } else if (FLAGS_protocol == "tcp") { xport = engine->installTransport("tcp", nullptr); } else if (FLAGS_protocol == "nvlink") { @@ -421,6 +427,12 @@ int target() { args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; engine->installTransport("rdma", args); + } else if (FLAGS_protocol == "barex") { + auto nic_priority_matrix = loadNicPriorityMatrix(); + void **args = (void **)malloc(2 * sizeof(void *)); + args[0] = (void *)nic_priority_matrix.c_str(); + args[1] = nullptr; + engine->installTransport("barex", args); } else if (FLAGS_protocol == "tcp") { engine->installTransport("tcp", nullptr); } else if (FLAGS_protocol == "nvlink") { diff --git a/mooncake-transfer-engine/include/config.h b/mooncake-transfer-engine/include/config.h index 41e92ecc6..5ff3b9b33 100644 --- a/mooncake-transfer-engine/include/config.h +++ b/mooncake-transfer-engine/include/config.h @@ -56,6 +56,7 @@ struct GlobalConfig { bool use_ipv6 = false; size_t fragment_limit = 16384; bool enable_dest_device_affinity = false; + size_t eic_max_block_size = 64UL * 1024 * 1024; EndpointStoreType endpoint_store_type = EndpointStoreType::SIEVE; }; diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index 117416843..40472cf22 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -95,6 +95,8 @@ class TransferEngine { SegmentHandle openSegment(const std::string &segment_name); + Status CheckSegmentStatus(SegmentID sid); + int closeSegment(SegmentHandle handle); int removeLocalSegment(const std::string &segment_name); @@ -249,6 +251,7 @@ class TransferEngine { // Set it to false only for testing. bool auto_discover_; std::vector filter_; + bool use_barex_ = false; #ifdef WITH_METRICS ylt::metric::counter_t transferred_bytes_counter_{ diff --git a/mooncake-transfer-engine/include/transfer_metadata.h b/mooncake-transfer-engine/include/transfer_metadata.h index ba35c4f17..a5133aac1 100644 --- a/mooncake-transfer-engine/include/transfer_metadata.h +++ b/mooncake-transfer-engine/include/transfer_metadata.h @@ -103,12 +103,14 @@ class TransferMetadata { struct RpcMetaDesc { std::string ip_or_host_name; uint16_t rpc_port; + uint16_t barex_port; int sockfd; // local cache }; struct HandShakeDesc { std::string local_nic_path; std::string peer_nic_path; + uint16_t barex_port; std::vector qp_num; std::string reply_msg; // on error }; diff --git a/mooncake-transfer-engine/include/transfer_metadata_plugin.h b/mooncake-transfer-engine/include/transfer_metadata_plugin.h index 44b5610f0..b95f7d31f 100644 --- a/mooncake-transfer-engine/include/transfer_metadata_plugin.h +++ b/mooncake-transfer-engine/include/transfer_metadata_plugin.h @@ -69,7 +69,7 @@ struct HandShakePlugin { std::vector findLocalIpAddresses(); -uint16_t findAvailableTcpPort(int &sockfd); +uint16_t findAvailableTcpPort(int &sockfd, bool set_range = false); } // namespace mooncake diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h new file mode 100644 index 000000000..a8466f5ca --- /dev/null +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h @@ -0,0 +1,194 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BAREX_CONTEXT_H_ +#define BAREX_CONTEXT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "transport/transport.h" + +#ifdef USE_BAREX +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace mooncake { + +#ifdef USE_BAREX + +using namespace accl::barex; +using XChannel = accl::barex::XChannel; +using SegmentID = Transport::SegmentID; +using XContext = accl::barex::XContext; +using BarexResult = accl::barex::BarexResult; + +class ChannelCache { + public: + // put channel + void put(SegmentID key, int nic_id, XChannel* channel) { + RWSpinlock::WriteGuard guard(lock_); + auto& channels = cache_[key]; + auto& vec = channels[nic_id]; + status_map_[key] = true; + vec.push_back(channel); + } + + // get channel + XChannel* find(SegmentID key, int nic_id, int idx) { + RWSpinlock::ReadGuard guard(lock_); + auto it = cache_.find(key); + if (it == cache_.end()) return nullptr; + auto& channels = it->second; + auto ch_it = channels.find(nic_id); + if (ch_it == channels.end()) return nullptr; + auto& vec = ch_it->second; + if (idx >= 0 && idx < static_cast(vec.size())) { + return vec[idx]; + } + return nullptr; + } + + // delete channel + bool erase(SegmentID key, int nic_id, int idx) { + RWSpinlock::WriteGuard guard(lock_); + auto it = cache_.find(key); + if (it == cache_.end()) return false; + + auto& channels = it->second; + auto ch_it = channels.find(nic_id); + if (ch_it == channels.end()) return false; + + auto& vec = ch_it->second; + if (idx < 0 || idx >= static_cast(vec.size())) return false; + + vec.erase(vec.begin() + idx); + status_map_[key] = false; + if (vec.empty()) { + channels.erase(ch_it); + if (channels.empty()) { + cache_.erase(it); + } + } + return true; + } + + // get channel state + bool CheckAllChannels(SegmentID segment_id) { + RWSpinlock::ReadGuard guard(lock_); + auto it = cache_.find(segment_id); + if (it == cache_.end()) { + return false; + } + auto& inner_map = it->second; + for (auto& pair : inner_map) { + auto& channels = pair.second; + for (XChannel* channel : channels) { + if (!channel->IsActive()) { + return false; + } + } + } + return true; + } + + // check and delete invalid channels + int RemoveInvalidChannels(SegmentID segment_id) { + RWSpinlock::WriteGuard guard(lock_); + auto it = cache_.find(segment_id); + if (it == cache_.end()) { + return 0; + } + + int invalid_count = 0; + auto& inner_map = it->second; + + for (auto& pair : inner_map) { + auto& channels = pair.second; + auto new_end = std::remove_if( + channels.begin(), channels.end(), + [](XChannel* channel) { return !channel->IsActive(); }); + invalid_count += std::distance(new_end, channels.end()); + channels.erase(new_end, channels.end()); + } + return invalid_count; + } + + // get all channels + std::vector copyAll() { + RWSpinlock::WriteGuard guard(lock_); + std::vector result; + for (const auto& [key, channels] : cache_) { + for (const auto& [nic_id, vec] : channels) { + result.insert(result.end(), vec.begin(), vec.end()); + } + } + return result; + } + + private: + std::unordered_map>> + cache_; + std::unordered_map status_map_; + RWSpinlock lock_; +}; +class BarexContext { + public: + int submitPostSend(const std::vector& slice_list); + int addChannel(SegmentID sid, int device_id, XChannel* ch); + XChannel* getChannel(SegmentID sid, int device_id, int idx); + int checkStatus(SegmentID sid); + XContext* getCtx(); + // int ClearAllChannel(); + std::vector getAllChannel(); + bool active() const { return active_; } + void setQpNum(int qp_num) { qp_num_per_ctx_ = qp_num; } + int getQpNum() const { return qp_num_per_ctx_; } + + public: + BarexContext(XContext* xcontext, bool use_cpu, int device_id); + + ~BarexContext(); + + XContext* xcontext_; + bool barex_use_cpu_; + int barex_local_device_; + + private: + ChannelCache channel_cache_; + bool active_ = true; + int qp_num_per_ctx_ = 2; +}; +#endif +} // namespace mooncake + +#endif // BAREX_CONTEXT_H_ \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h new file mode 100644 index 000000000..cf6f2fe1d --- /dev/null +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h @@ -0,0 +1,175 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BAREX_TRANSPORT_H_ +#define BAREX_TRANSPORT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "topology.h" +#include "transfer_metadata.h" +#include "transport/transport.h" +#include "transport/barex_transport/barex_context.h" + +namespace mooncake { + +using TransferRequest = Transport::TransferRequest; +using TransferStatus = Transport::TransferStatus; +using TransferStatusEnum = Transport::TransferStatusEnum; +using SegmentID = Transport::SegmentID; +using BatchID = Transport::BatchID; + +class TransferMetadata; +class CountDownLatch { + private: + int count_; + std::mutex mtx; + std::condition_variable cv; + + public: + CountDownLatch(int count) : count_(count) {}; + + void CountDown() { + std::unique_lock lk(mtx); + count_--; + if (count_ <= 0) { + cv.notify_all(); + } + } + + void Wait() { + std::unique_lock lk(mtx); + cv.wait(lk, [this] { return count_ <= 0; }); + } +}; +class BarexTransport : public Transport { + public: + using BufferDesc = TransferMetadata::BufferDesc; + using SegmentDesc = TransferMetadata::SegmentDesc; + using HandShakeDesc = TransferMetadata::HandShakeDesc; + + public: + BarexTransport(); + + ~BarexTransport(); + + int install(std::string &local_server_name, + std::shared_ptr meta, + std::shared_ptr topo) override; + + const char *getName() const override { return "barex"; } + + void setLocalPort(int port) { local_port_ = port; } + + void setPeerPort(int port) { peer_port_ = port; } + + int getLocalPort() { return local_port_; } + + int getPeerPort() { return peer_port_; } + + int registerLocalMemory(void *addr, size_t length, + const std::string &location, bool remote_accessible, + bool update_metadata) override; + + int registerLocalMemoryBase(void *addr, size_t length, + const std::string &location, + bool remote_accessible, bool update_metadata, + bool is_gpu); + + int unregisterLocalMemory(void *addr, bool update_metadata = true) override; + + int registerLocalMemoryBatch(const std::vector &buffer_list, + const std::string &location) override; + + int unregisterLocalMemoryBatch( + const std::vector &addr_list) override; + + // TRANSFER + + Status submitTransfer(BatchID batch_id, + const std::vector &entries) override; + + Status submitTransferTask( + const std::vector &task_list) override; + + Status getTransferStatus(BatchID batch_id, + std::vector &status); + + Status getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) override; + + SegmentID getSegmentID(const std::string &segment_name); + + Status OpenChannel(const std::string &segment_name, SegmentID sid) override; + Status CheckStatus(SegmentID sid) override; + + private: + int allocateLocalSegmentID(); + + public: + int onSetupRdmaConnections(const HandShakeDesc &peer_desc, + HandShakeDesc &local_desc); + + int sendHandshake(const std::string &peer_server_name, + const HandShakeDesc &local_desc, + HandShakeDesc &peer_desc) { + return metadata_->sendHandshake(peer_server_name, local_desc, + peer_desc); + } + + private: + int initializeRdmaResources(); + + int startHandshakeDaemon(std::string &local_server_name); + + public: + static int selectDevice(SegmentDesc *desc, uint64_t offset, size_t length, + int &buffer_id, int &device_id, int retry_cnt = 0); + + private: +#ifdef USE_BAREX + std::vector> server_context_list_; + std::vector> client_context_list_; + std::shared_ptr server_threadpool_; + std::shared_ptr client_threadpool_; + std::shared_ptr mempool_; + std::shared_ptr listener_; + std::shared_ptr connector_; +#endif + std::shared_ptr local_topology_; + std::mutex buf_mutex_; + std::map> buf_length_map_; + bool use_random_dev_ = false; + bool barex_use_cpu_ = false; + int barex_local_device_ = 0; + int local_port_ = 8089; + int peer_port_ = 8089; + std::random_device rd; +}; + +} // namespace mooncake + +#endif // BAREX_TRANSPORT_H_ \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transport/transport.h b/mooncake-transfer-engine/include/transport/transport.h index dcccc7789..0510c19f3 100644 --- a/mooncake-transfer-engine/include/transport/transport.h +++ b/mooncake-transfer-engine/include/transport/transport.h @@ -90,6 +90,7 @@ class Transport { std::string peer_nic_path; SliceStatus status; TransferTask *task; + std::vector dest_rkeys; bool from_cache; union { @@ -97,6 +98,7 @@ class Transport { uint64_t dest_addr; uint32_t source_lkey; uint32_t dest_rkey; + int lkey_index; int rkey_index; volatile int *qp_depth; uint32_t retry_cnt; @@ -257,6 +259,11 @@ class Transport { size_t length; }; + virtual Status OpenChannel(const std::string &segment_name, SegmentID sid) { + return Status::OK(); + } + virtual Status CheckStatus(SegmentID sid) { return Status::OK(); } + protected: virtual int install(std::string &local_server_name, std::shared_ptr meta, diff --git a/mooncake-transfer-engine/src/CMakeLists.txt b/mooncake-transfer-engine/src/CMakeLists.txt index 9c7f92525..0b92a0f82 100644 --- a/mooncake-transfer-engine/src/CMakeLists.txt +++ b/mooncake-transfer-engine/src/CMakeLists.txt @@ -9,7 +9,7 @@ if (BUILD_SHARED_LIBS) install(TARGETS transfer_engine DESTINATION lib) endif() -add_compile_definitions(transfer_engine PUBLIC MOONCAKE_USE_ETCD) +add_compile_definitions(transfer_engine PUBLIC MOONCAKE_USE_ETCD CMAKE_INCLUDE) if (USE_ETCD) if (USE_ETCD_LEGACY) if (USE_STATIC_ETCD_CPP_API) @@ -39,6 +39,10 @@ target_link_libraries( base transport rdma_transport ibverbs glog::glog gflags::gflags pthread JsonCpp::JsonCpp numa yalantinglibs::yalantinglibs ) +if (USE_BAREX) + target_link_libraries(transfer_engine PUBLIC barex_transport) +endif() + if (USE_CUDA) target_include_directories(transfer_engine PRIVATE /usr/local/cuda/include) target_link_libraries(transfer_engine PUBLIC cuda cudart rt) diff --git a/mooncake-transfer-engine/src/config.cpp b/mooncake-transfer-engine/src/config.cpp index 80828402d..c2bf041ad 100644 --- a/mooncake-transfer-engine/src/config.cpp +++ b/mooncake-transfer-engine/src/config.cpp @@ -170,6 +170,17 @@ void loadGlobalConfig(GlobalConfig &config) { << "Ignore value from environment variable MC_SLICE_SIZE"; } + const char *min_reg_size_env = std::getenv("MC_MIN_REG_SIZE"); + if (min_reg_size_env) { + size_t val = atoll(min_reg_size_env); + if (val > 0) { + config.eic_max_block_size = val; + LOG(INFO) << "Barex set MC_MIN_REG_SIZE=" << val; + } else + LOG(WARNING) + << "Ignore value from environment variable MC_MIN_REG_SIZE"; + } + const char *retry_cnt_env = std::getenv("MC_RETRY_CNT"); if (retry_cnt_env) { size_t val = atoi(retry_cnt_env); diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp index 9c24836a2..3f2078940 100644 --- a/mooncake-transfer-engine/src/multi_transport.cpp +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -17,6 +17,9 @@ #include "config.h" #include "transport/rdma_transport/rdma_transport.h" +#ifdef USE_BAREX +#include "transport/barex_transport/barex_transport.h" +#endif #ifdef USE_TCP #include "transport/tcp_transport/tcp_transport.h" #endif @@ -202,6 +205,11 @@ Transport *MultiTransport::installTransport(const std::string &proto, if (std::string(proto) == "rdma") { transport = new RdmaTransport(); } +#ifdef USE_BAREX + else if (std::string(proto) == "barex") { + transport = new BarexTransport(); + } +#endif #ifdef USE_TCP else if (std::string(proto) == "tcp") { transport = new TcpTransport(); @@ -244,6 +252,40 @@ Transport *MultiTransport::installTransport(const std::string &proto, return nullptr; } +#ifdef USE_BAREX + bool use_eic = false; + for (auto &dev : topo->getHcaList()) { + if (dev.find("soe") != std::string::npos || + dev.find("solar") != std::string::npos) { + use_eic = true; + } + } + + if (std::string(proto) == "barex") { + std::string nics; + for (auto &dev : topo->getHcaList()) { + if (use_eic) { + if (dev.find("soe") == std::string::npos && + dev.find("solar") == std::string::npos) { + // ignore no eic nics + continue; + } + } + nics += dev; + nics += ","; + } + + // Remove the last extra comma + if (!nics.empty()) { + nics.pop_back(); + } + + if (!nics.empty()) { + LOG(INFO) << "ACCL_USE_NICS is set to " << nics; + setenv("ACCL_USE_NICS", nics.c_str(), 1); + } + } +#endif if (transport->install(local_server_name_, metadata_, topo)) { return nullptr; } diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 32b46b13f..d644166c0 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -25,6 +25,7 @@ #include "transfer_metadata_plugin.h" #include "transport/transport.h" +#include "transport/barex_transport/barex_transport.h" namespace mooncake { @@ -72,6 +73,15 @@ int TransferEngine::init(const std::string &metadata_conn_string, "files are opened."; } // Set resources to the maximum value +#ifdef USE_BAREX + const char *use_barex_env = std::getenv("USE_BAREX"); + if (use_barex_env) { + int val = atoi(use_barex_env); + if (val != 0) { + use_barex_ = true; + } + } +#endif #ifdef USE_ASCEND // The only difference in initializing the Ascend Transport is that the @@ -99,7 +109,19 @@ int TransferEngine::init(const std::string &metadata_conn_string, desc.ip_or_host_name = host_name; desc.rpc_port = port; desc.sockfd = -1; - +#ifdef USE_BAREX + if (use_barex_) { + int tmp_fd = -1; + desc.barex_port = findAvailableTcpPort(tmp_fd, true); + if (desc.barex_port == 0) { + LOG(ERROR) + << "Barex: No valid port found for local barex service."; + return -1; + } + close(tmp_fd); + tmp_fd = -1; + } +#endif if (metadata_conn_string == P2PHANDSHAKE) { rpc_binding_method = "P2P handshake"; desc.rpc_port = findAvailableTcpPort(desc.sockfd); @@ -145,7 +167,10 @@ int TransferEngine::init(const std::string &metadata_conn_string, LOG(INFO) << "Transfer Engine RPC using " << rpc_binding_method << ", listening on " << desc.ip_or_host_name << ":" - << desc.rpc_port; + << desc.rpc_port + << (use_barex_ + ? ", barex use port:" + std::to_string(desc.barex_port) + : ""); metadata_ = std::make_shared(metadata_conn_string); #ifdef USE_ASCEND @@ -228,11 +253,26 @@ int TransferEngine::init(const std::string &metadata_conn_string, if (local_topology_->getHcaList().size() > 0 && !getenv("MC_FORCE_TCP")) { // only install RDMA transport when there is at least one HCA - Transport *rdma_transport = - multi_transports_->installTransport("rdma", local_topology_); - if (!rdma_transport) { - LOG(ERROR) << "Failed to install RDMA transport"; + Transport *rdma_transport = nullptr; + if (use_barex_) { +#ifdef USE_BAREX + rdma_transport = multi_transports_->installTransport( + "barex", local_topology_); +#else + LOG(ERROR) << "Set USE BAREX while barex not compiled"; + return -1; +#endif + } else { + rdma_transport = multi_transports_->installTransport( + "rdma", local_topology_); + } + if (rdma_transport == nullptr) { + LOG(ERROR) << "Failed to install RDMA transport, type=" + << (use_barex_ ? "barex" : "rdma"); return -1; + } else { + LOG(INFO) << "installTransport, type=" + << (use_barex_ ? "barex" : "rdma"); } } else { Transport *tcp_transport = @@ -328,7 +368,37 @@ Transport::SegmentHandle TransferEngine::openSegment( while (!trimmed_segment_name.empty() && trimmed_segment_name[0] == '/') trimmed_segment_name.erase(0, 1); if (trimmed_segment_name.empty()) return ERR_INVALID_ARGUMENT; - return metadata_->getSegmentID(trimmed_segment_name); + SegmentID sid = metadata_->getSegmentID(trimmed_segment_name); +#ifdef USE_BAREX + if (use_barex_) { + Transport *transport = multi_transports_->getTransport("barex"); + if (!transport) { + LOG(ERROR) << "Barex proto not installed"; + return (Transport::SegmentHandle)-1; + } + Status s = transport->OpenChannel(segment_name, sid); + if (!s.ok()) { + LOG(ERROR) << "openSegment, OpenChannel failed"; + return (Transport::SegmentHandle)-1; + } + } +#endif + return sid; +} + +Status TransferEngine::CheckSegmentStatus(SegmentID sid) { +#ifdef USE_BAREX + if (use_barex_) { + Transport *transport = multi_transports_->getTransport("barex"); + BarexTransport *barex_transport = + dynamic_cast(transport); + return barex_transport->CheckStatus(sid); + } else { + return Status::OK(); + } +#else + return Status::OK(); +#endif } int TransferEngine::closeSegment(Transport::SegmentHandle handle) { return 0; } diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index 1024f6577..1fe503fae 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -55,6 +55,7 @@ struct TransferHandshakeUtil { Json::Value root; root["local_nic_path"] = desc.local_nic_path; root["peer_nic_path"] = desc.peer_nic_path; + root["barex_port"] = desc.barex_port; Json::Value qpNums(Json::arrayValue); for (const auto &qp : desc.qp_num) qpNums.append(qp); root["qp_num"] = qpNums; @@ -65,6 +66,7 @@ struct TransferHandshakeUtil { static int decode(Json::Value root, TransferMetadata::HandShakeDesc &desc) { desc.local_nic_path = root["local_nic_path"].asString(); desc.peer_nic_path = root["peer_nic_path"].asString(); + desc.barex_port = root["barex_port"].asInt(); for (const auto &qp : root["qp_num"]) desc.qp_num.push_back(qp.asUInt()); desc.reply_msg = root["reply_msg"].asString(); @@ -157,7 +159,8 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, segmentJSON["tcp_data_port"] = desc.tcp_data_port; segmentJSON["timestamp"] = getCurrentDateTime(); - if (segmentJSON["protocol"] == "rdma") { + if (segmentJSON["protocol"] == "rdma" || + segmentJSON["protocol"] == "barex") { Json::Value devicesJSON(Json::arrayValue); for (const auto &device : desc.devices) { Json::Value deviceJSON; @@ -286,6 +289,15 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, int TransferMetadata::removeSegmentDesc(const std::string &segment_name) { if (p2p_handshake_mode_) { + auto iter = segment_name_to_id_map_.find(segment_name); + if (iter != segment_name_to_id_map_.end()) { + LOG(INFO) << "removeSegmentDesc " << segment_name << " finish"; + segment_id_to_desc_map_.erase(iter->second); + segment_name_to_id_map_.erase(iter); + } else { + LOG(INFO) << "removeSegmentDesc " << segment_name + << " not found, already removed maybe"; + } return 0; } if (!storage_plugin_->remove(getFullMetadataKey(segment_name))) { @@ -306,7 +318,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, if (segmentJSON.isMember("timestamp")) desc->timestamp = segmentJSON["timestamp"].asString(); - if (desc->protocol == "rdma") { + if (desc->protocol == "rdma" || desc->protocol == "barex") { for (const auto &deviceJSON : segmentJSON["devices"]) { DeviceDesc device; device.name = deviceJSON["name"].asString(); @@ -332,8 +344,11 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, if (buffer.name.empty() || !buffer.addr || !buffer.length || buffer.rkey.empty() || buffer.rkey.size() != buffer.lkey.size()) { - LOG(WARNING) << "Corrupted segment descriptor, name " - << segment_name << " protocol " << desc->protocol; + LOG(WARNING) + << "Corrupted segment descriptor, name " << segment_name + << " protocol " << desc->protocol << ", " << buffer.name + << ", " << buffer.addr << ", " << buffer.length << ", " + << buffer.rkey.size() << ", " << buffer.lkey.size(); return nullptr; } desc->buffers.push_back(buffer); diff --git a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp index 1e686c711..968430d43 100644 --- a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp @@ -1135,11 +1135,31 @@ std::vector findLocalIpAddresses() { return ips; } -uint16_t findAvailableTcpPort(int &sockfd) { +uint16_t findAvailableTcpPort(int &sockfd, bool set_range) { static std::random_device rand_gen; std::uniform_int_distribution rand_dist; - const int min_port = globalConfig().rpc_min_port; - const int max_port = globalConfig().rpc_max_port; + int min_port = globalConfig().rpc_min_port; + int max_port = globalConfig().rpc_max_port; +#ifdef USE_BAREX + if (set_range) { + min_port = 17000; + max_port = 35000; + const char *min_port_env = std::getenv("ACCL_MIN_PORT"); + const char *max_port_env = std::getenv("ACCL_MAX_PORT"); + if (min_port_env) { + int val = atoi(min_port_env); + if (val > 1024 && val < 65536) { + min_port = val; + } + } + if (max_port_env) { + int val = atoi(max_port_env); + if (val > 1024 && val < 65536 && val > min_port) { + max_port = val; + } + } + } +#endif const int max_attempts = 500; bool use_ipv6 = globalConfig().use_ipv6; diff --git a/mooncake-transfer-engine/src/transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/CMakeLists.txt index 5517a5ddc..026d75fe9 100644 --- a/mooncake-transfer-engine/src/transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/CMakeLists.txt @@ -9,6 +9,11 @@ if (USE_TCP) target_sources(transport PUBLIC $) endif() +if (USE_BAREX) + add_subdirectory(barex_transport) + target_sources(transport PUBLIC $) +endif() + if (USE_NVMEOF) add_subdirectory(nvmeof_transport) target_sources(transport PUBLIC $) diff --git a/mooncake-transfer-engine/src/transport/barex_transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/barex_transport/CMakeLists.txt new file mode 100644 index 000000000..f5df7ce9b --- /dev/null +++ b/mooncake-transfer-engine/src/transport/barex_transport/CMakeLists.txt @@ -0,0 +1,5 @@ +file(GLOB BAREX_SOURCES "*.cpp") + +add_library(barex_transport OBJECT ${BAREX_SOURCES}) +target_link_libraries(barex_transport PRIVATE pthread accl_barex) +target_compile_definitions(barex_transport PRIVATE CMAKE_INCLUDE=1) \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp b/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp new file mode 100644 index 000000000..848f79fc9 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp @@ -0,0 +1,199 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport/barex_transport/barex_context.h" + +namespace mooncake { + +using namespace accl::barex; + +BarexContext::BarexContext(XContext* xcontext, bool use_cpu, int device_id) + : xcontext_(xcontext), + barex_use_cpu_(use_cpu), + barex_local_device_(device_id) {} + +BarexContext::~BarexContext() { + if (xcontext_) { + xcontext_->Shutdown(); + xcontext_->WaitStop(); + delete xcontext_; + } +} + +int BarexContext::addChannel(SegmentID sid, int device_id, XChannel* ch) { + channel_cache_.put(sid, device_id, ch); + return 0; +} + +XChannel* BarexContext::getChannel(SegmentID sid, int device_id, int idx) { + XChannel* channel = channel_cache_.find(sid, device_id, idx); + return channel; +} + +int BarexContext::checkStatus(SegmentID sid) { + return channel_cache_.RemoveInvalidChannels(sid); +} + +XContext* BarexContext::getCtx() { return xcontext_; } + +std::vector BarexContext::getAllChannel() { + return channel_cache_.copyAll(); +} + +int BarexContext::submitPostSend( + const std::vector& slice_list) { + std::unordered_map< + SegmentID, std::unordered_map>> + sid_dev_data_map; + std::unordered_map>> + sid_dev_slice_map; + for (auto slice : slice_list) { + accl::barex::rw_memp_t w_m; + w_m.sg.addr = (uint64_t)slice->source_addr; + w_m.sg.length = (uint32_t)slice->length; + w_m.sg.lkey = slice->rdma.source_lkey; + w_m.data.d_type = barex_use_cpu_ ? CPU : GPU; + w_m.data.device_id = barex_local_device_; + w_m.r_addr = slice->rdma.dest_addr; + w_m.r_key = slice->rdma.dest_rkey; + w_m.r_ttl_ms = UINT64_MAX; + auto& dev_map = sid_dev_data_map[slice->target_id]; + int lkey_index = slice->rdma.lkey_index; + dev_map[lkey_index].push_back(w_m); + auto& slice_map = sid_dev_slice_map[slice->target_id]; + slice_map[lkey_index].push_back(slice); + } + for (auto& pair : sid_dev_data_map) { + SegmentID sid = pair.first; + auto& dev_map = pair.second; + + for (auto& dev_pair : dev_map) { + int dev = dev_pair.first; + std::vector& data_vec = dev_pair.second; + std::vector& slice_vec = + sid_dev_slice_map[sid][dev]; + size_t data_size = data_vec.size(); + int qp_in_use = qp_num_per_ctx_; + if (data_size < (size_t)qp_num_per_ctx_) { + qp_in_use = data_size; + } + size_t begin_idx = 0; + size_t end_idx = 0; + size_t batch_size = data_size / qp_in_use; + size_t reminder = data_size % qp_in_use; + + int retry_cnt = 5; + for (int i = 0; i < qp_in_use; i++) { + XChannel* channel = nullptr; + for (int j = 0; j < retry_cnt; j++) { + channel = channel_cache_.find(sid, dev, i); + if (!channel) { + LOG(ERROR) + << "Write fail, sid " << sid << ", dev " << dev + << ", id " << i << " not found, retry " << j << "/" + << retry_cnt; + break; + } + if (!channel->IsActive()) { + LOG(WARNING) + << "Write fail, channel status error " << channel + << " retry " << j << "/" << retry_cnt; + channel_cache_.erase(sid, dev, i); + continue; + } + } + if (!channel) { + LOG(ERROR) << "Write fail, no channel found"; + return -1; + } + + end_idx += batch_size; + if (i == qp_in_use - 1) { + end_idx += reminder; + } + int peer_nic_id = channel->GetPeerNicId(); + auto data_chunk_read = + std::make_shared>(); + auto data_chunk_write = + std::make_shared>(); + auto slice_chunk_read = + std::make_shared>(); + auto slice_chunk_write = + std::make_shared>(); + for (size_t idx = begin_idx; idx < end_idx; idx++) { + data_vec[idx].r_key = + slice_vec[idx]->dest_rkeys[peer_nic_id]; + if (slice_vec[idx]->opcode == + Transport::TransferRequest::READ) { + data_chunk_read->emplace_back(data_vec[idx]); + slice_chunk_read->emplace_back(slice_vec[idx]); + } else { + data_chunk_write->emplace_back(data_vec[idx]); + slice_chunk_write->emplace_back(slice_vec[idx]); + } + } + + if (!data_chunk_write->empty()) { + BarexResult r = channel->WriteBatch( + data_chunk_write, + [slice_chunk_write](accl::barex::Status s) { + if (!s.IsOk()) { + LOG(ERROR) << "WriteBatch fail, " + << s.ErrMsg().c_str(); + for (auto slice : *slice_chunk_write) { + slice->markFailed(); + } + } else { + for (auto slice : *slice_chunk_write) { + slice->markSuccess(); + } + } + }, + true); + if (r != accl::barex::BAREX_SUCCESS) { + LOG(ERROR) << "WriteBatch fail, ret " << r; + return -2; + } + } + if (!data_chunk_read->empty()) { + BarexResult r = channel->ReadBatch( + data_chunk_read, + [slice_chunk_read](accl::barex::Status s) { + if (!s.IsOk()) { + LOG(ERROR) + << "ReadBatch fail, " << s.ErrMsg().c_str(); + for (auto slice : *slice_chunk_read) { + slice->markFailed(); + } + } else { + for (auto slice : *slice_chunk_read) { + slice->markSuccess(); + } + } + }, + true); + if (r != accl::barex::BAREX_SUCCESS) { + LOG(ERROR) << "ReadBatch fail, ret " << r; + return -2; + } + } + begin_idx += batch_size; + } + } + } + return 0; +} + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp new file mode 100644 index 000000000..3f9558940 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp @@ -0,0 +1,1521 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport/barex_transport/barex_transport.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "common.h" +#include "config.h" +#include "memory_location.h" +#include "topology.h" +// #include "transport/rdma_transport/rdma_context.h" +// #include "transport/rdma_transport/rdma_endpoint.h" + +namespace mooncake { +using namespace accl::barex; + +class EmptyCallback : public XChannelCallback { + public: + void OnRecvCall(XChannel *channel, char *buf, size_t len, + x_msg_header header) {} +}; + +BarexTransport::BarexTransport() {} + +BarexTransport::~BarexTransport() { +#ifdef CONFIG_USE_BATCH_DESC_SET + for (auto &entry : batch_desc_set_) delete entry.second; + batch_desc_set_.clear(); +#endif + for (auto ctx : client_context_list_) { + std::vector chs = ctx->getAllChannel(); + for (auto ch : chs) { + BarexResult ret = + connector_->CloseChannel(ch, [&, ch](accl::barex::Status s) { + LOG(INFO) << "CloseChannel() finished, s.IsOk=" << s.IsOk(); + ch->Destroy(); + }); + if (ret != accl::barex::BAREX_SUCCESS) { + LOG(ERROR) << "CloseChannel() failed, ret " << ret; + } + } + } + client_context_list_.clear(); + server_context_list_.clear(); + metadata_->removeSegmentDesc(local_server_name_); + batch_desc_set_.clear(); + connector_->Shutdown(); + connector_->WaitStop(); + listener_->Shutdown(); + listener_->WaitStop(); + server_threadpool_->Shutdown(); + server_threadpool_->WaitStop(); + client_threadpool_->Shutdown(); + client_threadpool_->WaitStop(); + mempool_->Shutdown(); + mempool_->WaitStop(); +} + +int BarexTransport::install(std::string &local_server_name, + std::shared_ptr meta, + std::shared_ptr topo) { + if (topo == nullptr) { + LOG(ERROR) << "BarexTransport: missing topology"; + return ERR_INVALID_ARGUMENT; + } + + metadata_ = meta; + local_server_name_ = local_server_name; + local_topology_ = topo; + + const char *barex_random_dev_env = std::getenv("BAREX_USE_RANDOM_DEV"); + if (barex_random_dev_env) { + int val = atoi(barex_random_dev_env); + if (val != 0) { + LOG(INFO) << "BarexTransport: use random rdma device"; + use_random_dev_ = true; + } + } + + const char *barex_use_cpu_env = std::getenv("ACCL_USE_CPU"); + if (barex_use_cpu_env) { + int val = atoi(barex_use_cpu_env); + if (val != 0) { + LOG(INFO) << "BarexTransport: use_cpu"; + barex_use_cpu_ = true; + } + } + + const char *barex_local_device_env = std::getenv("ACCL_LOCAL_DEVICE"); + if (barex_local_device_env) { + int val = atoi(barex_local_device_env); + LOG(INFO) << "BarexTransport: set local device id " << val; + barex_local_device_ = val; + } + + auto ret = initializeRdmaResources(); + if (ret) { + LOG(ERROR) << "BarexTransport: cannot initialize RDMA resources"; + return ret; + } + + ret = allocateLocalSegmentID(); + if (ret) { + LOG(ERROR) << "Transfer engine cannot be initialized: cannot " + "allocate local segment"; + return ret; + } + + ret = startHandshakeDaemon(local_server_name); + if (ret) { + LOG(ERROR) << "BarexTransport: cannot start handshake daemon"; + return ret; + } + + ret = metadata_->updateLocalSegmentDesc(); + if (ret) { + LOG(ERROR) << "BarexTransport: cannot publish segments"; + return ret; + } + + return 0; +} + +int BarexTransport::registerLocalMemory(void *addr, size_t length, + const std::string &name, + bool remote_accessible, + bool update_metadata) { + auto &config = globalConfig(); + size_t buffer_size = config.eic_max_block_size; + size_t remaining = length; + void *current_ptr = addr; + device_type dtype; + + if (name.find("cuda") != std::string::npos || name == kWildcardLocation) { + dtype = GPU; + } else if (name.find("cpu") != std::string::npos) { + dtype = CPU; + } else { + LOG(ERROR) + << "BarexTransport: registerLocalMemory, cannot recognize: name " + << name << ", need include cpu or cuda in name"; + return ERR_INVALID_ARGUMENT; + } + + bool is_gpu = dtype == GPU ? true : false; + + while (remaining > 0) { + size_t buffer_len = std::min(buffer_size, remaining); + int ret = + registerLocalMemoryBase(current_ptr, buffer_len, name, + remote_accessible, update_metadata, is_gpu); + if (ret) { + LOG(ERROR) << "registerLocalMemoryBase failed, ret " << ret; + return -1; + } + current_ptr = static_cast(current_ptr) + buffer_len; + remaining -= buffer_len; + } + + std::lock_guard guard(buf_mutex_); + if (dtype == CPU) { + buf_length_map_.emplace(addr, std::make_pair(length, 0)); + } else { + buf_length_map_.emplace(addr, std::make_pair(length, 1)); + } + + return 0; +} + +int BarexTransport::registerLocalMemoryBase(void *addr, size_t length, + const std::string &name, + bool remote_accessible, + bool update_metadata, bool is_gpu) { + (void)remote_accessible; + BufferDesc buffer_desc; + memp_t mem; + BarexResult result; + device_type dtype = is_gpu ? GPU : CPU; + result = mempool_->RegUserMr(mem, addr, length, dtype); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: registerLocalMemory failed" + << ", result " << result << ", addr " << addr << ", length " + << length << ", name " << name; + return ERR_ADDRESS_NOT_REGISTERED; + } else { + for (auto &mr : mem.mrs) { + buffer_desc.lkey.push_back(mr.second->lkey); + buffer_desc.rkey.push_back(mr.second->rkey); + } + } + + // Get the memory location automatically after registered MR(pinned), + // when the name is kWildcardLocation("*"). + if (name == kWildcardLocation) { + bool only_first_page = true; + const std::vector entries = + getMemoryLocation(addr, length, only_first_page); + for (auto &entry : entries) { + buffer_desc.name = entry.location; + buffer_desc.addr = entry.start; + buffer_desc.length = entry.len; + int rc = + metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (rc) return rc; + } + } else { + buffer_desc.name = name; + buffer_desc.addr = (uint64_t)addr; + buffer_desc.length = length; + int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (rc) return rc; + } + + return 0; +} + +int BarexTransport::unregisterLocalMemory(void *addr, bool update_metadata) { + int rc = metadata_->removeLocalMemoryBuffer(addr, update_metadata); + if (rc) return rc; + + auto &config = globalConfig(); + size_t buffer_size = config.eic_max_block_size; + void *current_ptr = addr; + device_type dtype; + BarexResult result; + size_t remaining = 0; + memp_t mem; + { + std::lock_guard guard(buf_mutex_); + auto iter = buf_length_map_.find(addr); + if (iter != buf_length_map_.end()) { + remaining = iter->second.first; + dtype = iter->second.second ? GPU : CPU; + buf_length_map_.erase(iter); + } + } + + while (remaining > 0) { + size_t buffer_len = std::min(buffer_size, remaining); + if (current_ptr > addr) { + int rc = metadata_->removeLocalMemoryBuffer(current_ptr, + update_metadata); + if (rc) { + LOG(WARNING) << "unregisterLocalMemory, " + "removeLocalMemoryBuffer failed, addr " + << addr; + } + } + result = mempool_->DeregUserMr(current_ptr, dtype); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "unregisterLocalMemory, DeregUserMr, failed, ret " + << result << ", addr " << current_ptr; + return -1; + } + current_ptr = static_cast(current_ptr) + buffer_len; + remaining -= buffer_len; + } + + return 0; +} + +int BarexTransport::allocateLocalSegmentID() { + auto desc = std::make_shared(); + if (!desc) return ERR_MEMORY; + desc->name = local_server_name_; + desc->protocol = "barex"; + for (auto &entry : server_context_list_) { + TransferMetadata::DeviceDesc device_desc; + device_desc.name = entry->getCtx()->GetXDevice()->GetName(); + // TODO is barex need this? + device_desc.lid = 0; // entry->lid(); + device_desc.gid = "ignore"; // entry->gid(); + desc->devices.push_back(device_desc); + } + desc->topology = *(local_topology_.get()); + metadata_->addLocalSegment(LOCAL_SEGMENT_ID, local_server_name_, + std::move(desc)); + return 0; +} + +int BarexTransport::registerLocalMemoryBatch( + const std::vector &buffer_list, + const std::string &location) { + for (auto &buffer : buffer_list) { + int ret = registerLocalMemory(buffer.addr, buffer.length, location, + true, false); + if (ret) { + LOG(ERROR) << "BarexTransport: Failed to register memory: addr " + << buffer.addr << " length " << buffer.length; + return ERR_ADDRESS_NOT_REGISTERED; + } + } + + return metadata_->updateLocalSegmentDesc(); +} + +int BarexTransport::unregisterLocalMemoryBatch( + const std::vector &addr_list) { + std::vector> results; + for (auto &addr : addr_list) { + results.emplace_back( + std::async(std::launch::async, [this, addr]() -> int { + return unregisterLocalMemory(addr, false); + })); + } + + for (size_t i = 0; i < addr_list.size(); ++i) { + if (results[i].get()) + LOG(WARNING) << "BarexTransport: Failed to unregister memory: addr " + << addr_list[i]; + } + + return metadata_->updateLocalSegmentDesc(); +} + +Status BarexTransport::submitTransfer( + BatchID batch_id, const std::vector &entries) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { + LOG(ERROR) + << "BarexTransport: Exceed the limitation of current batch's " + "capacity"; + return Status::InvalidArgument( + "BarexTransport: Exceed the limitation of capacity, batch id: " + + std::to_string(batch_id)); + } + + std::unordered_map, std::vector> + slices_to_post; + size_t task_id = batch_desc.task_list.size(); + batch_desc.task_list.resize(task_id + entries.size()); + auto local_segment_desc = metadata_->getSegmentDescByID(LOCAL_SEGMENT_ID); + // const size_t kBlockSize = globalConfig().slice_size; + const size_t kBlockMaxSize = globalConfig().eic_max_block_size; + const int kMaxRetryCount = globalConfig().retry_cnt; + std::unordered_map> + segment_desc_map; + for (auto &request : entries) { + auto target_id = request.target_id; + if (!segment_desc_map.count(target_id)) + segment_desc_map[target_id] = + metadata_->getSegmentDescByID(target_id); + } + for (auto &request : entries) { + TransferTask &task = batch_desc.task_list[task_id]; + ++task_id; + SegmentID target_id = request.target_id; + auto peer_segment_desc = segment_desc_map[target_id]; + if (!peer_segment_desc) { + LOG(ERROR) << "peer_segment_desc not found for target_id " + << target_id; + return Status::InvalidArgument( + "BarexTransport: peer_segment_desc not found, batch id: " + + std::to_string(batch_id)); + } + size_t kBlockSize = std::min(request.length, kBlockMaxSize); + for (uint64_t offset = 0; offset < request.length; + offset += kBlockSize) { + Slice *slice = getSliceCache().allocate(); + slice->source_addr = (char *)request.source + offset; + slice->length = std::min(request.length - offset, kBlockSize); + slice->opcode = request.opcode; + slice->rdma.dest_addr = request.target_offset + offset; + slice->rdma.retry_cnt = 0; + slice->rdma.max_retry_cnt = kMaxRetryCount; + slice->task = &task; + slice->target_id = request.target_id; + slice->ts = 0; + slice->status = Slice::PENDING; + task.slice_list.push_back(slice); + + int peer_buffer_id = -1, extra_peer_buffer_id = 0, + peer_device_id = -1; + int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, + retry_cnt = 0; + while (retry_cnt < kMaxRetryCount) { + int ret = selectDevice( + local_segment_desc.get(), (uint64_t)slice->source_addr, + slice->length, local_buffer_id, device_id, retry_cnt++); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) + << "local_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_local_buffer_id = local_buffer_id + 1; + } + } + ret = + selectDevice(peer_segment_desc.get(), slice->rdma.dest_addr, + slice->length, peer_buffer_id, peer_device_id, + slice->rdma.retry_cnt); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) << "peer_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_peer_buffer_id = peer_buffer_id + 1; + } + } + assert(device_id >= 0); + if (device_id >= + static_cast(client_context_list_.size()) || + use_random_dev_) { + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis( + 0, client_context_list_.size() - 1); + device_id = dis(gen); + } + auto &context = client_context_list_[device_id]; + if (!context->active()) continue; + assert(context->getCtx()->GetXDevice()->GetId() == device_id); + // 4 types, local:peer = 1:1, 1:2, 2:1, 2:2 + if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += slice->length; + __sync_fetch_and_add(&task.slice_count, 1); + break; + } else if (!extra_local_buffer_id && + extra_peer_buffer_id) { // 1:2 + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_length = slice->rdma.dest_addr + slice->length - + last_peer_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = + (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + break; + } else if (extra_local_buffer_id && + !extra_peer_buffer_id) { // 2:1 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = + (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + break; + } else { // 2:2 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_local_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_local_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; + assert(first_local_length + last_local_length == + slice->length); + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_peer_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_peer_length = slice->rdma.dest_addr + + slice->length - + last_peer_buffer_desc.addr; + assert(first_peer_length + last_peer_length == + slice->length); + if (first_local_length == first_peer_length) { + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else if (first_local_length > first_peer_length) { + // add first part + slice->length = first_peer_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = + (char *)request.source + offset + first_peer_length; + second_slice->length = + first_local_length - first_peer_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += + first_local_length - first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else { + // first_local_length < first_peer_length + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = (char *)request.source + + offset + first_local_length; + second_slice->length = + first_peer_length - first_local_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += + first_peer_length - first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = + (char *)request.source + offset + first_peer_length; + last_slice->length = last_peer_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + } + break; + } + } + if (device_id < 0) { + auto source_addr = slice->source_addr; + for (auto &entry : slices_to_post) + for (auto s : entry.second) delete s; + LOG(ERROR) << "BarexTransport: Address not registered by any " + "device(s) " + << source_addr; + return Status::AddressNotRegistered( + "BarexTransport: not registered by any device(s), " + "address: " + + std::to_string(reinterpret_cast(source_addr))); + } + } + } + for (auto &entry : slices_to_post) { + int ret = entry.first->submitPostSend(entry.second); + if (ret) { + return Status::InvalidArgument("submitPostSend failed"); + } + } + return Status::OK(); +} + +Status BarexTransport::submitTransferTask( + const std::vector &task_list) { + std::unordered_map, std::vector> + slices_to_post; + auto local_segment_desc = metadata_->getSegmentDescByID(LOCAL_SEGMENT_ID); + assert(local_segment_desc.get()); + // const size_t kBlockSize = globalConfig().slice_size; + const size_t kBlockMaxSize = globalConfig().eic_max_block_size; + const int kMaxRetryCount = globalConfig().retry_cnt; + std::unordered_map> + segment_desc_map; + for (size_t index = 0; index < task_list.size(); ++index) { + assert(task_list[index]); + auto &task = *task_list[index]; + assert(task.request); + auto &request = *task.request; + auto target_id = request.target_id; + if (!segment_desc_map.count(target_id)) + segment_desc_map[target_id] = + metadata_->getSegmentDescByID(target_id); + } + for (size_t index = 0; index < task_list.size(); ++index) { + auto &task = *task_list[index]; + auto &request = *task.request; + SegmentID target_id = request.target_id; + auto peer_segment_desc = segment_desc_map[target_id]; + if (!peer_segment_desc) { + LOG(ERROR) << "peer_segment_desc not found for target_id " + << target_id; + return Status::InvalidArgument( + "BarexTransport: peer_segment_desc not found"); + } + size_t kBlockSize = std::min(request.length, kBlockMaxSize); + for (uint64_t offset = 0; offset < request.length; + offset += kBlockSize) { + Slice *slice = getSliceCache().allocate(); + assert(slice); + slice->source_addr = (char *)request.source + offset; + slice->length = std::min(request.length - offset, kBlockSize); + slice->opcode = request.opcode; + slice->rdma.dest_addr = request.target_offset + offset; + slice->rdma.retry_cnt = request.advise_retry_cnt; + slice->rdma.max_retry_cnt = kMaxRetryCount; + slice->task = &task; + slice->target_id = request.target_id; + slice->status = Slice::PENDING; + slice->ts = 0; + task.slice_list.push_back(slice); + + int peer_buffer_id = -1, extra_peer_buffer_id = 0, + peer_device_id = -1; + int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, + retry_cnt = request.advise_retry_cnt; + bool found_device = false; + while (retry_cnt < kMaxRetryCount) { + int ret = selectDevice( + local_segment_desc.get(), (uint64_t)slice->source_addr, + slice->length, local_buffer_id, device_id, retry_cnt++); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) + << "local_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_local_buffer_id = local_buffer_id + 1; + } + } + ret = + selectDevice(peer_segment_desc.get(), slice->rdma.dest_addr, + slice->length, peer_buffer_id, peer_device_id, + slice->rdma.retry_cnt); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) << "peer_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_peer_buffer_id = peer_buffer_id + 1; + } + } + assert(device_id >= 0); + if (device_id >= + static_cast(client_context_list_.size()) || + use_random_dev_) { + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis( + 0, client_context_list_.size() - 1); + device_id = dis(gen); + } + auto &context = client_context_list_[device_id]; + assert(context.get()); + if (!context->active()) continue; + assert(context->getCtx()->GetXDevice()->GetId() == device_id); + assert(local_buffer_id >= 0 && + local_buffer_id < local_segment_desc->buffers.size()); + assert( + local_segment_desc->buffers[local_buffer_id].lkey.size() == + client_context_list_.size()); + // 4 types, local:peer = 1:1, 1:2, 2:1, 2:2 + if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += slice->length; + __sync_fetch_and_add(&task.slice_count, 1); + found_device = true; + break; + } else if (!extra_local_buffer_id && + extra_peer_buffer_id) { // 1:2 + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_length = slice->rdma.dest_addr + slice->length - + last_peer_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = + (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + found_device = true; + break; + } else if (extra_local_buffer_id && + !extra_peer_buffer_id) { // 2:1 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = + (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + found_device = true; + break; + } else { // 2:2 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_local_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_local_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; + assert(first_local_length + last_local_length == + slice->length); + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_peer_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_peer_length = slice->rdma.dest_addr + + slice->length - + last_peer_buffer_desc.addr; + assert(first_peer_length + last_peer_length == + slice->length); + if (first_local_length == first_peer_length) { + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else if (first_local_length > first_peer_length) { + // add first part + slice->length = first_peer_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = + (char *)request.source + offset + first_peer_length; + second_slice->length = + first_local_length - first_peer_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += + first_local_length - first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else { + // first_local_length < first_peer_length + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = (char *)request.source + + offset + first_local_length; + second_slice->length = + first_peer_length - first_local_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += + first_peer_length - first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = + (char *)request.source + offset + first_peer_length; + last_slice->length = last_peer_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + } + found_device = true; + break; + } + } + if (!found_device) { + auto source_addr = slice->source_addr; + for (auto &entry : slices_to_post) + for (auto s : entry.second) getSliceCache().deallocate(s); + LOG(ERROR) + << "Memory region not registered by any active device(s): " + << source_addr; + return Status::AddressNotRegistered( + "Memory region not registered by any active device(s): " + + std::to_string(reinterpret_cast(source_addr))); + } + } + } + for (auto &entry : slices_to_post) { + int ret = entry.first->submitPostSend(entry.second); + if (ret) { + return Status::InvalidArgument("submitPostSend failed"); + } + } + return Status::OK(); +} + +Status BarexTransport::getTransferStatus(BatchID batch_id, + std::vector &status) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + status.resize(task_count); + for (size_t task_id = 0; task_id < task_count; task_id++) { + auto &task = batch_desc.task_list[task_id]; + status[task_id].transferred_bytes = task.transferred_bytes; + uint64_t success_slice_count = task.success_slice_count; + uint64_t failed_slice_count = task.failed_slice_count; + if (success_slice_count + failed_slice_count == task.slice_count) { + if (failed_slice_count) { + status[task_id].s = TransferStatusEnum::FAILED; + } else { + status[task_id].s = TransferStatusEnum::COMPLETED; + } + task.is_finished = true; + } else { + status[task_id].s = TransferStatusEnum::WAITING; + } + } + return Status::OK(); +} + +Status BarexTransport::getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + if (task_id >= task_count) { + return Status::InvalidArgument( + "BarexTransport::getTransportStatus invalid argument, batch id: " + + std::to_string(batch_id)); + } + auto &task = batch_desc.task_list[task_id]; + status.transferred_bytes = task.transferred_bytes; + uint64_t success_slice_count = task.success_slice_count; + uint64_t failed_slice_count = task.failed_slice_count; + if (success_slice_count + failed_slice_count == task.slice_count) { + if (failed_slice_count) + status.s = TransferStatusEnum::FAILED; + else + status.s = TransferStatusEnum::COMPLETED; + task.is_finished = true; + } else { + status.s = TransferStatusEnum::WAITING; + } + return Status::OK(); +} + +BarexTransport::SegmentID BarexTransport::getSegmentID( + const std::string &segment_name) { + return metadata_->getSegmentID(segment_name); +} + +Status BarexTransport::OpenChannel(const std::string &segment_name, + SegmentID sid) { + auto [ip, port] = parseHostNameWithPort(segment_name); + + HandShakeDesc local_desc, peer_desc; + local_desc.barex_port = getLocalPort(); + + int rc = metadata_->sendHandshake(segment_name, local_desc, peer_desc); + if (rc) return Status::Socket("sendHandshake failed"); + ; + if (!peer_desc.reply_msg.empty()) { + LOG(ERROR) << "Reject the handshake request by peer " << segment_name; + return Status::Socket("empty peer_desc"); + } else { + LOG(INFO) << "Handshake finish, get peer_server " << segment_name << ":" + << peer_desc.barex_port; + setPeerPort(peer_desc.barex_port); + } + + int client_ctx_cnt = client_context_list_.size(); + int total_channels = client_ctx_cnt * client_context_list_[0]->getQpNum(); + CountDownLatch connect_latch(total_channels); + std::vector channels; + static std::mutex push_channel_mtx; + for (int i = 0; i < total_channels; i++) { + BarexResult result = connector_->Connect( + ip, getPeerPort(), + [=, &channels, &connect_latch](XChannel *channel, + accl::barex::Status s) { + if (!s.IsOk()) { + LOG(ERROR) + << "BarexTransport::OpenChannel failed, " << s.ErrMsg(); + } else { + std::unique_lock lk(push_channel_mtx); + channels.push_back(channel); + LOG(INFO) + << "Open channel " << i + 1 << "/" << total_channels; + } + connect_latch.CountDown(); + }); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport::OpenChannel failed, result=" + << result; + connect_latch.CountDown(); + } + } + connect_latch.Wait(); + if ((int)channels.size() != total_channels) { + LOG(ERROR) << "open channel failed, need " << total_channels + << " but got " << channels.size(); + return Status::InvalidArgument("connect failed"); + } + for (auto channel : channels) { + int idx = channel->GetContext()->GetXDevice()->GetId(); + assert(client_context_list_[idx]->getCtx()->GetXDevice()->GetId() == + idx); + client_context_list_[idx]->addChannel(sid, idx, channel); + } + return Status::OK(); +} + +Status BarexTransport::CheckStatus(SegmentID sid) { + bool status = 0; + for (auto ctx : client_context_list_) { + int ret = ctx->checkStatus(sid); + if (ret) { + LOG(INFO) << "checkStatus failed in ctx" << ctx + << ", bad channel cnt=" << ret; + status = 1; + } + } + if (!status) { + LOG(ERROR) << "CheckStatus for sid " << sid << " failed"; + return Status::InvalidArgument("sid status error"); + } + return Status::OK(); +} + +int BarexTransport::onSetupRdmaConnections(const HandShakeDesc &peer_desc, + HandShakeDesc &local_desc) { + local_desc.barex_port = getLocalPort(); + return 0; +} + +int BarexTransport::initializeRdmaResources() { + auto hca_list = local_topology_->getHcaList(); + BarexResult result; + XDeviceManager *manager = nullptr; + XThreadpool *server_threadpool = nullptr; + XThreadpool *client_threadpool = nullptr; + XSimpleMempool *mempool = nullptr; + result = XDeviceManager::Singleton(manager); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create XDeviceManager failed"; + return ERR_DEVICE_NOT_FOUND; + } + std::vector devices = manager->AllDevices(); + if (devices.size() <= 0) { + LOG(ERROR) << "BarexTransport: No available RNIC"; + return ERR_DEVICE_NOT_FOUND; + } else { + LOG(INFO) << devices.size() << " rdma devices found"; + } + result = XSimpleMempool::NewInstance(mempool, "barex-mempool", devices); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create XSimpleMempool failed"; + return ERR_INVALID_ARGUMENT; + } + mempool_ = std::shared_ptr(mempool); + result = XThreadpool::NewInstance(server_threadpool, 10, + "barex-server-threadpool"); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create Server XThreadpool failed"; + return ERR_INVALID_ARGUMENT; + } + server_threadpool_ = std::shared_ptr(server_threadpool); + result = XThreadpool::NewInstance(client_threadpool, 10, + "barex-client-threadpool"); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create Client XThreadpool failed"; + return ERR_INVALID_ARGUMENT; + } + client_threadpool_ = std::shared_ptr(client_threadpool); + auto &config = globalConfig(); + for (auto &dev : devices) { + if (std::find(hca_list.begin(), hca_list.end(), dev->GetName()) == + hca_list.end()) { + LOG(WARNING) << "BarexTransport: device " << dev->GetName() + << " not found in hca_list, ignore "; + continue; + } + ContextConfig server_config = XConfigUtil::DefaultContextConfig(); + XContext *raw_server_context = nullptr; + result = XContext::NewInstance(raw_server_context, server_config, + new EmptyCallback(), dev, mempool, + server_threadpool); + if (result != BAREX_SUCCESS) { + local_topology_->disableDevice(dev->GetName()); + LOG(WARNING) + << "BarexTransport: Create XContext failed, Disable device " + << dev->GetName(); + } else { + raw_server_context->Start(); + auto server_context = std::make_shared( + raw_server_context, barex_use_cpu_, barex_local_device_); + server_context->setQpNum(config.num_qp_per_ep); + server_context_list_.push_back(server_context); + } + ContextConfig client_config = XConfigUtil::DefaultContextConfig(); + XContext *raw_client_context = nullptr; + result = XContext::NewInstance(raw_client_context, client_config, + new EmptyCallback(), dev, mempool, + client_threadpool); + if (result != BAREX_SUCCESS) { + local_topology_->disableDevice(dev->GetName()); + LOG(WARNING) + << "BarexTransport: Create XContext failed, Disable device " + << dev->GetName(); + } else { + raw_client_context->Start(); + auto client_context = std::make_shared( + raw_client_context, barex_use_cpu_, barex_local_device_); + client_context->setQpNum(config.num_qp_per_ep); + client_context_list_.push_back(client_context); + } + } + + if (local_topology_->empty()) { + LOG(ERROR) << "BarexTransport: No available RNIC"; + return ERR_DEVICE_NOT_FOUND; + } + return 0; +} + +int BarexTransport::startHandshakeDaemon(std::string &local_server_name) { + std::vector raw_server_contexts; + std::vector raw_client_contexts; + for (auto ctx : server_context_list_) { + raw_server_contexts.emplace_back(ctx->getCtx()); + } + for (auto ctx : client_context_list_) { + raw_client_contexts.emplace_back(ctx->getCtx()); + } + XListener *listener = nullptr; + + int port = metadata_->localRpcMeta().barex_port; + setLocalPort(port); + BarexResult result = XListener::NewInstance(listener, 2, getLocalPort(), + TIMER_3S, raw_server_contexts); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create listener " + "failed, result " + << result; + return ERR_INVALID_ARGUMENT; + } + result = listener->Listen(); + if (result != BAREX_SUCCESS) { + LOG(ERROR) + << "BarexTransport: startHandshakeDaemon, Listen failed, result " + << result; + return ERR_INVALID_ARGUMENT; + } + listener_ = std::shared_ptr(listener); + XConnector *connector = nullptr; + result = + XConnector::NewInstance(connector, 2, TIMER_3S, raw_client_contexts); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create connector " + "failed, result " + << result; + return ERR_INVALID_ARGUMENT; + } + connector_ = std::shared_ptr(connector); + return metadata_->startHandshakeDaemon( + std::bind(&BarexTransport::onSetupRdmaConnections, this, + std::placeholders::_1, std::placeholders::_2), + metadata_->localRpcMeta().rpc_port, metadata_->localRpcMeta().sockfd); +} + +// According to the request desc, offset and length information, find proper +// buffer_id and device_id as output. +// Return 0 if successful, ERR_ADDRESS_NOT_REGISTERED otherwise. +int BarexTransport::selectDevice(SegmentDesc *desc, uint64_t offset, + size_t length, int &buffer_id, int &device_id, + int retry_count) { + if (!desc) return ERR_ADDRESS_NOT_REGISTERED; + int ret = 0; + for (buffer_id = 0; buffer_id < (int)desc->buffers.size(); ++buffer_id) { + auto &buffer_desc = desc->buffers[buffer_id]; + if (buffer_desc.addr > offset || + offset >= buffer_desc.addr + buffer_desc.length) { + continue; + } else { + if (offset + length > buffer_desc.addr + buffer_desc.length) { + // mr cross two buffers, need separate into two parts + if (buffer_id + 1 < (int)desc->buffers.size()) { + auto &next_buffer_desc = desc->buffers[buffer_id + 1]; + if (offset + length > next_buffer_desc.addr && + offset + length <= + next_buffer_desc.addr + next_buffer_desc.length) { + ret = 1; + } else { + LOG(ERROR) << "selectDevice failed, 2 buffers in need " + "but next buffer not fit," + << " offset " << offset << " length " + << length << " buffer_id " << buffer_id + << " buffer_desc.addr " << buffer_desc.addr + << " buffer_desc.length " + << buffer_desc.length << " buffer_id " + << buffer_id + 1 << " next_buffer_desc.addr " + << next_buffer_desc.addr + << " next_buffer_desc.length " + << next_buffer_desc.length; + return ERR_ADDRESS_NOT_REGISTERED; + } + } else { + LOG(ERROR) << "selectDevice failed, last buffer overflow," + << " offset " << offset << " length " << length + << " buffer_id " << buffer_id + << " buffer_desc.addr " << buffer_desc.addr + << " buffer_desc.length " << buffer_desc.length; + return ERR_ADDRESS_NOT_REGISTERED; + } + } + device_id = + desc->topology.selectDevice(buffer_desc.name, retry_count); + if (device_id >= 0) return ret; + device_id = + desc->topology.selectDevice(kWildcardLocation, retry_count); + if (device_id >= 0) return ret; + } + } + + return ERR_ADDRESS_NOT_REGISTERED; +} +} // namespace mooncake diff --git a/scripts/build_wheel.sh b/scripts/build_wheel.sh index df5d205a2..e34017f0e 100755 --- a/scripts/build_wheel.sh +++ b/scripts/build_wheel.sh @@ -37,9 +37,13 @@ else fi # Copy nvlink-allocator.so to mooncake directory (only if it exists - CUDA builds only) -if [ -f build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so ]; then - echo "Copying CUDA nvlink_allocator.so..." - cp build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so mooncake-wheel/mooncake/nvlink_allocator.so +if [ -f build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so ] \ + || [ -f /usr/lib/libaccl_barex.so ] \ + || [ -f /usr/lib64/libaccl_barex.so ]; then + if [ -f build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so ]; then + echo "Copying CUDA nvlink_allocator.so..." + cp build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so mooncake-wheel/mooncake/nvlink_allocator.so + fi echo "Copying allocator libraries..." # Copy allocator.py cp mooncake-integration/allocator.py mooncake-wheel/mooncake/allocator.py @@ -296,6 +300,7 @@ else --exclude libascend_trace.so* \ --exclude libmetadef*.so \ --exclude libllm_datadist*.so \ + --exclude libaccl_barex.so* \ -w ${REPAIRED_DIR}/ --plat ${PLATFORM_TAG} fi