@@ -99,9 +99,23 @@ void gemmSV(
9999 }
100100}
101101
102+ // T is bfloat16_t or float16_t
103+ // ldb is the K value during packing
104+ template <typename T>
105+ void small_amx_gemm_16bits_compute (int m, int n, int k, T *A, int lda, T *packedB, int ldb, T *C, int ldc) {
106+ static_assert (std::is_same_v<T, bfloat16_t > || std::is_same_v<T, float16_t >, " AMX gemm only supports BF16/FP16." );
107+
108+ if (std::is_same_v<T, bfloat16_t >) {
109+ xdnn_small_amx_sgemm_bf16bf16bf16_compute (
110+ m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
111+ } else {
112+ // xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc);
113+ }
114+ }
115+
102116// Self attention while KV cache copy is separated
103- template <bool fusedPack, typename Lambda1, typename Lambda2>
104- void selfAttention_SeparateCopy (bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
117+ template <bool fusedPack, typename T, typename Lambda1, typename Lambda2>
118+ void selfAttention_SeparateCopy (T *output, T *query, T *key, T *value, int qHeadNum,
105119 int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
106120 const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
107121 const Lambda2 &getVCache) {
@@ -126,8 +140,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
126140 auto totalPackSize
127141 = fusedPack ? threadNum * (kPackSize + vPackSize) : (batchSize * kvHeadNum) * (kPackSize + vPackSize);
128142
129- bfloat16_t *packBuf
130- = (bfloat16_t *)SimpleMemPool::instance ().getBuffer (" kv_packing" , totalPackSize * sizeof (bfloat16_t ));
143+ T *packBuf
144+ = (T *)SimpleMemPool::instance ().getBuffer (" kv_packing" , totalPackSize * sizeof (T ));
131145
132146 // Copy key/value to cache and pack them
133147 // If packing is not fused into computing, then pack it here
@@ -137,8 +151,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
137151 for (int i = 0 ; i < kvHeadNum; ++i) {
138152 const int tokens = tokenSizes[b];
139153
140- bfloat16_t *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize);
141- bfloat16_t *packedV = packedB + kPackSize ;
154+ T *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize);
155+ T *packedV = packedB + kPackSize ;
142156
143157 auto B = key + offsets[b] * kvStride + i * headSize;
144158 for (int s = 0 ; s < tokens; ++s) {
@@ -181,8 +195,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
181195
182196 // Prepare score buffer
183197 auto maxScoreStride = (maxTokenSize + 31 ) / 32 * 32 ;
184- bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance ().getBuffer (
185- " qkscore" , threadNum * mBlockSize * maxScoreStride * sizeof (bfloat16_t ));
198+ T *scores = (T *)SimpleMemPool::instance ().getBuffer (
199+ " qkscore" , threadNum * mBlockSize * maxScoreStride * sizeof (T ));
186200
187201 auto totalBlocks = blkEndIndex[batchSize - 1 ];
188202 std::pair<int , int > packInfo[threadNum];
@@ -208,8 +222,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
208222 int tid = omp_get_thread_num ();
209223 int kvHeadIdx = i / groupNum;
210224 int locationIdx = (fusedPack ? tid : b * kvHeadNum + kvHeadIdx);
211- bfloat16_t *packedB = packBuf + locationIdx * (kPackSize + vPackSize);
212- bfloat16_t *packedV = packedB + kPackSize ;
225+ T *packedB = packBuf + locationIdx * (kPackSize + vPackSize);
226+ T *packedV = packedB + kPackSize ;
213227
214228 const int tokens = tokenSizes[b];
215229 const int startSeq = mb * mBlockSize ;
@@ -234,8 +248,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
234248 }
235249
236250 // Causal mask (either with or without Alibi), use endSeq as N
237- xdnn_small_amx_sgemm_bf16bf16bf16_compute (
238- m, endSeq, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
251+ small_amx_gemm_16bits_compute (m, endSeq, k, A, lda, packedB, headSize, C, ldc);
239252
240253#ifdef XFT_DEBUG
241254 if (b == 0 && i == 0 ) {
@@ -257,7 +270,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
257270 } else {
258271 DecoderUtil::alibiSoftmax (C + seq * ldc, scale, alibiSlopes[i], elements);
259272 }
260- memset (C + seq * ldc + elements, 0 , (tokens - elements) * sizeof (bfloat16_t ));
273+ memset (C + seq * ldc + elements, 0 , (tokens - elements) * sizeof (T ));
261274 }
262275
263276#ifdef XFT_DEBUG
@@ -274,7 +287,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
274287 lda = ldc;
275288 ldc = oStride;
276289 A = C;
277- C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize;
290+ C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize;
278291
279292 if constexpr (fusedPack) {
280293 if (packInfo[tid].first != b || packInfo[tid].second != kvHeadIdx) {
@@ -287,8 +300,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
287300 }
288301 }
289302
290- xdnn_small_amx_sgemm_bf16bf16bf16_compute (
291- m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc);
303+ small_amx_gemm_16bits_compute (m, n, k, A, lda, packedV, tokens, C, ldc);
292304
293305#ifdef XFT_DEBUG
294306 if (b == 0 && i == 0 ) {
@@ -301,8 +313,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
301313 });
302314}
303315
304- template <typename Lambda1, typename Lambda2>
305- void selfAttention_FusedCopy (bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
316+ template <typename T, typename Lambda1, typename Lambda2>
317+ void selfAttention_FusedCopy (T *output, T *query, T *key, T *value, int qHeadNum,
306318 int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
307319 const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
308320 const Lambda2 &getVCache) {
@@ -331,11 +343,11 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
331343 // Prepare buffers (packing buffer and score buffer)
332344 const int kPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size (maxTokenSize, headSize, 32 , 32 );
333345 const int vPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size (headSize, maxTokenSize, 32 , 32 );
334- bfloat16_t *packBuf = (bfloat16_t *)SimpleMemPool::instance ().getBuffer (
335- " kv_packing" , threadNum * (kPackSize + vPackSize) * sizeof (bfloat16_t ));
346+ T *packBuf = (T *)SimpleMemPool::instance ().getBuffer (
347+ " kv_packing" , threadNum * (kPackSize + vPackSize) * sizeof (T ));
336348 int maxScoreStride = (maxTokenSize + 31 ) / 32 * 32 ;
337- bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance ().getBuffer (
338- " qkscore" , threadNum * mBlockSize * maxScoreStride * sizeof (bfloat16_t ));
349+ T *scores = (T *)SimpleMemPool::instance ().getBuffer (
350+ " qkscore" , threadNum * mBlockSize * maxScoreStride * sizeof (T ));
339351
340352#ifdef XFT_DEBUG
341353 printf (" maxTokenSize=%d, tokenSizes[0]=%d, offsets[0]=%d, kvStride=%d\n " , maxTokenSize, tokenSizes[0 ], offsets[0 ],
@@ -349,8 +361,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
349361 const int tokens = tokenSizes[b];
350362 const int mBlockNum = (tokens + mBlockSize - 1 ) / mBlockSize ;
351363
352- bfloat16_t *packedB = packBuf + tid * (kPackSize + vPackSize);
353- bfloat16_t *packedV = packedB + kPackSize ;
364+ T *packedB = packBuf + tid * (kPackSize + vPackSize);
365+ T *packedV = packedB + kPackSize ;
354366
355367 // Copy key/value to cache and pack them
356368 auto B = key + offsets[b] * kvStride + i * headSize;
@@ -386,8 +398,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
386398 auto A = query + (offsets[b] + startSeq) * qStride + i * headSize;
387399 auto C = scores + tid * mBlockSize * maxScoreStride;
388400
389- xdnn_small_amx_sgemm_bf16bf16bf16_compute (
390- m, n, k, (XDNN_BF16 *) A, lda, (XDNN_BF16 *) packedB, (XDNN_BF16 *) C, ldc);
401+ small_amx_gemm_16bits_compute (
402+ m, n, k, A, lda, packedB, headSize, C, ldc);
391403
392404#ifdef XFT_DEBUG
393405 if (b == 0 && i == 0 ) {
@@ -408,7 +420,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
408420 } else {
409421 DecoderUtil::alibiSoftmax (C + seq * ldc, scale, alibiSlopes[i], elements);
410422 }
411- memset (C + seq * ldc + elements, 0 , (tokens - elements) * sizeof (bfloat16_t ));
423+ memset (C + seq * ldc + elements, 0 , (tokens - elements) * sizeof (T ));
412424 }
413425
414426#ifdef XFT_DEBUG
@@ -425,10 +437,9 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
425437 lda = ldc;
426438 ldc = oStride;
427439 A = C;
428- C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize;
440+ C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize;
429441
430- xdnn_small_amx_sgemm_bf16bf16bf16_compute (
431- m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc);
442+ small_amx_gemm_16bits_compute (m, n, k, A, lda, packedV, tokens, C, ldc);
432443
433444#ifdef XFT_DEBUG
434445 if (b == 0 && i == 0 ) {
@@ -443,8 +454,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
443454 } // end for b
444455}
445456
446- template <typename Lambda1, typename Lambda2>
447- void selfAttention (bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
457+ template <typename T, typename Lambda1, typename Lambda2>
458+ void selfAttention (T *output, T *query, T *key, T *value, int qHeadNum,
448459 int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
449460 const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
450461 const Lambda2 &getVCache) {
@@ -700,91 +711,94 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in
700711 size_t scoreSizePerThr = 0 ;
701712 for (int i = 0 ; i < batchSize; ++i) {
702713 scoreSizePerThr = std::max (scoreSizePerThr, (size_t )inputSeqLens[i] * (inputSeqLens[i] + pastSeqLens[i]));
703- inputOffsets[i] = (i > 0 ? inputOffsets[i - 1 ] + inputSeqLens[i - 1 ] : 0 );
714+ inputOffsets[i] = (i > 0 ? inputOffsets[i - 1 ] + inputSeqLens[i] : 0 );
704715 }
705716
706717 scoreSizePerThr = ALIGNED_SIZE (scoreSizePerThr, 16 );
707718 size_t scoreSize = scoreSizePerThr * threadNum;
708719 float *scoreBuf = (float *)SimpleMemPool::instance ().getBuffer (" scoreBuf" , sizeof (float ) * scoreSize);
709720
710- #pragma omp parallel for collapse(2)
711- for (int b = 0 ; b < batchSize; ++b) {
712- for (int i = 0 ; i < responsibleHeads; ++i) {
713- // Copy current key to cached keys (if needed)
714- int kvHdx = i / groupNum;
715- auto keyMatInfo = getKHead (b, kvHdx);
716- auto valueMat = getVHead (b, kvHdx);
717- bool bCopyCache = (i % groupNum == 0 );
718-
719- // Q * K
720- auto Q = query + inputOffsets[b] * qStride + i * headSize;
721- auto S = scoreBuf + omp_get_thread_num () * scoreSizePerThr;
722-
723- const int queryLen = inputSeqLens[b];
724- const int keyLen = pastSeqLens[b] + inputSeqLens[b];
725-
726- if (bCopyCache) {
727- int m = queryLen;
728- int n = keyLen;
729- int lda = qStride;
730- int ldc = keyLen;
721+ #pragma omp parallel for collapse(3)
722+ for (int kvh = 0 ; kvh < kvHeadNum; ++kvh) {
723+ for (int b = 0 ; b < batchSize; ++b) {
724+ for (int groupOff = 0 ; groupOff < groupNum; ++groupOff) {
725+ int i = kvh * groupNum + groupOff;
731726
732- // Copy to Key cache and compute Query * Key
733- auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize;
734- storeKVCache (keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
727+ // Copy current key to cached keys (if needed)
728+ int kvHdx = kvh;
729+ auto keyMatInfo = getKHead (b, kvHdx);
730+ auto valueMat = getVHead (b, kvHdx);
731+ bool bCopyCache = (i % groupNum == 0 );
735732
736- gemmQK (Q, keyMatInfo, S, m, n, headSize, lda, ldc);
737- } else {
738- // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
739- int m = queryLen;
740- int n = pastSeqLens[b];
741- int lda = qStride;
742- int ldc = keyLen;
743- gemmQK (Q, keyMatInfo, S, m, n, headSize, lda, ldc);
733+ // Q * K
734+ auto Q = query + inputOffsets[b] * qStride + i * headSize;
735+ auto S = scoreBuf + omp_get_thread_num () * scoreSizePerThr;
744736
745- int ldb = kvStride;
746- auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize;
747- small_gemm_transb (Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc);
748- }
737+ const int queryLen = inputSeqLens[b];
738+ const int keyLen = pastSeqLens[b] + inputSeqLens[b];
739+
740+ if (bCopyCache) {
741+ int m = queryLen;
742+ int n = keyLen;
743+ int lda = qStride;
744+ int ldc = keyLen;
749745
750- // Softmax(Q * K)
751- for ( int seq = 0 ; seq < queryLen; ++seq) {
752- int elements = pastSeqLens[b] + seq + 1 ;
753- if (alibiSlopes == nullptr ) {
754- small_softmax_f32 (S + seq * keyLen, scale, elements );
746+ // Copy to Key cache and compute Query * Key
747+ auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize;
748+ storeKVCache (keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride) ;
749+
750+ gemmQK (Q, keyMatInfo, S, m, n, headSize, lda, ldc );
755751 } else {
756- DecoderUtil::alibiSoftmax (S + seq * keyLen, scale, alibiSlopes[i], elements);
752+ // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
753+ int m = queryLen;
754+ int n = pastSeqLens[b];
755+ int lda = qStride;
756+ int ldc = keyLen;
757+ gemmQK (Q, keyMatInfo, S, m, n, headSize, lda, ldc);
758+
759+ int ldb = kvStride;
760+ auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize;
761+ small_gemm_transb (Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc);
757762 }
758- if (keyLen > elements) { memset (S + seq * keyLen + elements, 0 , (keyLen - elements) * sizeof (float )); }
759- }
760763
761- // Softmax * V
762- if (bCopyCache) {
763- // Copy current value to cached values
764- auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize;
765- storeKVCache (valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
766-
767- int m = queryLen;
768- auto result = output + inputOffsets[b] * oStride + i * headSize;
769- gemmSV (S, valueMat, result, m, headSize, keyLen, keyLen, oStride);
770- } else {
771- // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
772- int m = queryLen;
773- float f32Out[m * headSize]; // accumulate in FP32
774- gemmSV (S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize);
775-
776- auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize;
777- small_gemm (S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true );
778-
779- // f32Out -> output
780- auto result = output + inputOffsets[b] * oStride + i * headSize;
781- for (int t = 0 ; t < m; ++t) {
782- xft::copy (result + t * oStride, f32Out + t * headSize, headSize);
764+ // Softmax(Q * K)
765+ for (int seq = 0 ; seq < queryLen; ++seq) {
766+ int elements = pastSeqLens[b] + seq + 1 ;
767+ if (alibiSlopes == nullptr ) {
768+ small_softmax_f32 (S + seq * keyLen, scale, elements);
769+ } else {
770+ DecoderUtil::alibiSoftmax (S + seq * keyLen, scale, alibiSlopes[i], elements);
771+ }
772+ if (keyLen > elements) { memset (S + seq * keyLen + elements, 0 , (keyLen - elements) * sizeof (float )); }
783773 }
784- }
785774
786- } // end for i
787- } // end for b
775+ // Softmax * V
776+ if (bCopyCache) {
777+ // Copy current value to cached values
778+ auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize;
779+ storeKVCache (valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
780+
781+ int m = queryLen;
782+ auto result = output + inputOffsets[b] * oStride + i * headSize;
783+ gemmSV (S, valueMat, result, m, headSize, keyLen, keyLen, oStride);
784+ } else {
785+ // Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
786+ int m = queryLen;
787+ float f32Out[m * headSize]; // accumulate in FP32
788+ gemmSV (S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize);
789+
790+ auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize;
791+ small_gemm (S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true );
792+
793+ // f32Out -> output
794+ auto result = output + inputOffsets[b] * oStride + i * headSize;
795+ for (int t = 0 ; t < m; ++t) {
796+ xft::copy (result + t * oStride, f32Out + t * headSize, headSize);
797+ }
798+ }
799+ } // end for groupOff
800+ } // end for b
801+ } // end for kvh
788802}
789803
790804// scaled dot-product attention: bmm1 + softmax + bmm2
0 commit comments