diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e8ce36c..21783d16 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,6 +132,13 @@ else() list(APPEND 3RDPART_LIB_LIST "xdnn_static") endif() +# pipeline parallel feature +option(WITH_PIPELINE_PARALLEL "Build with pipeline parallel" OFF) +if(WITH_PIPELINE_PARALLEL) + message(STATUS "Notice: Building with pipeline parallel.") + add_definitions(-DPIPELINE_PARALLEL=true) +endif() + # Enable AVX512_FP16 optimization # add_definitions(-DAVX512_FP32_WEIGHT_ONLY_FP16=true) add_definitions(-DAVX512_FP16_WEIGHT_ONLY_FP16=true) diff --git a/src/comm_helper/comm_helper.cpp b/src/comm_helper/comm_helper.cpp index 4591e014..96b7e21f 100644 --- a/src/comm_helper/comm_helper.cpp +++ b/src/comm_helper/comm_helper.cpp @@ -18,29 +18,45 @@ static ccl::communicator *pcomm; -extern "C" int init(int *rank, int *size) { +// world_color is initialized to pipeline_parallel_stages_num(pp_size) +// and will be re-assign to world_color of MPI == ppRank +extern "C" int init(int *world_size, int *world_rank, int *world_color) { ccl::init(); MPI_Init(NULL, NULL); - MPI_Comm_size(MPI_COMM_WORLD, size); - MPI_Comm_rank(MPI_COMM_WORLD, rank); + MPI_Comm_size(MPI_COMM_WORLD, world_size); + MPI_Comm_rank(MPI_COMM_WORLD, world_rank); + + // world_color = world_rank / tpSize = world_rank / (world_size / ppSize) + // like: world_color = 0~7 / (8 / 4), XFT_PIPELINE_STAGES = ppSize = 4; tpSize = 2 + // world_rank = 0, 1, -> world_color = ppRank = 0, 0, -> tpRank = 0, 1; + // 2, 3, 1, 1, 0, 1; + // 4, 5, 2, 2, 0, 1; + // 6, 7; 3, 3; 0, 1; + *world_color = *world_rank / (*world_size / *world_color); + MPI_Comm row_comm; + MPI_Comm_split(MPI_COMM_WORLD, *world_color, *world_rank, &row_comm); + + int row_size, row_rank; + MPI_Comm_size(row_comm, &row_size); + MPI_Comm_rank(row_comm, &row_rank); ccl::shared_ptr_class kvs; ccl::kvs::address_type mainAddr; - if (*rank == 0) { + if (row_rank == 0) { kvs = ccl::create_main_kvs(); mainAddr = kvs->get_address(); - MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm); } else { - MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm); kvs = ccl::create_kvs(mainAddr); } - pcomm = new ccl::communicator(ccl::create_communicator(*size, *rank, kvs)); + pcomm = new ccl::communicator(ccl::create_communicator(row_size, row_rank, kvs)); - *rank = pcomm->rank(); - *size = pcomm->size(); + *world_size = pcomm->size(); + *world_rank = pcomm->rank(); #ifdef USE_SHM char myHostname[MPI_MAX_PROCESSOR_NAME]; @@ -53,7 +69,7 @@ extern "C" int init(int *rank, int *size) { MPI_COMM_WORLD); int sameHostnames = 1; - for (int i = 1; i < *size; i++) { + for (int i = 1; i < *world_size; i++) { if (strcmp(myHostname, &all_hostnames[i * MPI_MAX_PROCESSOR_NAME]) != 0) { sameHostnames = 0; break; @@ -89,4 +105,20 @@ extern "C" void broadcast(int *buf, size_t count) { extern "C" void allgatherv( const float *sendBuf, size_t count, float *recvBuf, const std::vector &recvCounts) { ccl::allgatherv(sendBuf, count, recvBuf, recvCounts, *pcomm).wait(); +} + +extern "C" void worldSendFP32(const float *buf, int count, int dest, int tag) { + MPI_Send((const void *)buf, count, MPI_FLOAT, dest, tag, MPI_COMM_WORLD); +} + +extern "C" void worldRecvFP32(float *buf, int count, int source, int tag) { + MPI_Recv((void *)buf, count, MPI_FLOAT, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE); +} + +extern "C" void worldSendINT32(const int32_t *buf, int count, int dest, int tag) { + MPI_Send((const void *)buf, count, MPI_INT32_T, dest, tag, MPI_COMM_WORLD); +} + +extern "C" void worldRecvINT32(int32_t *buf, int count, int source, int tag) { + MPI_Recv((void *)buf, count, MPI_INT32_T, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } \ No newline at end of file diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 6be7f62d..e2e876de 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -84,6 +84,12 @@ struct DecoderContext { // # of splits (the same as NUMA node number in the system) const int numSplit; + // For pipeline parallel and tensor parallel config + int ppSize = 1; // pipeline parallel stage size + int ppRank = 0; // pipeline parallel stage rank + int tpSize = 1; // tensor parallel size + int tpRank = 0; // tensor parallel rank + enum ActivationType { RELU, GELU, SWIGLU, SILU }; ActivationType actType; @@ -105,7 +111,7 @@ struct DecoderContext { public: DecoderContext(int _layers, int _hiddenSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act, float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength, - int _splitIdx, int _splits, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0) + int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0) : layers(_layers) , hiddenSize(_hiddenSize) , intermediateSize(_imSize) @@ -119,6 +125,10 @@ struct DecoderContext { , ropeParamsPtr(_ropeParamsPtr) , splitIdx(_splitIdx) , numSplit(_splits) + , ppSize(_ppSize) + , ppRank(_ppRank) + , tpSize(_splits) + , tpRank(_splitIdx) , epsilon(epsilon) { if (attHeadNum != 0) { this->attHeadSize = hiddenSize / attHeadNum; diff --git a/src/layers/attention.h b/src/layers/attention.h index e03d82c2..0e22a755 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -273,6 +273,7 @@ class Attention { imBuffer.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride()); inputBuffer.Assign(tmp, rows, cols, stride); } + // TODO: refine the logic (and support large inputSeqLen when pastSeqLen > 0) if constexpr (std::is_same_v && std::is_same_v) { if (pastSeqLen == 0) { @@ -284,8 +285,9 @@ class Attention { if (ctx->inputSeqLen >= 1024 && pastSeqLen == 0) flashAttention( ctx, qkvGroupMatMul, outBuffer, imBuffer, presentKey, presentValue, attnMask, pastSeqLen); - else + else { fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen); + } } t4.release(); @@ -375,7 +377,7 @@ class Attention { // to make sure it works better (the logic here is trying to make sure each head of BMM result [seq * seq] in cache) // WARN: reserve field in context is used to make it effective for all layers, do not change it in other places int &mBlockSize = ctx->reserved1; - if (layerId == 0) { + if (layerId % (ctx->layers / ctx->ppSize) == 0) { // TODO: if pastSeqLen > 0 and inputSeqLen large. if (pastSeqLen == 0) { const int l2CacheSize = 2 * 1024 * 1024; // TODO: get it dynamically diff --git a/src/models/CMakeLists.txt b/src/models/CMakeLists.txt index 18f4fe52..8b5bc081 100644 --- a/src/models/CMakeLists.txt +++ b/src/models/CMakeLists.txt @@ -18,3 +18,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} MODEL_SRCS) add_library(models OBJECT ${MODEL_SRCS}) add_dependencies(models utils) + +if(WITH_PIPELINE_PARALLEL) + find_package(MPI REQUIRED) + include_directories(${MPI_INCLUDE_PATH}) + add_definitions(${MPI_CXX_COMPILE_FLAGS}) + target_link_libraries(models ${MPI_CXX_LIBRARIES}) +endif() \ No newline at end of file diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 971dfc59..5e4e86f0 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -63,6 +63,81 @@ struct MlpTypeExtractor> { using Tout = OutT; }; +/* +Pipeline parallel and tensor parallel introduction: + + 1) MPI_Instances = 16,XFT_PIPELINE_STAGES = 4 => ctx->ppSize = 4, ctx->tpSize = 4 + 2) TP sync by oneCCL(row_comm) or shared_memory + 3) PP sync by MPI MPI_COMM_WORLD + + World Rank: => Row Rank: => Rank: tp0 tp1 tp2 tp3 + [ 0, 1, 2, 3, [ 0, 1, 2, 3]; pp0 [ 0, 1, 2, 3]; + 4, 5, 6, 7, [ 0, 1, 2, 3]; pp1 [ 0, 1, 2, 3]; + 8, 9, 10, 11, [ 0, 1, 2, 3]; pp2 [ 0, 1, 2, 3]; + 12, 13, 14, 15]; [ 0, 1, 2, 3]; pp3 [ 0, 1, 2, 3]; + + Prompts + │ + ┌──────────────────┬─────────┴────────┬──────────────────┐ + │ │ │ │ + ▼ ▼ ▼ ▼ + Embedding(PP0) Embedding(PP0) Embedding(PP0) Embedding(PP0) + │ │ │ │ + PP0 │ │ │ │ + ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐ + │ TP0 │ TP1 │ TP2 │ TP3 │ layer0-7 │ + │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ │ + │ │ OMP │ │ OMP │ │ OMP │ │ OMP │ │ + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ + │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ + │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ │ + │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐ │ + │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘ │ + └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘ + PP1 │ MPI Send/Recv │ │ │ + ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐ + │ TP0 │ TP1 │ TP2 │ TP3 │ layer8-15 │ + │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ │ + │ │ OMP │ │ OMP │ │ OMP │ │ OMP │ │ + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ + │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ + │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ │ + │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐ │ + │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘ │ + └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘ + PP2 │ MPI Send/Recv │ │ │ + ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐ + │ TP0 │ TP1 │ TP2 │ TP3 │ layer16-23 │ + │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ │ + │ │ OMP │ │ OMP │ │ OMP │ │ OMP │ │ + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ + │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ + │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ │ + │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐ │ + │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘ │ + └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘ + PP3 │ MPI Send/Recv │ │ │ + ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐ + │ TP0 │ TP1 │ TP2 │ TP3 │ layer24-31 │ + │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ │ + │ │ OMP │ │ OMP │ │ OMP │ │ OMP │ │ + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ + │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ + │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ │ + │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐ │ + │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘ │ + └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘ + │ │ │ │ + ▼ ▼ ▼ ▼ + Predictor(PP3) Predictor(PP3) Predictor(PP3) Predictor(PP3) + │ MPI Send/Recv │ │ │ + ▼ ▼ ▼ ▼ + Searchers(PP0) Searchers(PP0) Searchers(PP0) Searchers(PP0) + │ + ▼ + Output +*/ + // Template parameters: // ATTN_CLS - class for attention impl. // MLP_CLS - MLP implementation @@ -134,7 +209,15 @@ class CommonDecoder : public AbstractDecoder { vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, ropeParamsPtr); // Decoder - for (int i = 0; i < layers; ++i) { + if (layers % ctx->ppSize != 0) { + std::cerr << "Warning: layers cannot be evenly divided by pipeline parallel stage size(ppSize)." + << std::endl; + std::exit(-1); + } + + int layers_per_pp_stage = layers / ctx->ppSize; + int start_layer = ctx->ppRank * layers_per_pp_stage; + for (int i = start_layer; i < start_layer + layers_per_pp_stage; ++i) { auto pdec = new DECODER(ctx, i); this->setDecoderWeights(pdec, modelPath, i); this->decoders.push_back(pdec); @@ -215,9 +298,9 @@ class CommonDecoder : public AbstractDecoder { dbg.debugPrint("---- embedding.forward ----\n"); dbg.debugPrint("ids:\n"); dbg.dumpMatrix(ids, batchSize, inputSeqLen, inputSeqLen); - dbg.debugPrint("embBuf(rows: %d, cols: %d, stride: %d):\n", this->embBuf->Rows(), this->embBuf->Cols(), - this->embBuf->Stride()); - dbg.dumpMatrix(*this->embBuf); + dbg.debugPrint( + "embBuf(rows: %d, cols: %d, stride: %d):\n", batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize); + dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize); #endif // Prepare attention mask @@ -227,9 +310,22 @@ class CommonDecoder : public AbstractDecoder { int *positionIds = this->getPositionIds(ids, batchSize, inputSeqLen, step + this->prefixSharing); t1.release(); +#ifdef PIPELINE_PARALLEL + // if current pipeline parallel stage rank isn't the first stage, should receive previous stage data + if (ctx->ppSize > 1 && ctx->ppRank > 0) { + int curr_world_rank = ctx->ppRank * ctx->tpSize + ctx->tpRank; + int prev_world_rank = (ctx->ppRank - 1) * ctx->tpSize + ctx->tpRank; + int count = batchSize * inputSeqLen * ctx->hiddenSize; + MPI_Recv(embBuf, count, MPI_FLOAT, prev_world_rank, curr_world_rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + // TODO: Error: different scope when dynamic loading so file + // this->messenger.worldRecvFP32(embBuf, count, prev_world_rank, curr_world_rank); + } +#endif + // Decoder: forward int hiddenSize = ctx->hiddenSize; - for (int i = 0; i < this->decoders.size(); ++i) { + int layers_per_pp_stage = this->decoders.size(); + for (int i = 0; i < layers_per_pp_stage; ++i) { int workers = this->messenger.getSize(); if (step == 0 && this->prefixSharing) { // Expand the prefix KV cache for each batch @@ -279,6 +375,18 @@ class CommonDecoder : public AbstractDecoder { } } +#ifdef PIPELINE_PARALLEL + // If current pipeline stage isn't the end of stage, should send data to next stage and return nullptr + if (ctx->ppSize > 1 && ctx->ppRank < ctx->ppSize - 1) { + int next_world_rank = (ctx->ppRank + 1) * ctx->tpSize + ctx->tpRank; + int count = batchSize * inputSeqLen * ctx->hiddenSize; + MPI_Send(embBuf, count, MPI_FLOAT, next_world_rank, next_world_rank, MPI_COMM_WORLD); + // TODO: Error: different scope when dynamic loading so file + // this->messenger.worldSendFP32(embBuf, count, next_world_rank, next_world_rank); + return std::tuple(nullptr, 0, 0); + } +#endif + // Prepare input for final Layer Norm (only care about the last row of the result) // Shape of embBuf: (bs, seqLen, hiddenSize) MlpOutT *lnIn = embBuf; @@ -376,6 +484,7 @@ class CommonDecoder : public AbstractDecoder { t1.release(); // Decoder: forward + // TODO: Add PIPELINE_PARALLEL feature int hiddenSize = ctx->hiddenSize; for (int i = 0; i < this->decoders.size(); ++i) { int workers = this->messenger.getSize(); @@ -469,13 +578,16 @@ class CommonDecoder : public AbstractDecoder { DecoderContext *getDecoderContext(int layers, const int hiddenSize, const int attHeadNum, const int kvHeadNum, const int imSize, const std::string &act, const float epsilon, int vocabSize, int embeddingSize, int maxPositions, int maxPosEmbed, int maxSeqLength, RopeParams *ropeParamsPtr) { - int splits = messenger.getSize(); - int splitIdx = messenger.getRank(); + int tpSize = messenger.getSize(); + int tpRank = messenger.getRank(); + int ppSize = Env::getPipelineStage(); + int ppRank = messenger.getColor(); + // printf("ppSize: %d, ppRank: %d, tpSize: %d, tpRank: %d\n", ppSize, ppRank, tpSize, tpRank); if (context != nullptr) { if (context->hiddenSize == hiddenSize && context->attHeadNum == attHeadNum && context->kvHeadNum == kvHeadNum && context->intermediateSize == imSize - && context->splitIdx == splitIdx) { + && context->tpRank == tpRank) { return context.get(); } else { printf("Different context size not unsupported!\n"); @@ -483,7 +595,8 @@ class CommonDecoder : public AbstractDecoder { } } else { this->context.reset(new DecoderContext(layers, hiddenSize, attHeadNum, kvHeadNum, imSize, act, epsilon, - vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, splitIdx, splits, ropeParamsPtr)); + vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, ppSize, ppRank, + ropeParamsPtr)); } return this->context.get(); diff --git a/src/models/models.cpp b/src/models/models.cpp index 989984df..fcddb647 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -50,7 +50,8 @@ GenerationMode getGenerationMode(SearcherConfig &config_) { } Model::Model() : decoder(nullptr), searcher(nullptr), isNewInput(true) { - Env::initValue(); + Env::initVerbose(); + Env::initPipelineStage(); TimeLine::init(); } diff --git a/src/searchers/greedy_search.cpp b/src/searchers/greedy_search.cpp index 672ba225..0e55648e 100644 --- a/src/searchers/greedy_search.cpp +++ b/src/searchers/greedy_search.cpp @@ -13,6 +13,7 @@ // limitations under the License. // ============================================================================ #include "greedy_search.h" +#include "messenger.h" #include "search_utils.h" GreedySearch::GreedySearch(AbstractDecoder &dec, const SearcherConfig &config) @@ -27,7 +28,45 @@ GreedySearch::GreedySearch(AbstractDecoder &dec, const SearcherConfig &config) stopWordsIndex = {}; } -// Get next tokens accoring to the prompt IDs +std::vector GreedySearch::syncToken(std::tuple &result) { + // send data from last predictor stage to first embedding stage in pipeline parallel +#ifdef PIPELINE_PARALLEL + DecoderContext *ctx = decoder.getContext(); + // Messenger &messenger = decoder.getMessenger(); + + if (std::get<0>(result) == nullptr) { // The first embedding pipeline parallel stage + this->nextTokens = std::vector(batchSize, 0); + if (ctx->ppSize > 1 && ctx->ppRank == 0) { + int predictor_world_rank = (ctx->ppSize - 1) * ctx->tpSize + ctx->tpRank; + MPI_Recv(this->nextTokens.data(), batchSize, MPI_INT32_T, predictor_world_rank, predictor_world_rank, + MPI_COMM_WORLD, MPI_STATUS_IGNORE); + // TODO: Error: different scope when dynamic loading so file + // messenger.worldRecvINT32(this->nextTokens.data(), batchSize, predictor_world_rank, predictor_world_rank); + } + } else { // The last predictor pipeline parallel stage + this->nextTokens = this->search(result); + if (ctx->ppSize > 1 && ctx->ppRank == ctx->ppSize - 1) { + int embedding_world_rank = 0 * ctx->tpSize + ctx->tpRank; + int predictor_world_rank = (ctx->ppSize - 1) * ctx->tpSize + ctx->tpRank; + MPI_Send(this->nextTokens.data(), batchSize, MPI_INT32_T, embedding_world_rank, predictor_world_rank, + MPI_COMM_WORLD); + // TODO: Error: different scope when dynamic loading so file + // messenger.worldSendINT32(this->nextTokens.data(), batchSize, embedding_world_rank, predictor_world_rank); + } + } +#else + this->nextTokens = this->search(result); +#endif + + this->curLen++; + for (int batchId = 0; batchId < batchSize; ++batchId) { + output.insert(output.begin() + (batchId + 1) * curLen - 1, nextTokens[batchId]); + } + + return this->nextTokens; +} + +// Get next tokens accoring to the prompt IDs for first token std::vector GreedySearch::getNextToken(int *ids, int batchSize, int seqLen) { TimeLine t("1st Token"); this->step = 0; @@ -46,30 +85,17 @@ std::vector GreedySearch::getNextToken(int *ids, int batchSize, int seqLen) std::tuple result = decoder.forward(ids, dims, this->step++); - this->nextTokens = search(result); - - this->curLen++; - for (int batchId = 0; batchId < batchSize; ++batchId) { - output.insert(output.begin() + (batchId + 1) * curLen - 1, nextTokens[batchId]); - } - - return this->nextTokens; + return this->syncToken(result); } -// Get next tokens according to previous predicted ID +// Get next tokens according to previous predicted ID for next tokens std::vector GreedySearch::getNextToken() { TimeLine t("Next Token"); int64_t dims[3] = {batchSize, 1, 1}; - std::tuple result = decoder.forward(nextTokens.data(), dims, this->step++); - this->nextTokens = search(result); - - this->curLen++; - for (int batchId = 0; batchId < batchSize; ++batchId) { - output.insert(output.begin() + (batchId + 1) * curLen - 1, nextTokens[batchId]); - } + std::tuple result = decoder.forward(nextTokens.data(), dims, this->step++); - return this->nextTokens; + return this->syncToken(result); } bool GreedySearch::isDone() { diff --git a/src/searchers/greedy_search.h b/src/searchers/greedy_search.h index ee478875..af33e5a9 100644 --- a/src/searchers/greedy_search.h +++ b/src/searchers/greedy_search.h @@ -34,6 +34,7 @@ class GreedySearch : public AbstractSearcher { bool setStopWords(std::vector> stopWordsList); private: + std::vector syncToken(std::tuple &result); std::vector search(std::tuple &result); AbstractDecoder &decoder; diff --git a/src/searchers/sample_search.cpp b/src/searchers/sample_search.cpp index 6ae6c2db..872f5dc8 100644 --- a/src/searchers/sample_search.cpp +++ b/src/searchers/sample_search.cpp @@ -59,6 +59,7 @@ std::vector SampleSearch::getNextToken(int *ids, int batchSize, int seqLen) std::tuple result = decoder.forward(ids, dims, this->step++); nextTokens.resize(batchSize); + // TODO: Add PIPELINE_PARALLEL feature sample(result); this->curLen++; @@ -75,6 +76,7 @@ std::vector SampleSearch::getNextToken() { int64_t dims[3] = {batchSize, 1, 1}; std::tuple result = decoder.forward(nextTokens.data(), dims, this->step++); + // TODO: Add PIPELINE_PARALLEL feature sample(result); this->curLen++; diff --git a/src/utils/messenger.h b/src/utils/messenger.h index c5565476..f0325b21 100644 --- a/src/utils/messenger.h +++ b/src/utils/messenger.h @@ -24,6 +24,7 @@ #include "oneapi/ccl.hpp" #include "shm_reduction.h" #include "timeline.h" +#include "verbose.h" class Messenger { private: @@ -46,7 +47,7 @@ class Messenger { exit(-1); } - helperInit = (int (*)(int *, int *))dlsym(commHelperHanlde, "init"); + helperInit = (int (*)(int *, int *, int *))dlsym(commHelperHanlde, "init"); helperFreePCOMM = (void (*)())dlsym(commHelperHanlde, "freePCOMM"); helperAllreduce = (void (*)(float *, float *, size_t))dlsym(commHelperHanlde, "allreduce"); helperAllreduceBF16 = (void (*)(bfloat16_t *, bfloat16_t *, size_t))dlsym(commHelperHanlde, "allreduceBF16"); @@ -54,9 +55,15 @@ class Messenger { helperAllgatherv = (void (*)(const float *, size_t, float *, const std::vector &))dlsym( commHelperHanlde, "allgatherv"); + helperWorldSendFP32 = (void (*)(const float *, int, int, int))dlsym(commHelperHanlde, "worldSendFP32"); + helperWorldRecvFP32 = (void (*)(float *, int, int, int))dlsym(commHelperHanlde, "worldRecvFP32"); + helperWorldSendINT32 = (void (*)(const int32_t *, int, int, int))dlsym(commHelperHanlde, "worldSendINT32"); + helperWorldRecvINT32 = (void (*)(int32_t *, int, int, int))dlsym(commHelperHanlde, "worldRecvINT32"); + atexit(Messenger::mpi_finalize); - int sameHostnames = (*helperInit)(&rank, &size); + color = Env::getPipelineStage(); + int sameHostnames = (*helperInit)(&size, &rank, &color); #ifdef USE_SHM if (sameHostnames && !std::getenv("XFT_ONECCL")) { @@ -88,6 +95,8 @@ class Messenger { int getSize() { return size; } + int getColor() { return color; } + // From some example code of oneCCL, inplace reducing is supported // Only float is used now void reduceAdd(float *sendBuf, float *recvBuf, size_t count) { @@ -144,6 +153,22 @@ class Messenger { if (check()) { (*helperAllgatherv)(send_buf, count, recv_buf, recv_counts); } } + void worldSendFP32(const float *buf, int count, int dest, int tag) { + if (check()) { (*helperWorldSendFP32)(buf, count, dest, tag); } + } + + void worldRecvFP32(float *buf, int count, int source, int tag) { + if (check()) { (*helperWorldRecvFP32)(buf, count, source, tag); } + } + + void worldSendINT32(const int32_t *buf, int count, int dest, int tag) { + if (check()) { (*helperWorldSendINT32)(buf, count, dest, tag); } + } + + void worldRecvINT32(int32_t *buf, int count, int source, int tag) { + if (check()) { (*helperWorldRecvINT32)(buf, count, source, tag); } + } + bool withMpirun() { return (std::getenv("MPI_LOCALRANKID") || std::getenv("MPI_LOCALNRANKS") || std::getenv("PMI_RANK") || std::getenv("PMI_SIZE") || std::getenv("PMIX_RANK")) @@ -176,16 +201,21 @@ class Messenger { private: int size; int rank; + int color; // Processes with the same color will be placed into the same sub-communicator bool localRanksFlag; #ifdef USE_SHM ShmReduction *pshm; #endif void *commHelperHanlde; - int (*helperInit)(int *, int *); + int (*helperInit)(int *, int *, int *); void (*helperFreePCOMM)(); void (*helperAllreduce)(float *, float *, size_t); void (*helperAllreduceBF16)(bfloat16_t *, bfloat16_t *, size_t); void (*helperBroadcast)(int *, size_t); void (*helperAllgatherv)(const float *, size_t, float *, const std::vector &); + void (*helperWorldSendFP32)(const float *, int, int, int); + void (*helperWorldRecvFP32)(float *, int, int, int); + void (*helperWorldSendINT32)(const int32_t *, int, int, int); + void (*helperWorldRecvINT32)(int32_t *, int, int, int); }; diff --git a/src/utils/verbose.h b/src/utils/verbose.h index 3f29c118..e19c9103 100644 --- a/src/utils/verbose.h +++ b/src/utils/verbose.h @@ -46,19 +46,19 @@ class Printer { class Env { private: - static int &verbose_value() { + static int &verboseValue() { static int value = 0; return value; } public: - static void initValue() { + static void initVerbose() { char *xft_verbose_value = getenv("XFT_VERBOSE"); if (xft_verbose_value != NULL) { int value = atoi(xft_verbose_value); - verbose_value() = value; + verboseValue() = value; } else { - verbose_value() = 0; + verboseValue() = 0; } // TODO: Move XFT_FAKE_MODEL here. @@ -67,7 +67,32 @@ class Env { } } - static int getVerbose() { return verbose_value(); } + static int getVerbose() { return verboseValue(); } + +// Pipeline Parallel +private: + static int &pipelineStageValue() { + static int value = 1; + return value; + } + +public: + static void initPipelineStage() { + char *xft_pipeline_value = getenv("XFT_PIPELINE_STAGES"); + if (xft_pipeline_value != NULL) { +#ifdef PIPELINE_PARALLEL + int value = atoi(xft_pipeline_value); + if (value >= 1) + pipelineStageValue() = value; +#else + printf("[WARNING] XFT_PIPELINE_STAGES need to build with WITH_PIPELINE_PARALLEL=ON.\n"); +#endif + } else { + pipelineStageValue() = 1; + } + } + + static int getPipelineStage() { return pipelineStageValue(); } }; #define GEMMVERBOSE(api_func, compute_func) \