Skip to content

Commit 34ccd8b

Browse files
author
youxiao
committed
Adapt to adxl connection auto release feature
1 parent 496cecd commit 34ccd8b

File tree

3 files changed

+46
-16
lines changed

3 files changed

+46
-16
lines changed

mooncake-common/common.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,15 @@ if (USE_ASCEND OR USE_ASCEND_DIRECT)
160160
file(GLOB ASCEND_TOOLKIT_ROOT "/usr/local/Ascend/ascend-toolkit/latest/*-linux")
161161
endif()
162162
set(ASCEND_LIB_DIR "${ASCEND_TOOLKIT_ROOT}/lib64")
163-
set(ASCEND_DEVLIB_DIR "${ASCEND_TOOLKIT_ROOT}/devlib")
164163
set(ASCEND_INCLUDE_DIR "${ASCEND_TOOLKIT_ROOT}/include")
165164
add_compile_options(-Wno-ignored-qualifiers)
166165
include_directories(/usr/local/include /usr/include ${ASCEND_INCLUDE_DIR})
167-
link_directories(${ASCEND_LIB_DIR} ${ASCEND_DEVLIB_DIR})
166+
link_directories(${ASCEND_LIB_DIR})
168167
endif()
169168

170169
if (USE_ASCEND)
170+
set(ASCEND_DEVLIB_DIR "${ASCEND_TOOLKIT_ROOT}/devlib")
171+
link_directories(${ASCEND_DEVLIB_DIR})
171172
add_compile_definitions(USE_ASCEND)
172173
endif()
173174

mooncake-transfer-engine/include/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ class AscendDirectTransport : public Transport {
8181

8282
void processSliceList(const std::vector<Slice *> &slice_list);
8383

84+
void connectAndTransfer(const std::string &target_adxl_engine_name,
85+
adxl::TransferOp operation,
86+
const std::vector<Slice *> &slice_list,
87+
int32_t times = 0);
88+
8489
void localCopy(TransferRequest::OpCode opcode,
8590
const std::vector<Slice *> &slice_list);
8691

@@ -105,7 +110,7 @@ class AscendDirectTransport : public Transport {
105110
int checkAndConnect(const std::string &target_adxl_engine_name);
106111

107112
int disconnect(const std::string &target_adxl_engine_name,
108-
int32_t timeout_in_millis);
113+
int32_t timeout_in_millis, bool force = false);
109114

110115
std::atomic_bool running_;
111116
std::unique_ptr<adxl::AdxlEngine> adxl_;
@@ -131,6 +136,8 @@ class AscendDirectTransport : public Transport {
131136
bool use_buffer_pool_{false};
132137

133138
int32_t base_port_ = 20000;
139+
140+
std::unordered_set<SegmentID> need_update_metadata_segs_;
134141
};
135142

136143
} // namespace mooncake

mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
namespace mooncake {
3737
namespace {
3838
constexpr size_t kMemcpyBatchLimit = 4096;
39-
}
39+
constexpr int32_t kMaxAdxlConnectRetries = 3;
40+
} // namespace
4041
AscendDirectTransport::AscendDirectTransport() : running_(false) {}
4142

4243
AscendDirectTransport::~AscendDirectTransport() {
@@ -78,6 +79,7 @@ AscendDirectTransport::~AscendDirectTransport() {
7879
}
7980
}
8081
addr_to_mem_handle_.clear();
82+
adxl_->Finalize();
8183
}
8284

8385
int AscendDirectTransport::install(std::string &local_server_name,
@@ -547,8 +549,9 @@ void AscendDirectTransport::processSliceList(
547549
if (slice_list.empty()) {
548550
return;
549551
}
550-
auto target_segment_desc =
551-
metadata_->getSegmentDescByID(slice_list[0]->target_id);
552+
auto it = need_update_metadata_segs_.find(slice_list[0]->target_id);
553+
auto target_segment_desc = metadata_->getSegmentDescByID(
554+
slice_list[0]->target_id, (it != need_update_metadata_segs_.end()));
552555
if (!target_segment_desc) {
553556
LOG(ERROR) << "Cannot find segment descriptor for target_id: "
554557
<< slice_list[0]->target_id;
@@ -557,6 +560,9 @@ void AscendDirectTransport::processSliceList(
557560
}
558561
return;
559562
}
563+
if (it != need_update_metadata_segs_.end()) {
564+
need_update_metadata_segs_.erase(it);
565+
}
560566
auto target_adxl_engine_name =
561567
(target_segment_desc->rank_info.hostIp + ":" +
562568
std::to_string(target_segment_desc->rank_info.hostPort));
@@ -582,10 +588,14 @@ void AscendDirectTransport::processSliceList(
582588
<< "us";
583589
return;
584590
}
591+
return connectAndTransfer(target_adxl_engine_name, operation, slice_list);
592+
}
593+
594+
void AscendDirectTransport::connectAndTransfer(
595+
const std::string &target_adxl_engine_name, adxl::TransferOp operation,
596+
const std::vector<Slice *> &slice_list, int32_t times) {
585597
int ret = checkAndConnect(target_adxl_engine_name);
586598
if (ret != 0) {
587-
LOG(ERROR) << "Failed to connect to segment: "
588-
<< target_segment_desc->name;
589599
for (auto &slice : slice_list) {
590600
slice->markFailed();
591601
}
@@ -613,6 +623,14 @@ void AscendDirectTransport::processSliceList(
613623
std::chrono::steady_clock::now() - start)
614624
.count()
615625
<< " us";
626+
} else if (status == adxl::NOT_CONNECTED) {
627+
LOG(INFO) << "Connection reset by backend, retry times:" << times;
628+
disconnect(target_adxl_engine_name, 0, true);
629+
if (times < kMaxAdxlConnectRetries) {
630+
return connectAndTransfer(target_adxl_engine_name, operation,
631+
slice_list, times + 1);
632+
}
633+
return;
616634
} else {
617635
if (status == adxl::TIMEOUT) {
618636
LOG(ERROR) << "Transfer timeout to: " << target_adxl_engine_name
@@ -628,6 +646,7 @@ void AscendDirectTransport::processSliceList(
628646
// the connection is probably broken.
629647
// set small timeout to just release local res.
630648
disconnect(target_adxl_engine_name, 10);
649+
need_update_metadata_segs_.emplace(slice_list[0]->target_id);
631650
}
632651
}
633652

@@ -844,21 +863,24 @@ int AscendDirectTransport::checkAndConnect(
844863
}
845864

846865
int AscendDirectTransport::disconnect(
847-
const std::string &target_adxl_engine_name, int32_t timeout_in_millis) {
866+
const std::string &target_adxl_engine_name, int32_t timeout_in_millis,
867+
bool force) {
848868
std::lock_guard<std::mutex> lock(connection_mutex_);
849869
auto it = connected_segments_.find(target_adxl_engine_name);
850870
if (it == connected_segments_.end()) {
851871
LOG(INFO) << "Target adxl engine: " << target_adxl_engine_name
852872
<< " is not connected.";
853873
return 0;
854874
}
855-
auto status =
856-
adxl_->Disconnect(target_adxl_engine_name.c_str(), timeout_in_millis);
857-
if (status != adxl::SUCCESS) {
858-
LOG(ERROR) << "Failed to disconnect to: " << target_adxl_engine_name
859-
<< ", status: " << status;
860-
connected_segments_.erase(target_adxl_engine_name);
861-
return -1;
875+
if (!force) {
876+
auto status = adxl_->Disconnect(target_adxl_engine_name.c_str(),
877+
timeout_in_millis);
878+
if (status != adxl::SUCCESS) {
879+
LOG(ERROR) << "Failed to disconnect to: " << target_adxl_engine_name
880+
<< ", status: " << status;
881+
connected_segments_.erase(target_adxl_engine_name);
882+
return -1;
883+
}
862884
}
863885
connected_segments_.erase(target_adxl_engine_name);
864886
return 0;

0 commit comments

Comments
 (0)