Skip to content

Commit 69b91cf

Browse files
authored
[Kernel] Make SelfAttention prepared for AMX_FP16; More balanced task split in Cross Attention (#466)
1 parent bf57d0b commit 69b91cf

File tree

2 files changed

+134
-112
lines changed

2 files changed

+134
-112
lines changed

src/kernels/attention_kernels.h

Lines changed: 116 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,23 @@ void gemmSV(
9999
}
100100
}
101101

102+
// T is bfloat16_t or float16_t
103+
// ldb is the K value during packing
104+
template <typename T>
105+
void small_amx_gemm_16bits_compute(int m, int n, int k, T *A, int lda, T *packedB, int ldb, T *C, int ldc) {
106+
static_assert(std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float16_t>, "AMX gemm only supports BF16/FP16.");
107+
108+
if (std::is_same_v<T, bfloat16_t>) {
109+
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
110+
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
111+
} else {
112+
//xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc);
113+
}
114+
}
115+
102116
// Self attention while KV cache copy is separated
103-
template <bool fusedPack, typename Lambda1, typename Lambda2>
104-
void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
117+
template <bool fusedPack, typename T, typename Lambda1, typename Lambda2>
118+
void selfAttention_SeparateCopy(T *output, T *query, T *key, T *value, int qHeadNum,
105119
int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
106120
const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
107121
const Lambda2 &getVCache) {
@@ -126,8 +140,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
126140
auto totalPackSize
127141
= fusedPack ? threadNum * (kPackSize + vPackSize) : (batchSize * kvHeadNum) * (kPackSize + vPackSize);
128142

129-
bfloat16_t *packBuf
130-
= (bfloat16_t *)SimpleMemPool::instance().getBuffer("kv_packing", totalPackSize * sizeof(bfloat16_t));
143+
T *packBuf
144+
= (T *)SimpleMemPool::instance().getBuffer("kv_packing", totalPackSize * sizeof(T));
131145

132146
// Copy key/value to cache and pack them
133147
// If packing is not fused into computing, then pack it here
@@ -137,8 +151,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
137151
for (int i = 0; i < kvHeadNum; ++i) {
138152
const int tokens = tokenSizes[b];
139153

140-
bfloat16_t *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize);
141-
bfloat16_t *packedV = packedB + kPackSize;
154+
T *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize);
155+
T *packedV = packedB + kPackSize;
142156

143157
auto B = key + offsets[b] * kvStride + i * headSize;
144158
for (int s = 0; s < tokens; ++s) {
@@ -181,8 +195,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
181195

182196
// Prepare score buffer
183197
auto maxScoreStride = (maxTokenSize + 31) / 32 * 32;
184-
bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer(
185-
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t));
198+
T *scores = (T *)SimpleMemPool::instance().getBuffer(
199+
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(T));
186200

187201
auto totalBlocks = blkEndIndex[batchSize - 1];
188202
std::pair<int, int> packInfo[threadNum];
@@ -208,8 +222,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
208222
int tid = omp_get_thread_num();
209223
int kvHeadIdx = i / groupNum;
210224
int locationIdx = (fusedPack ? tid : b * kvHeadNum + kvHeadIdx);
211-
bfloat16_t *packedB = packBuf + locationIdx * (kPackSize + vPackSize);
212-
bfloat16_t *packedV = packedB + kPackSize;
225+
T *packedB = packBuf + locationIdx * (kPackSize + vPackSize);
226+
T *packedV = packedB + kPackSize;
213227

214228
const int tokens = tokenSizes[b];
215229
const int startSeq = mb * mBlockSize;
@@ -234,8 +248,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
234248
}
235249

236250
// Causal mask (either with or without Alibi), use endSeq as N
237-
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
238-
m, endSeq, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
251+
small_amx_gemm_16bits_compute(m, endSeq, k, A, lda, packedB, headSize, C, ldc);
239252

240253
#ifdef XFT_DEBUG
241254
if (b == 0 && i == 0) {
@@ -257,7 +270,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
257270
} else {
258271
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
259272
}
260-
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t));
273+
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(T));
261274
}
262275

263276
#ifdef XFT_DEBUG
@@ -274,7 +287,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
274287
lda = ldc;
275288
ldc = oStride;
276289
A = C;
277-
C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize;
290+
C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize;
278291

279292
if constexpr (fusedPack) {
280293
if (packInfo[tid].first != b || packInfo[tid].second != kvHeadIdx) {
@@ -287,8 +300,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
287300
}
288301
}
289302

290-
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
291-
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc);
303+
small_amx_gemm_16bits_compute(m, n, k, A, lda, packedV, tokens, C, ldc);
292304

293305
#ifdef XFT_DEBUG
294306
if (b == 0 && i == 0) {
@@ -301,8 +313,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
301313
});
302314
}
303315

304-
template <typename Lambda1, typename Lambda2>
305-
void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
316+
template <typename T, typename Lambda1, typename Lambda2>
317+
void selfAttention_FusedCopy(T *output, T *query, T *key, T *value, int qHeadNum,
306318
int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
307319
const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
308320
const Lambda2 &getVCache) {
@@ -331,11 +343,11 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
331343
// Prepare buffers (packing buffer and score buffer)
332344
const int kPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size(maxTokenSize, headSize, 32, 32);
333345
const int vPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size(headSize, maxTokenSize, 32, 32);
334-
bfloat16_t *packBuf = (bfloat16_t *)SimpleMemPool::instance().getBuffer(
335-
"kv_packing", threadNum * (kPackSize + vPackSize) * sizeof(bfloat16_t));
346+
T *packBuf = (T *)SimpleMemPool::instance().getBuffer(
347+
"kv_packing", threadNum * (kPackSize + vPackSize) * sizeof(T));
336348
int maxScoreStride = (maxTokenSize + 31) / 32 * 32;
337-
bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer(
338-
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t));
349+
T *scores = (T *)SimpleMemPool::instance().getBuffer(
350+
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(T));
339351

340352
#ifdef XFT_DEBUG
341353
printf("maxTokenSize=%d, tokenSizes[0]=%d, offsets[0]=%d, kvStride=%d\n", maxTokenSize, tokenSizes[0], offsets[0],
@@ -349,8 +361,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
349361
const int tokens = tokenSizes[b];
350362
const int mBlockNum = (tokens + mBlockSize - 1) / mBlockSize;
351363

352-
bfloat16_t *packedB = packBuf + tid * (kPackSize + vPackSize);
353-
bfloat16_t *packedV = packedB + kPackSize;
364+
T *packedB = packBuf + tid * (kPackSize + vPackSize);
365+
T *packedV = packedB + kPackSize;
354366

355367
// Copy key/value to cache and pack them
356368
auto B = key + offsets[b] * kvStride + i * headSize;
@@ -386,8 +398,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
386398
auto A = query + (offsets[b] + startSeq) * qStride + i * headSize;
387399
auto C = scores + tid * mBlockSize * maxScoreStride;
388400

389-
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
390-
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
401+
small_amx_gemm_16bits_compute(
402+
m, n, k, A, lda, packedB, headSize, C, ldc);
391403

392404
#ifdef XFT_DEBUG
393405
if (b == 0 && i == 0) {
@@ -408,7 +420,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
408420
} else {
409421
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
410422
}
411-
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t));
423+
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(T));
412424
}
413425

414426
#ifdef XFT_DEBUG
@@ -425,10 +437,9 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
425437
lda = ldc;
426438
ldc = oStride;
427439
A = C;
428-
C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize;
440+
C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize;
429441

430-
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
431-
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc);
442+
small_amx_gemm_16bits_compute(m, n, k, A, lda, packedV, tokens, C, ldc);
432443

433444
#ifdef XFT_DEBUG
434445
if (b == 0 && i == 0) {
@@ -443,8 +454,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
443454
} // end for b
444455
}
445456

446-
template <typename Lambda1, typename Lambda2>
447-
void selfAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
457+
template <typename T, typename Lambda1, typename Lambda2>
458+
void selfAttention(T *output, T *query, T *key, T *value, int qHeadNum,
448459
int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
449460
const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
450461
const Lambda2 &getVCache) {
@@ -700,91 +711,94 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in
700711
size_t scoreSizePerThr = 0;
701712
for (int i = 0; i < batchSize; ++i) {
702713
scoreSizePerThr = std::max(scoreSizePerThr, (size_t)inputSeqLens[i] * (inputSeqLens[i] + pastSeqLens[i]));
703-
inputOffsets[i] = (i > 0 ? inputOffsets[i - 1] + inputSeqLens[i - 1] : 0);
714+
inputOffsets[i] = (i > 0 ? inputOffsets[i - 1] + inputSeqLens[i] : 0);
704715
}
705716

706717
scoreSizePerThr = ALIGNED_SIZE(scoreSizePerThr, 16);
707718
size_t scoreSize = scoreSizePerThr * threadNum;
708719
float *scoreBuf = (float *)SimpleMemPool::instance().getBuffer("scoreBuf", sizeof(float) * scoreSize);
709720

710-
#pragma omp parallel for collapse(2)
711-
for (int b = 0; b < batchSize; ++b) {
712-
for (int i = 0; i < responsibleHeads; ++i) {
713-
// Copy current key to cached keys (if needed)
714-
int kvHdx = i / groupNum;
715-
auto keyMatInfo = getKHead(b, kvHdx);
716-
auto valueMat = getVHead(b, kvHdx);
717-
bool bCopyCache = (i % groupNum == 0);
718-
719-
// Q * K
720-
auto Q = query + inputOffsets[b] * qStride + i * headSize;
721-
auto S = scoreBuf + omp_get_thread_num() * scoreSizePerThr;
722-
723-
const int queryLen = inputSeqLens[b];
724-
const int keyLen = pastSeqLens[b] + inputSeqLens[b];
725-
726-
if (bCopyCache) {
727-
int m = queryLen;
728-
int n = keyLen;
729-
int lda = qStride;
730-
int ldc = keyLen;
721+
#pragma omp parallel for collapse(3)
722+
for (int kvh = 0; kvh < kvHeadNum; ++kvh) {
723+
for (int b = 0; b < batchSize; ++b) {
724+
for (int groupOff = 0; groupOff < groupNum; ++groupOff) {
725+
int i = kvh * groupNum + groupOff;
731726

732-
// Copy to Key cache and compute Query * Key
733-
auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize;
734-
storeKVCache(keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
727+
// Copy current key to cached keys (if needed)
728+
int kvHdx = kvh;
729+
auto keyMatInfo = getKHead(b, kvHdx);
730+
auto valueMat = getVHead(b, kvHdx);
731+
bool bCopyCache = (i % groupNum == 0);
735732

736-
gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);
737-
} else {
738-
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
739-
int m = queryLen;
740-
int n = pastSeqLens[b];
741-
int lda = qStride;
742-
int ldc = keyLen;
743-
gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);
733+
// Q * K
734+
auto Q = query + inputOffsets[b] * qStride + i * headSize;
735+
auto S = scoreBuf + omp_get_thread_num() * scoreSizePerThr;
744736

745-
int ldb = kvStride;
746-
auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize;
747-
small_gemm_transb(Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc);
748-
}
737+
const int queryLen = inputSeqLens[b];
738+
const int keyLen = pastSeqLens[b] + inputSeqLens[b];
739+
740+
if (bCopyCache) {
741+
int m = queryLen;
742+
int n = keyLen;
743+
int lda = qStride;
744+
int ldc = keyLen;
749745

750-
// Softmax(Q * K)
751-
for (int seq = 0; seq < queryLen; ++seq) {
752-
int elements = pastSeqLens[b] + seq + 1;
753-
if (alibiSlopes == nullptr) {
754-
small_softmax_f32(S + seq * keyLen, scale, elements);
746+
// Copy to Key cache and compute Query * Key
747+
auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize;
748+
storeKVCache(keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
749+
750+
gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);
755751
} else {
756-
DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements);
752+
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
753+
int m = queryLen;
754+
int n = pastSeqLens[b];
755+
int lda = qStride;
756+
int ldc = keyLen;
757+
gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);
758+
759+
int ldb = kvStride;
760+
auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize;
761+
small_gemm_transb(Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc);
757762
}
758-
if (keyLen > elements) { memset(S + seq * keyLen + elements, 0, (keyLen - elements) * sizeof(float)); }
759-
}
760763

761-
// Softmax * V
762-
if (bCopyCache) {
763-
// Copy current value to cached values
764-
auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize;
765-
storeKVCache(valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
766-
767-
int m = queryLen;
768-
auto result = output + inputOffsets[b] * oStride + i * headSize;
769-
gemmSV(S, valueMat, result, m, headSize, keyLen, keyLen, oStride);
770-
} else {
771-
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
772-
int m = queryLen;
773-
float f32Out[m * headSize]; // accumulate in FP32
774-
gemmSV(S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize);
775-
776-
auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize;
777-
small_gemm(S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true);
778-
779-
// f32Out -> output
780-
auto result = output + inputOffsets[b] * oStride + i * headSize;
781-
for (int t = 0; t < m; ++t) {
782-
xft::copy(result + t * oStride, f32Out + t * headSize, headSize);
764+
// Softmax(Q * K)
765+
for (int seq = 0; seq < queryLen; ++seq) {
766+
int elements = pastSeqLens[b] + seq + 1;
767+
if (alibiSlopes == nullptr) {
768+
small_softmax_f32(S + seq * keyLen, scale, elements);
769+
} else {
770+
DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements);
771+
}
772+
if (keyLen > elements) { memset(S + seq * keyLen + elements, 0, (keyLen - elements) * sizeof(float)); }
783773
}
784-
}
785774

786-
} // end for i
787-
} // end for b
775+
// Softmax * V
776+
if (bCopyCache) {
777+
// Copy current value to cached values
778+
auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize;
779+
storeKVCache(valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
780+
781+
int m = queryLen;
782+
auto result = output + inputOffsets[b] * oStride + i * headSize;
783+
gemmSV(S, valueMat, result, m, headSize, keyLen, keyLen, oStride);
784+
} else {
785+
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
786+
int m = queryLen;
787+
float f32Out[m * headSize]; // accumulate in FP32
788+
gemmSV(S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize);
789+
790+
auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize;
791+
small_gemm(S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true);
792+
793+
// f32Out -> output
794+
auto result = output + inputOffsets[b] * oStride + i * headSize;
795+
for (int t = 0; t < m; ++t) {
796+
xft::copy(result + t * oStride, f32Out + t * headSize, headSize);
797+
}
798+
}
799+
} // end for groupOff
800+
} // end for b
801+
} // end for kvh
788802
}
789803

790804
// scaled dot-product attention: bmm1 + softmax + bmm2

0 commit comments

Comments
 (0)