Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class AscendDirectTransport : public Transport {

void workerThread();

void queryThread();

void processSliceList(const std::vector<Slice *> &slice_list);

void localCopy(TransferRequest::OpCode opcode,
Expand Down Expand Up @@ -116,16 +118,21 @@ class AscendDirectTransport : public Transport {
std::set<std::string> connected_segments_;
std::mutex connection_mutex_;

// Async processing related members (similar to hccl_transport)
// Async processing related members
std::thread worker_thread_;
std::queue<std::vector<Slice *>> slice_queue_;
std::mutex queue_mutex_;
std::condition_variable queue_cv_;

std::thread query_thread_;
std::queue<std::vector<Slice *>> query_slice_queue_;
std::mutex query_mutex_;
std::condition_variable query_cv_;

int32_t device_logic_id_{};
aclrtContext rt_context_{nullptr};
int32_t connect_timeout_ = 10000;
int32_t transfer_timeout_ = 10000;
int64_t transfer_timeout_ = 10000;
std::string local_adxl_engine_name_{};
aclrtStream stream_{};
bool use_buffer_pool_{false};
Expand Down
2 changes: 2 additions & 0 deletions mooncake-transfer-engine/include/transport/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class Transport {
} hccl;
struct {
uint64_t dest_addr;
void *handle;
int64_t start_time;
} ascend_direct;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ AscendDirectTransport::~AscendDirectTransport() {
// Stop worker thread
running_ = false;
queue_cv_.notify_all();
query_cv_.notify_all();

if (worker_thread_.joinable()) {
worker_thread_.join();
}
if (query_thread_.joinable()) {
query_thread_.join();
}

// Disconnect all connections
std::lock_guard<std::mutex> lock(connection_mutex_);
Expand Down Expand Up @@ -123,6 +127,7 @@ int AscendDirectTransport::install(std::string &local_server_name,
// Start worker thread
running_ = true;
worker_thread_ = std::thread(&AscendDirectTransport::workerThread, this);
query_thread_ = std::thread(&AscendDirectTransport::queryThread, this);
return 0;
}

Expand Down Expand Up @@ -156,8 +161,8 @@ int AscendDirectTransport::InitAdxlEngine() {
}
}
// set default buffer pool
options["adxl.BufferPool"] = "4:8";
use_buffer_pool_ = true;
options["adxl.BufferPool"] = "0:0";
use_buffer_pool_ = false;
char *buffer_pool = std::getenv("ASCEND_BUFFER_POOL");
if (buffer_pool) {
options["adxl.BufferPool"] = buffer_pool;
Expand Down Expand Up @@ -192,9 +197,11 @@ int AscendDirectTransport::InitAdxlEngine() {
parseFromString<int32_t>(connect_transfer_str);
if (transfer_timeout.has_value()) {
transfer_timeout_ = transfer_timeout.value();
LOG(INFO) << "Set transfer timeout to:" << transfer_timeout_;
LOG(INFO) << "Set transfer timeout to:" << transfer_timeout_
<< " us.";
}
}
transfer_timeout_ = transfer_timeout_ * 1000000;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conversion to nanoseconds is correct for the timeout logic. However, the corresponding log message for transfer_timeout_ at line 200 is emitted before this conversion, showing the value in what seems to be milliseconds. This could be misleading for users debugging timeout issues. It would be clearer to log the timeout value after this conversion, and to include the units (e.g., 'ns') in the log message to avoid confusion.

return 0;
}

Expand Down Expand Up @@ -542,6 +549,93 @@ void AscendDirectTransport::workerThread() {
LOG(INFO) << "AscendDirectTransport worker thread stopped";
}

void AscendDirectTransport::queryThread() {
LOG(INFO) << "AscendDirectTransport query thread started";
std::vector<std::vector<Slice *>> pending_batches;
while (running_) {
{
std::unique_lock<std::mutex> lock(query_mutex_);
if (pending_batches.empty()) {
query_cv_.wait(lock, [this] {
return !running_ || !query_slice_queue_.empty();
});
}
if (!running_) {
break;
}
while (!query_slice_queue_.empty()) {
pending_batches.emplace_back(
std::move(query_slice_queue_.front()));
query_slice_queue_.pop();
}
}

if (pending_batches.empty()) {
continue;
}

auto it = pending_batches.begin();
while (it != pending_batches.end()) {
auto &slice_list = *it;
if (slice_list.empty()) {
it = pending_batches.erase(it);
continue;
}
auto handle = static_cast<adxl::TransferReq>(
slice_list[0]->ascend_direct.handle);
adxl::TransferStatus task_status;
auto ret = adxl_->GetTransferStatus(handle, task_status);
if (ret != adxl::SUCCESS ||
task_status == adxl::TransferStatus::FAILED) {
LOG(ERROR) << "Get transfer status failed, ret: " << ret;
for (auto &slice : slice_list) {
slice->markFailed();
}
it = pending_batches.erase(it);
} else if (task_status == adxl::TransferStatus::COMPLETED) {
auto now = getCurrentTimeInNano();
auto duration = now - slice_list[0]->ascend_direct.start_time;
auto target_segment_desc =
metadata_->getSegmentDescByID(slice_list[0]->target_id);
if (target_segment_desc) {
auto target_adxl_engine_name =
(target_segment_desc->rank_info.hostIp + ":" +
std::to_string(
target_segment_desc->rank_info.hostPort));
LOG(INFO) << "Transfer to " << target_adxl_engine_name
<< " time: " << duration / 1000 << "us";
}
for (auto &slice : slice_list) {
slice->markSuccess();
}
it = pending_batches.erase(it);
} else {
auto now = getCurrentTimeInNano();
if (now - slice_list[0]->ascend_direct.start_time >
transfer_timeout_) {
LOG(ERROR)
<< "Transfer timeout, you can increase the timeout "
"duration to reduce "
"the failure rate by configuring "
"the ASCEND_TRANSFER_TIMEOUT environment variable.";
for (auto &slice : slice_list) {
slice->markFailed();
}
it = pending_batches.erase(it);
} else {
++it;
}
}
}

if (!pending_batches.empty()) {
// Avoid busy loop
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
}
LOG(INFO) << "AscendDirectTransport query thread stopped";
}

void AscendDirectTransport::processSliceList(
const std::vector<Slice *> &slice_list) {
if (slice_list.empty()) {
Expand Down Expand Up @@ -591,7 +685,6 @@ void AscendDirectTransport::processSliceList(
}
return;
}
auto start = std::chrono::steady_clock::now();
std::vector<adxl::TransferOpDesc> op_descs;
op_descs.reserve(slice_list.size());
for (auto &slice : slice_list) {
Expand All @@ -602,26 +695,25 @@ void AscendDirectTransport::processSliceList(
op_desc.len = slice->length;
op_descs.emplace_back(op_desc);
}
auto status = adxl_->TransferSync(target_adxl_engine_name.c_str(),
operation, op_descs, transfer_timeout_);
auto start_time = getCurrentTimeInNano();
for (auto &slice : slice_list) {
slice->ascend_direct.start_time = start_time;
}
adxl::TransferReq req_handle;
auto status =
adxl_->TransferAsync(target_adxl_engine_name.c_str(), operation,
op_descs, adxl::TransferArgs(), req_handle);
if (status == adxl::SUCCESS) {
for (auto &slice : slice_list) {
slice->markSuccess();
slice->ascend_direct.handle = req_handle;
}
LOG(INFO) << "Transfer to:" << target_adxl_engine_name << ", cost: "
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< " us";
} else {
if (status == adxl::TIMEOUT) {
LOG(ERROR) << "Transfer timeout to: " << target_adxl_engine_name
<< ", you can increase the timeout duration to reduce "
"the failure rate by configuring "
"the ASCEND_TRANSFER_TIMEOUT environment variable.";
} else {
LOG(ERROR) << "Transfer slice failed with status: " << status;
{
std::unique_lock<std::mutex> lock(query_mutex_);
query_slice_queue_.push(slice_list);
}
query_cv_.notify_one();
} else {
LOG(ERROR) << "Transfer slice failed with status: " << status;
for (auto &slice : slice_list) {
slice->markFailed();
}
Expand Down
Loading