Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions csrc/mla_preprocess/op_host/mla_preprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048;
constexpr uint32_t L0C_SIZE = 128 * 1024;
constexpr uint32_t CONCAT_SIZE = 512;

constexpr uint32_t HIDDEN_STRATE = 7168;
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
Expand All @@ -62,6 +61,7 @@ constexpr uint32_t INDEX_WUQ = 18;
constexpr uint32_t INDEX_WUK = 20;

constexpr uint32_t MAX_SUPPORT_TOKEN_NUMS = 1024;
constexpr uint32_t MAX_CACHE_MODE_NUMS = 3;

inline uint32_t CeilDiv(const uint32_t dividend, const uint32_t divisor)
{
Expand Down Expand Up @@ -122,6 +122,7 @@ struct PlatformInfo {
};

struct OpParam {
uint32_t hiddenStateDim;
uint32_t N;
uint32_t headNum;
int32_t cacheMode;
Expand Down Expand Up @@ -391,7 +392,7 @@ class MlaPreprocessTiling
void MlaPreprocessTiling::RmsNormQuantTiling()
{
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
tilingData->rmsNumCol1 = HIDDEN_STRATE;
tilingData->rmsNumCol1 = opParam.hiddenStateDim;
tilingData->rmsNumRow1 = opParam.N;
tilingData->rmsQuantMin1 = -CONST_128;
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
Expand Down Expand Up @@ -507,9 +508,9 @@ void MlaPreprocessTiling::EinSumQuantTiling()
void MlaPreprocessTiling::SetMlapoWorkSpace()
{
uint64_t s1wsFactor =
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
: HIDDEN_STRATE * sizeof(int8_t));
: opParam.hiddenStateDim * sizeof(int8_t));
uint64_t workSizeS1 = s1wsFactor;
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
Expand Down Expand Up @@ -552,21 +553,21 @@ void MlaPreprocessTiling::Init()
{
tilingData->numCore = platformInfo.coreNumAic;
tilingData->n = opParam.N;

tilingData->hiddenStateDim = opParam.hiddenStateDim;
bool deqOnTheFly = false;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
deqOnTheFly = true;
}

PpMatmulTilingApi mm1TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE, // k
HIDDEN_STRATE_MM, // n
false, // transA
true, // transB
true, // enDequant
deqOnTheFly); // in bf16.cce?
1, // numBatch
opParam.N, // m
opParam.hiddenStateDim, // k
HIDDEN_STRATE_MM, // n
false, // transA
true, // transB
true, // enDequant
deqOnTheFly); // in bf16.cce?
mm1TilingApi.GetTilingData(tilingData->mm1);

PpMatmulTilingApi mm2TilingApi(platformInfo,
Expand Down Expand Up @@ -654,8 +655,10 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(

int32_t N = hiddenState.sizes()[0];
int32_t headNum = wuk.sizes()[0];
uint32_t hiddenStateDim = hiddenState.sizes().back();

OpParam opParam;
opParam.hiddenStateDim = hiddenStateDim;
opParam.N = N;
opParam.headNum = headNum;
opParam.cacheMode = static_cast<int32_t>(cacheMode);
Expand All @@ -677,18 +680,19 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
// tiling
int32_t bIndex = N - 1;
uint32_t tilingSize = sizeof(MlaTilingData);
int32_t tilingOffset = tilingSize * MAX_SUPPORT_TOKEN_NUMS * (opParam.cacheMode - 1) + (tilingSize * bIndex);
static auto global_tiling_data =
at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS},
at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS * MAX_CACHE_MODE_NUMS},
at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device()));
if (bIndex >= 0 && bIndex < MAX_SUPPORT_TOKEN_NUMS) {
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex), tilingSize, &tilingData, tilingSize,
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + tilingOffset, tilingSize, &tilingData, tilingSize,
ACL_MEMCPY_HOST_TO_DEVICE);
} else {
// Handle the case where bIndex is out of range
TORCH_CHECK(false, "bIndex is out of range: ", bIndex);
}
at::Tensor tiling = at::from_blob(
global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex),
global_tiling_data.data_ptr<uint8_t>() + tilingOffset,
tilingSize,
at::kByte);

Expand Down
3 changes: 3 additions & 0 deletions csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ struct MlaTilingData {
uint32_t esqHeadTail{0};
uint32_t esqColLoop{0};
uint32_t esqColTail{0};

// hidden state dimension
uint32_t hiddenStateDim{7168};
};

#endif // MLAPREPROCESS_TILING_H
2 changes: 1 addition & 1 deletion csrc/mla_preprocess/op_kernel/mla_preprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format
constexpr uint8_t CACHE_MODE_NZCACHE = 3;

// pp matmul
constexpr uint32_t HIDDTEN_STATE = 7168;
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
constexpr uint32_t HALF_BLOCK_SIZE = 64;
constexpr uint32_t HALF_VECTOR_SIZE = 64;
Expand Down Expand Up @@ -102,6 +101,7 @@ constexpr uint32_t KEY_FP16_CACHEMODE_0_QUANTMODE_0 = 0;
constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
constexpr uint32_t KEY_BF16_CACHEMODE_2_QUANTMODE_0 = 258;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;

enum class QuantMode : int32_t {
Expand Down
17 changes: 17 additions & 0 deletions csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(

mlaTilingData.tilingKey = tilingData->tilingKey;
mlaTilingData.n = tilingData->n;
mlaTilingData.hiddenStateDim = tilingData->hiddenStateDim;

mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
mlaTilingData.mm1.m = tilingData->mm1.m;
Expand Down Expand Up @@ -202,6 +203,22 @@ extern "C" __global__ __aicore__ void mla_preprocess(
}
break;
}
case KEY_BF16_CACHEMODE_2_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 2, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm2Qm0(mlaTilingData, tiling);
opBf16Cm2Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm2Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm2Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_3_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
Expand Down
17 changes: 8 additions & 9 deletions csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,7 @@ class MLAOperation
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
}

__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
Expand Down Expand Up @@ -2692,6 +2693,7 @@ class MLAOperation
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams;

uint32_t num_core_;
Expand Down Expand Up @@ -2795,18 +2797,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t base_offset = hiddenStateDim * 6;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
base_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64);
base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
}
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);
Expand Down
24 changes: 12 additions & 12 deletions csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,7 @@ class MLAOperation
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
}

__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
Expand Down Expand Up @@ -2294,6 +2295,7 @@ class MLAOperation
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams;

// rmsnormQuant
Expand Down Expand Up @@ -2389,21 +2391,19 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t gamma_offset = hiddenStateDim * 2;
const uint32_t beta_offset = gamma_offset + hiddenStateDim * 2;
const uint32_t scale_offset = beta_offset + hiddenStateDim * 2;
AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2);
AscendC::LocalTensor<half> beta_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<half> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(gamma_offset);
AscendC::LocalTensor<half> beta_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(beta_offset);
AscendC::LocalTensor<half> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(scale_offset);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(scale_offset + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(scale_offset + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
scale_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 32);
scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32);
Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor,
res3_tensor);
}
Expand Down
Loading
Loading