Skip to content
Merged
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ else()
list(APPEND 3RDPART_LIB_LIST "xdnn_static")
endif()

# pipeline parallel feature
option(WITH_PIPELINE_PARALLEL "Build with pipeline parallel" OFF)
if(WITH_PIPELINE_PARALLEL)
message(STATUS "Notice: Building with pipeline parallel.")
add_definitions(-DPIPELINE_PARALLEL=true)
endif()

# Enable AVX512_FP16 optimization
# add_definitions(-DAVX512_FP32_WEIGHT_ONLY_FP16=true)
add_definitions(-DAVX512_FP16_WEIGHT_ONLY_FP16=true)
Expand Down
52 changes: 42 additions & 10 deletions src/comm_helper/comm_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,45 @@

static ccl::communicator *pcomm;

extern "C" int init(int *rank, int *size) {
// world_color is initialized to pipeline_parallel_stages_num(pp_size)
// and will be re-assign to world_color of MPI == ppRank
extern "C" int init(int *world_size, int *world_rank, int *world_color) {
ccl::init();

MPI_Init(NULL, NULL);
MPI_Comm_size(MPI_COMM_WORLD, size);
MPI_Comm_rank(MPI_COMM_WORLD, rank);
MPI_Comm_size(MPI_COMM_WORLD, world_size);
MPI_Comm_rank(MPI_COMM_WORLD, world_rank);

// world_color = world_rank / tpSize = world_rank / (world_size / ppSize)
// like: world_color = 0~7 / (8 / 4), XFT_PIPELINE_STAGES = ppSize = 4; tpSize = 2
// world_rank = 0, 1, -> world_color = ppRank = 0, 0, -> tpRank = 0, 1;
// 2, 3, 1, 1, 0, 1;
// 4, 5, 2, 2, 0, 1;
// 6, 7; 3, 3; 0, 1;
*world_color = *world_rank / (*world_size / *world_color);
MPI_Comm row_comm;
MPI_Comm_split(MPI_COMM_WORLD, *world_color, *world_rank, &row_comm);

int row_size, row_rank;
MPI_Comm_size(row_comm, &row_size);
MPI_Comm_rank(row_comm, &row_rank);

ccl::shared_ptr_class<ccl::kvs> kvs;
ccl::kvs::address_type mainAddr;

if (*rank == 0) {
if (row_rank == 0) {
kvs = ccl::create_main_kvs();
mainAddr = kvs->get_address();
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm);
} else {
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm);
kvs = ccl::create_kvs(mainAddr);
}

pcomm = new ccl::communicator(ccl::create_communicator(*size, *rank, kvs));
pcomm = new ccl::communicator(ccl::create_communicator(row_size, row_rank, kvs));

*rank = pcomm->rank();
*size = pcomm->size();
*world_size = pcomm->size();
*world_rank = pcomm->rank();

#ifdef USE_SHM
char myHostname[MPI_MAX_PROCESSOR_NAME];
Expand All @@ -53,7 +69,7 @@ extern "C" int init(int *rank, int *size) {
MPI_COMM_WORLD);

int sameHostnames = 1;
for (int i = 1; i < *size; i++) {
for (int i = 1; i < *world_size; i++) {
if (strcmp(myHostname, &all_hostnames[i * MPI_MAX_PROCESSOR_NAME]) != 0) {
sameHostnames = 0;
break;
Expand Down Expand Up @@ -89,4 +105,20 @@ extern "C" void broadcast(int *buf, size_t count) {
extern "C" void allgatherv(
const float *sendBuf, size_t count, float *recvBuf, const std::vector<long unsigned int> &recvCounts) {
ccl::allgatherv(sendBuf, count, recvBuf, recvCounts, *pcomm).wait();
}

extern "C" void worldSendFP32(const float *buf, int count, int dest, int tag) {
MPI_Send((const void *)buf, count, MPI_FLOAT, dest, tag, MPI_COMM_WORLD);
}

extern "C" void worldRecvFP32(float *buf, int count, int source, int tag) {
MPI_Recv((void *)buf, count, MPI_FLOAT, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}

extern "C" void worldSendINT32(const int32_t *buf, int count, int dest, int tag) {
MPI_Send((const void *)buf, count, MPI_INT32_T, dest, tag, MPI_COMM_WORLD);
}

extern "C" void worldRecvINT32(int32_t *buf, int count, int source, int tag) {
MPI_Recv((void *)buf, count, MPI_INT32_T, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
12 changes: 11 additions & 1 deletion src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ struct DecoderContext {
// # of splits (the same as NUMA node number in the system)
const int numSplit;

// For pipeline parallel and tensor parallel config
int ppSize = 1; // pipeline parallel stage size
int ppRank = 0; // pipeline parallel stage rank
int tpSize = 1; // tensor parallel size
int tpRank = 0; // tensor parallel rank

enum ActivationType { RELU, GELU, SWIGLU, SILU };
ActivationType actType;

Expand All @@ -105,7 +111,7 @@ struct DecoderContext {
public:
DecoderContext(int _layers, int _hiddenSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act,
float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength,
int _splitIdx, int _splits, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0)
int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0)
: layers(_layers)
, hiddenSize(_hiddenSize)
, intermediateSize(_imSize)
Expand All @@ -119,6 +125,10 @@ struct DecoderContext {
, ropeParamsPtr(_ropeParamsPtr)
, splitIdx(_splitIdx)
, numSplit(_splits)
, ppSize(_ppSize)
, ppRank(_ppRank)
, tpSize(_splits)
, tpRank(_splitIdx)
, epsilon(epsilon) {
if (attHeadNum != 0) {
this->attHeadSize = hiddenSize / attHeadNum;
Expand Down
6 changes: 4 additions & 2 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class Attention {
imBuffer.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride());
inputBuffer.Assign(tmp, rows, cols, stride);
}

// TODO: refine the logic (and support large inputSeqLen when pastSeqLen > 0)
if constexpr (std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>) {
if (pastSeqLen == 0) {
Expand All @@ -284,8 +285,9 @@ class Attention {
if (ctx->inputSeqLen >= 1024 && pastSeqLen == 0)
flashAttention(
ctx, qkvGroupMatMul, outBuffer, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
else
else {
fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
}
}
t4.release();

Expand Down Expand Up @@ -375,7 +377,7 @@ class Attention {
// to make sure it works better (the logic here is trying to make sure each head of BMM result [seq * seq] in cache)
// WARN: reserve field in context is used to make it effective for all layers, do not change it in other places
int &mBlockSize = ctx->reserved1;
if (layerId == 0) {
if (layerId % (ctx->layers / ctx->ppSize) == 0) {
// TODO: if pastSeqLen > 0 and inputSeqLen large.
if (pastSeqLen == 0) {
const int l2CacheSize = 2 * 1024 * 1024; // TODO: get it dynamically
Expand Down
7 changes: 7 additions & 0 deletions src/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} MODEL_SRCS)

add_library(models OBJECT ${MODEL_SRCS})
add_dependencies(models utils)

if(WITH_PIPELINE_PARALLEL)
find_package(MPI REQUIRED)
include_directories(${MPI_INCLUDE_PATH})
add_definitions(${MPI_CXX_COMPILE_FLAGS})
target_link_libraries(models ${MPI_CXX_LIBRARIES})
endif()
Loading