Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mooncake-common/common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ option(USE_CUDA "option for enabling gpu features" OFF)
option(USE_MUSA "option for enabling Moore Threads gpu features by leveraging MUSA (Meta-computing Unified System Architecture)" OFF)
option(USE_NVMEOF "option for using NVMe over Fabric" OFF)
option(USE_TCP "option for using TCP transport" ON)
option(USE_BAREX "option for using accl-barex transport" OFF)
option(USE_ASCEND "option for using npu with HCCL" OFF)
option(USE_ASCEND_DIRECT "option for using ascend npu with adxl engine" OFF)
option(USE_ASCEND_HETEROGENEOUS "option for transferring between ascend npu and gpu" OFF)
Expand Down Expand Up @@ -123,6 +124,10 @@ if (USE_TCP)
add_compile_definitions(USE_TCP)
endif()

if (USE_BAREX)
add_compile_definitions(USE_BAREX)
endif()

if (USE_ASCEND OR USE_ASCEND_DIRECT)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DOPEN_BUILD_PROJECT ")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DOPEN_BUILD_PROJECT ")
Expand Down
2 changes: 1 addition & 1 deletion mooncake-integration/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def get_allocator(cls, device: torch_device) -> CUDAPluggableAllocator:
if device not in cls._instances:
so_path = cls._get_so_path()
cls._instances[device] = CUDAPluggableAllocator(
so_path, "u2mm_alloc_wrapper", "u2mm_free_wrapper"
so_path, "u2mm_alloc_wrapper_with_stream", "u2mm_free_wrapper_with_stream"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this has a compatible issue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry. It works.

)
return cls._instances[device]
33 changes: 31 additions & 2 deletions mooncake-integration/transfer_engine/transfer_engine_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,17 @@ int TransferEnginePy::initializeExt(const char *local_hostname,
free_list_.resize(kSlabSizeKBTabLen);
#if !defined(USE_ASCEND) && !defined(USE_ASCEND_DIRECT) && \
!defined(USE_ASCEND_HETEROGENEOUS)
doBuddyAllocate(kMaxClassId);
bool pass_alloc = false;
const char *pass_alloc_env = std::getenv("PASS_ALLOC");
if (pass_alloc_env) {
int val = atoi(pass_alloc_env);
if (val != 0) {
pass_alloc = true;
}
}
Comment on lines 137 to 146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of atoi is generally discouraged as it has undefined behavior on errors (e.g., when the input string is not a valid number or is out of range). For parsing environment variables, it's safer to use std::stoi within a try-catch block or strtol to handle potential parsing errors gracefully. This pattern of using atoi appears in other places in this PR as well (transfer_engine.cpp:79, transfer_metadata_plugin.cpp:1150, etc.). Consider replacing it with a safer alternative throughout.

Suggested change
if (pass_alloc_env) {
int val = atoi(pass_alloc_env);
if (val != 0) {
pass_alloc = true;
}
}
if (pass_alloc_env) {
try {
if (std::stoi(pass_alloc_env) != 0) {
pass_alloc = true;
}
} catch (const std::exception&) {
// Ignore invalid values or log a warning
}
}

if (!pass_alloc) {
doBuddyAllocate(kMaxClassId);
}
#endif
return 0;
}
Expand Down Expand Up @@ -266,6 +276,7 @@ int TransferEnginePy::transferSync(const char *target_hostname,
if (handle_map_.count(target_hostname)) {
handle = handle_map_[target_hostname];
} else {
LOG(INFO) << "transferSync, cache not found, openSegment with target " << target_hostname;
handle = engine_->openSegment(target_hostname);
if (handle == (Transport::SegmentHandle)-1) return -1;
handle_map_[target_hostname] = handle;
Expand Down Expand Up @@ -300,7 +311,17 @@ int TransferEnginePy::transferSync(const char *target_hostname,
batch_id, {entry},
TransferMetadata::NotifyDesc{notify->name, notify->msg})
: engine_->submitTransfer(batch_id, {entry});
if (!s.ok()) return -1;
if (!s.ok()) {
Status segment_status = engine_->CheckSegmentStatus(handle);
if (!segment_status.ok()) {
LOG(WARNING) << "submitTransfer failed with target " << target_hostname << ", CheckSegmentStatus not ok, ready to closeSegment";
std::lock_guard<std::mutex> guard(mutex_);
engine_->closeSegment(handle);
engine_->getMetadata()->removeSegmentDesc(target_hostname);
handle_map_.erase(target_hostname);
}
Comment on lines 321 to 330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for cleaning up a failed segment is duplicated in batchTransferSync (lines 411-418). To improve maintainability and avoid potential inconsistencies in the future, this logic should be extracted into a private helper function.

return -1;
}

TransferStatus status;
bool completed = false;
Expand Down Expand Up @@ -387,6 +408,14 @@ int TransferEnginePy::batchTransferSync(
: engine_->submitTransfer(batch_id, entries);
if (!s.ok()) {
engine_->freeBatchID(batch_id);
Status segment_status = engine_->CheckSegmentStatus(handle);
if (!segment_status.ok()) {
LOG(WARNING) << "submitTransfer failed with target " << target_hostname << ", CheckSegmentStatus not ok, ready to closeSegment";
std::lock_guard<std::mutex> guard(mutex_);
engine_->closeSegment(handle);
engine_->getMetadata()->removeSegmentDesc(target_hostname);
handle_map_.erase(target_hostname);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code block looks like the same as above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each non-OK request, it should check the results. I guess we should wrap this code block with USE_BAREX.

return -1;
}

Expand Down
16 changes: 14 additions & 2 deletions mooncake-transfer-engine/example/transfer_engine_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ DEFINE_string(mode, "initiator",
"data blocks from target node");
DEFINE_string(operation, "read", "Operation type: read or write");

DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|tcp");
DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|barex|tcp");

DEFINE_string(device_name, "mlx5_2",
"Device name to use, valid if protocol=rdma");
Expand Down Expand Up @@ -317,6 +317,12 @@ int initiator() {
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
xport = engine->installTransport("rdma", args);
} else if (FLAGS_protocol == "barex") {
auto nic_priority_matrix = loadNicPriorityMatrix();
void **args = (void **)malloc(2 * sizeof(void *));
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
xport = engine->installTransport("barex", args);
} else if (FLAGS_protocol == "tcp") {
xport = engine->installTransport("tcp", nullptr);
} else if (FLAGS_protocol == "nvlink") {
Expand Down Expand Up @@ -436,7 +442,13 @@ int target() {
void **args = (void **)malloc(2 * sizeof(void *));
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
engine->installTransport("rdma", args);
engine->installTransport("rdma", args);
} else if (FLAGS_protocol == "barex") {
auto nic_priority_matrix = loadNicPriorityMatrix();
void **args = (void **)malloc(2 * sizeof(void *));
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
engine->installTransport("barex", args);
} else if (FLAGS_protocol == "tcp") {
engine->installTransport("tcp", nullptr);
} else if (FLAGS_protocol == "nvlink") {
Expand Down
1 change: 1 addition & 0 deletions mooncake-transfer-engine/include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct GlobalConfig {
bool use_ipv6 = false;
size_t fragment_limit = 16384;
bool enable_dest_device_affinity = false;
size_t eic_max_block_size = 64UL * 1024 * 1024;
};

void loadGlobalConfig(GlobalConfig &config);
Expand Down
3 changes: 3 additions & 0 deletions mooncake-transfer-engine/include/transfer_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class TransferEngine {

SegmentHandle openSegment(const std::string &segment_name);

Status CheckSegmentStatus(SegmentID sid);

int closeSegment(SegmentHandle handle);

int removeLocalSegment(const std::string &segment_name);
Expand Down Expand Up @@ -249,6 +251,7 @@ class TransferEngine {
// Set it to false only for testing.
bool auto_discover_;
std::vector<std::string> filter_;
bool use_barex_ = false;

#ifdef WITH_METRICS
ylt::metric::counter_t transferred_bytes_counter_{
Expand Down
2 changes: 2 additions & 0 deletions mooncake-transfer-engine/include/transfer_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ class TransferMetadata {
struct RpcMetaDesc {
std::string ip_or_host_name;
uint16_t rpc_port;
uint16_t barex_port;
int sockfd; // local cache
};

struct HandShakeDesc {
std::string local_nic_path;
std::string peer_nic_path;
uint16_t barex_port;
std::vector<uint32_t> qp_num;
std::string reply_msg; // on error
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct HandShakePlugin {

std::vector<std::string> findLocalIpAddresses();

uint16_t findAvailableTcpPort(int &sockfd);
uint16_t findAvailableTcpPort(int &sockfd, bool set_range=false);

} // namespace mooncake

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright 2024 KVCache.AI
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BAREX_CONTEXT_H_
#define BAREX_CONTEXT_H_

#include <infiniband/verbs.h>

#include <atomic>
#include <cstddef>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "common.h"
#include "transport/transport.h"

#ifdef USE_BAREX
#include <accl/barex/barex.h>
#include <accl/barex/xcontext.h>
#include <accl/barex/xlistener.h>
#include <accl/barex/xconnector.h>
#include <accl/barex/xsimple_mempool.h>
#include <accl/barex/xthreadpool.h>
#include <accl/barex/xtimer.h>
#include <accl/barex/xconfig_util.h>
#endif

namespace mooncake {

#ifdef USE_BAREX

using namespace accl::barex;
using XChannel = accl::barex::XChannel;
using SegmentID = Transport::SegmentID;
using XContext = accl::barex::XContext;
using BarexResult = accl::barex::BarexResult;

class ChannelCache {
public:
// 添加一个 channel 到指定 key & nic_id
Copy link

Copilot AI Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. Should be translated to English: '// Add a channel to the specified key & nic_id'.

Suggested change
// 添加一个 channel 到指定 key & nic_id
// Add a channel to the specified key & nic_id

Copilot uses AI. Check for mistakes.
void put(SegmentID key, int nic_id, XChannel* channel) {
RWSpinlock::WriteGuard guard(lock_);
auto& channels = cache_[key];
auto& vec = channels[nic_id];
status_map_[key] = true;
vec.push_back(channel);
}

// 获取 sid 下指定 nic_id 和 idx 的 channel
Copy link

Copilot AI Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. Should be translated to English: '// Get the channel for the specified nic_id and idx under sid'.

Suggested change
// 获取 sid 下指定 nic_id idx 的 channel
// Get the channel for the specified nic_id and idx under sid

Copilot uses AI. Check for mistakes.
XChannel* find(SegmentID key, int nic_id, int idx) {
RWSpinlock::ReadGuard guard(lock_);
auto it = cache_.find(key);
if (it == cache_.end()) return nullptr;
auto& channels = it->second;
auto ch_it = channels.find(nic_id);
if (ch_it == channels.end()) return nullptr;
auto& vec = ch_it->second;
if (idx >= 0 && idx < static_cast<int>(vec.size())) {
return vec[idx];
}
return nullptr;
}

// 删除某个 channel(通过id和idx)
Copy link

Copilot AI Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. Should be translated to English: '// Delete a channel (by id and idx)'.

Suggested change
// 删除某个 channel(通过id和idx)
// Delete a channel (by id and idx)

Copilot uses AI. Check for mistakes.
bool erase(SegmentID key, int nic_id, int idx) {
RWSpinlock::WriteGuard guard(lock_);
auto it = cache_.find(key);
if (it == cache_.end()) return false;

auto& channels = it->second;
auto ch_it = channels.find(nic_id);
if (ch_it == channels.end()) return false;

auto& vec = ch_it->second;
if (idx < 0 || idx >= static_cast<int>(vec.size())) return false;

vec.erase(vec.begin() + idx);
status_map_[key] = false;
if (vec.empty()) {
channels.erase(ch_it);
if (channels.empty()) {
cache_.erase(it);
}
}
return true;
}

// 查询某个 SegmentID 下的 channel 状态
Copy link

Copilot AI Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. Should be translated to English: '// Check the channel status under a SegmentID'.

Suggested change
// 查询某个 SegmentID 下的 channel 状态
// Check the channel status under a SegmentID

Copilot uses AI. Check for mistakes.
bool CheckAllChannels(SegmentID segment_id) {
RWSpinlock::WriteGuard guard(lock_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

CheckAllChannels appears to be a read-only operation, but it uses a RWSpinlock::WriteGuard. This is inefficient and semantically incorrect. It should use a RWSpinlock::ReadGuard to allow for concurrent reads. The same issue exists in copyAll at line 148.

Suggested change
RWSpinlock::WriteGuard guard(lock_);
RWSpinlock::ReadGuard guard(lock_);

auto it = cache_.find(segment_id);
if (it == cache_.end()) {
return false;
}
auto& inner_map = it->second;
for (auto& pair : inner_map) {
auto& channels = pair.second;
for (XChannel* channel : channels) {
if (!channel->IsActive()) {
return false;
}
}
}
return true;
}

// 检查并删除某个 SegmentID 下的异常channel,并返回删除的数量
Copy link

Copilot AI Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. Should be translated to English: '// Check and remove invalid channels under a SegmentID, return the number of removed channels'.

Suggested change
// 检查并删除某个 SegmentID 下的异常channel,并返回删除的数量
// Check and remove invalid channels under a SegmentID, return the number of removed channels

Copilot uses AI. Check for mistakes.
int RemoveInvalidChannels(SegmentID segment_id) {
RWSpinlock::WriteGuard guard(lock_);
auto it = cache_.find(segment_id);
if (it == cache_.end()) {
return 0;
}

int invalid_count = 0;
auto& inner_map = it->second;

for (auto& pair : inner_map) {
auto& channels = pair.second;
auto new_end = std::remove_if(channels.begin(), channels.end(),
[](XChannel* channel) {
return !channel->IsActive();
});
invalid_count += std::distance(new_end, channels.end());
channels.erase(new_end, channels.end());
}
return invalid_count;
}

// 将所有的 channel 以 vector 形式返回
Copy link

Copilot AI Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is in Chinese. Should be translated to English: '// Return all channels as a vector'.

Suggested change
// 将所有的 channel 以 vector 形式返回
// Return all channels as a vector

Copilot uses AI. Check for mistakes.
std::vector<XChannel*> copyAll() {
RWSpinlock::WriteGuard guard(lock_);
std::vector<XChannel*> result;
for (const auto& [key, channels] : cache_) {
for (const auto& [nic_id, vec] : channels) {
result.insert(result.end(), vec.begin(), vec.end());
}
}
return result;
}

private:
std::unordered_map<SegmentID, std::unordered_map<int, std::vector<XChannel*>>> cache_;
std::unordered_map<SegmentID, bool> status_map_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The status_map_ member of ChannelCache is written to in put() and erase(), but it is never read. This appears to be dead code and should be removed to simplify the class.

RWSpinlock lock_;
};
class BarexContext {
public:
int submitPostSend(const std::vector<Transport::Slice *> &slice_list);
int addChannel(SegmentID sid, int device_id, XChannel *ch);
XChannel* getChannel(SegmentID sid, int device_id, int idx);
int checkStatus(SegmentID sid);
XContext* getCtx();
// int ClearAllChannel();
std::vector<XChannel*> getAllChannel();
bool active() const { return active_; }
void setQpNum(int qp_num) { qp_num_per_ctx_ = qp_num; }
int getQpNum() const { return qp_num_per_ctx_; }

public:
BarexContext(XContext* xcontext, bool use_cpu, int device_id);

~BarexContext();

XContext* xcontext_;
bool barex_use_cpu_;
int barex_local_device_;

private:
ChannelCache channel_cache_;
bool active_ = true;
int qp_num_per_ctx_ = 2;

};
#endif
} // namespace mooncake

#endif // BAREX_CONTEXT_H_
Loading