@@ -759,6 +759,7 @@ struct MHAHelper {
759
759
PlainTensor _wv_scratch_a;
760
760
PlainTensor _wv_scratch_b;
761
761
PlainTensor _alibi_lookup;
762
+ PlainTensor _score_output;
762
763
std::vector<size_t > _wsp;
763
764
size_t _wsp_size_per_thread = 0 ;
764
765
@@ -772,6 +773,8 @@ struct MHAHelper {
772
773
// second token for bhl loop
773
774
PlainTensor _weight_bhl;
774
775
PlainTensor _output_bhl;
776
+ PlainTensor _score_offsets_aligned;
777
+ PlainTensor _score_offsets;
775
778
776
779
MHAHelper () {
777
780
_weight.resize <float >({size_t {1 }, size_t {1 }, size_t {1 }, size_t {1 }});
@@ -867,6 +870,26 @@ struct MHAHelper {
867
870
_wv_scratch_b.resize <DATA_TYPE>({batch, kv_len_in_blocks, _Hk, _block_size * rnd_up (_S, _block_size)});
868
871
}
869
872
873
+ void init_score_buffers (const PlainTensor& past_lens, const PlainTensor& subsequence_begins) {
874
+ static constexpr int cache_line_size = dnnl::impl::cpu::platform::get_cache_line_size ();
875
+ auto seq_cout = static_cast <int32_t >(past_lens.m_dims [0 ]);
876
+ _score_offsets_aligned.resize <int32_t >({past_lens.m_dims [0 ]});
877
+ _score_offsets.resize <int32_t >({past_lens.m_dims [0 ]});
878
+ int32_t total_kv_len_aligned = 0 ;
879
+ int32_t total_kv_len = 0 ;
880
+ for (int32_t i = 0 ; i < seq_cout; i++) {
881
+ auto q_len = subsequence_begins.ptr <int32_t >()[i + 1 ] - subsequence_begins.ptr <int32_t >()[i];
882
+ auto kv_len = past_lens.ptr <int32_t >()[i] + q_len;
883
+ _score_offsets_aligned.ptr <int32_t >()[i] = total_kv_len_aligned;
884
+ _score_offsets.ptr <int32_t >()[i] = total_kv_len;
885
+ // aligned to cache line to avoid false sharing
886
+ total_kv_len_aligned += rnd_up (kv_len, cache_line_size / sizeof (float ));
887
+ total_kv_len += kv_len;
888
+ }
889
+
890
+ _score_output.resize <float >({total_kv_len_aligned * _H});
891
+ }
892
+
870
893
// compute one block(such as 32 tokens) of query in M dimension: softmax(q_block*k')*v
871
894
// all tensors such as query... have no batch dimension because batch dimension is varying
872
895
// query: [H, L, S]
@@ -875,8 +898,8 @@ struct MHAHelper {
875
898
// qk_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size]
876
899
// wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size]
877
900
void exec_kernel_multiple (const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb,
878
- const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b,
879
- const int32_t * block_table, size_t ithr, size_t q_blk, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes) {
901
+ const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b, const int32_t * block_table, size_t ithr, size_t q_blk,
902
+ size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float * score_output ) {
880
903
auto q_start = q_blk * _block_size;
881
904
auto q_end = std::min (q_start + _block_size, q_len);
882
905
auto q_cnt = q_end - q_start;
@@ -947,6 +970,9 @@ struct MHAHelper {
947
970
precision_of<DATA_TYPE>::value,
948
971
alibi_slope);
949
972
}
973
+ if (score_output) {
974
+ cvt_copy (score_output + h * rnd_up (cur_kv_len, 16 ), reinterpret_cast <DATA_TYPE*>(score), cur_kv_len);
975
+ }
950
976
}
951
977
952
978
// reuse float buffer, need to use float to compute offset
@@ -998,7 +1024,7 @@ struct MHAHelper {
998
1024
// weight: [nthr, H, 32, rnd_up(kv_len, block_size)]
999
1025
// output: [nthr, 32, H, S]
1000
1026
void exec_kernel_one_bh (const PlainTensor& query, const PlainTensor& present_key, const PlainTensor& present_value, const PlainTensor& output_emb,
1001
- const int32_t * block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes) {
1027
+ const int32_t * block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float * score_output ) {
1002
1028
if (_fastpath_valid) {
1003
1029
_gemv->tile_config ();
1004
1030
for (size_t pk = 0 , i = 0 ; pk < cur_kv_len; pk += _block_size, i++) {
@@ -1044,6 +1070,9 @@ struct MHAHelper {
1044
1070
ov::element::f32,
1045
1071
ov::element::f32,
1046
1072
alibi_slope);
1073
+ if (score_output) {
1074
+ memcpy (score_output + h * rnd_up (cur_kv_len, 16 ), _weight.ptr <float >(ithr, h, pq), cur_kv_len * sizeof (float ));
1075
+ }
1047
1076
}
1048
1077
}
1049
1078
@@ -1078,6 +1107,7 @@ struct MHAHelper {
1078
1107
const PlainTensor& present_key,
1079
1108
const PlainTensor& present_value,
1080
1109
const PlainTensor& output_emb,
1110
+ const PlainTensor& output_score,
1081
1111
size_t max_context_len,
1082
1112
const PlainTensor& past_lens,
1083
1113
const PlainTensor& subsequence_begins,
@@ -1141,6 +1171,16 @@ struct MHAHelper {
1141
1171
alibi_slope);
1142
1172
});
1143
1173
1174
+ if (output_score) {
1175
+ parallel_for2d_dynamic (B, q_len, [&](size_t b, size_t pq) {
1176
+ auto cur_kv_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + 1 ;
1177
+ auto * src = _weight_bhl.ptr <float >(b, 0 , pq);
1178
+ size_t src_stride = _weight_bhl.stride (2 );
1179
+ auto * dst = output_score.ptr <float >() + _score_offsets.ptr <int32_t >()[b];
1180
+ attn_reduce (dst, src, _H, cur_kv_len, src_stride);
1181
+ });
1182
+ }
1183
+
1144
1184
// attn_w * V
1145
1185
_output_bhl.resize <float >({static_cast <size_t >(_nthr), B, q_len, _H, _S});
1146
1186
// m_attn_w {B, H, q_len, kv_len}
@@ -1284,6 +1324,7 @@ struct MHA {
1284
1324
const PlainTensor& k_cache,
1285
1325
const PlainTensor& v_cache,
1286
1326
const PlainTensor& output_emb,
1327
+ const PlainTensor& output_score,
1287
1328
size_t max_context_len,
1288
1329
const PlainTensor& past_lens,
1289
1330
const PlainTensor& subsequence_begins,
@@ -1343,16 +1384,30 @@ struct MHA {
1343
1384
1344
1385
if (q_len == 1 ) {
1345
1386
const auto cur_kv_len = static_cast <size_t >(past_lens.ptr <int32_t >()[batch_in_seq]) + 1 ;
1387
+ float * score_output = nullptr ;
1388
+ if (output_score) {
1389
+ auto score_offset = _helper._score_offsets_aligned .template ptr <int32_t >()[batch_in_seq];
1390
+ score_output = _helper._score_output .template ptr <float >() + score_offset * _helper._H ;
1391
+ }
1346
1392
1347
1393
_helper.exec_kernel_one_bh (q.slice (0 , batch_in_token, batch_in_token), k_cache, v_cache,
1348
1394
output_emb.slice (0 , batch_in_token, batch_in_token),
1349
1395
block_indices.ptr <int32_t >() + block_indices_begins.ptr <int32_t >()[batch_in_seq],
1350
- ithr, hk, 1ul , cur_kv_len, alibi_slopes);
1396
+ ithr, hk, 1ul , cur_kv_len, alibi_slopes,
1397
+ score_output);
1351
1398
} else {
1352
1399
const auto batch_in_reorder = item.batch_in_reorder ;
1353
1400
const auto q_blk = item.q_block_id ;
1354
1401
const auto q_cnt = std::min (_helper._block_size , q_len - q_blk * _helper._block_size );
1355
1402
const auto cur_kv_len = static_cast <size_t >(past_lens.ptr <int32_t >()[batch_in_seq]) + q_blk * _helper._block_size + q_cnt;
1403
+ float * score_output = nullptr ;
1404
+ if (output_score) {
1405
+ // last block
1406
+ if (q_len - q_blk * _helper._block_size <= _helper._block_size ) {
1407
+ auto score_offset = _helper._score_offsets_aligned .template ptr <int32_t >()[batch_in_seq];
1408
+ score_output = _helper._score_output .template ptr <float >() + score_offset * _helper._H ;
1409
+ }
1410
+ }
1356
1411
1357
1412
PlainTensor sub_query;
1358
1413
sub_query.resize ({q_len, _helper._H , _helper._S }, q.ptr <DATA_TYPE>(batch_in_token));
@@ -1368,31 +1423,47 @@ struct MHA {
1368
1423
hk,
1369
1424
q_len,
1370
1425
cur_kv_len,
1371
- alibi_slopes);
1426
+ alibi_slopes,
1427
+ score_output);
1372
1428
}
1373
1429
});
1430
+ if (output_score) {
1431
+ parallel_for2d_dynamic (past_lens.m_dims [0 ], 1 , [&](size_t b, size_t pq) {
1432
+ auto seq_len = static_cast <size_t >(subsequence_begins.ptr <int32_t >()[b + 1 ] - subsequence_begins.ptr <int32_t >()[b]);
1433
+ auto cur_kv_len = static_cast <size_t >(past_lens.ptr <int32_t >()[b]) + seq_len;
1434
+ auto src_offset = _helper._score_offsets_aligned .template ptr <int32_t >()[b];
1435
+ auto * src = _helper._score_output .template ptr <float >() + src_offset * _helper._H ;
1436
+ size_t src_stride = rnd_up (cur_kv_len, 16 );
1437
+ auto dst_offset = _helper._score_offsets .template ptr <int32_t >()[b];
1438
+ auto * dst = output_score.ptr <float >() + dst_offset;
1439
+ attn_reduce (dst, src, _helper._H , cur_kv_len, src_stride);
1440
+ });
1441
+ }
1374
1442
}
1375
1443
1376
1444
// Q, K, V is ready, do attention
1377
1445
void operator ()(PlainTensor& query,
1378
1446
PlainTensor& present_key,
1379
1447
PlainTensor& present_value,
1380
1448
PlainTensor& output_emb,
1449
+ PlainTensor& output_score,
1381
1450
size_t max_context_len,
1382
1451
const PlainTensor& past_lens,
1383
1452
const PlainTensor& subsequence_begins,
1384
1453
const PlainTensor& block_indices,
1385
1454
const PlainTensor& block_indices_begins,
1386
1455
const PlainTensor& alibi_slopes) {
1387
1456
_workitems.reset (query, past_lens, subsequence_begins, _helper._block_size );
1457
+ if (output_score)
1458
+ _helper.init_score_buffers (past_lens, subsequence_begins);
1388
1459
1389
1460
auto nthr = static_cast <size_t >(parallel_get_max_threads ());
1390
1461
1391
1462
if (past_lens.m_dims [0 ] >= nthr || _workitems.get_reorder_max_batch_size () > 0 ) {
1392
- exec_loop_mixed (query, present_key, present_value, output_emb, max_context_len, past_lens, subsequence_begins,
1463
+ exec_loop_mixed (query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins,
1393
1464
block_indices, block_indices_begins, alibi_slopes);
1394
1465
} else {
1395
- _helper.exec_loop_bhl (query, present_key, present_value, output_emb, max_context_len, past_lens, subsequence_begins,
1466
+ _helper.exec_loop_bhl (query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins,
1396
1467
block_indices, block_indices_begins, alibi_slopes);
1397
1468
}
1398
1469
}
@@ -1406,9 +1477,9 @@ struct AttentionExecutor : public PagedAttentionExecutor {
1406
1477
1407
1478
AttentionExecutor () : _kernel(_helper) {}
1408
1479
1409
- void init (const std::vector<MemoryPtr>& inputs, const MemoryPtr& output , PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache,
1480
+ void init (const std::vector<MemoryPtr>& inputs, const std::vector< MemoryPtr>& outputs , PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache,
1410
1481
PlainTensor& v_cache, PlainTensor& past_lens, PlainTensor& subsequence_begins, PlainTensor& block_indices, PlainTensor& block_indices_begins,
1411
- float & scale, size_t & sliding_window, PlainTensor& alibi_slopes, size_t & max_context_len, PlainTensor& output_emb) {
1482
+ float & scale, size_t & sliding_window, PlainTensor& alibi_slopes, size_t & max_context_len, PlainTensor& output_emb, PlainTensor& output_score ) {
1412
1483
q.reset (inputs[ID_Q]); // [B_token, H * S]
1413
1484
k.reset (inputs[ID_K]);
1414
1485
v.reset (inputs[ID_V]);
@@ -1423,7 +1494,9 @@ struct AttentionExecutor : public PagedAttentionExecutor {
1423
1494
if (!inputs[ID_ALIBI_SLOPES]->getShape ().hasZeroDims ())
1424
1495
alibi_slopes.reset (inputs[ID_ALIBI_SLOPES]);
1425
1496
max_context_len = static_cast <size_t >(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs <int32_t >());
1426
- output_emb.reset (output);
1497
+ output_emb.reset (outputs[0 ]);
1498
+ if (outputs.size () == 2 )
1499
+ output_score.reset (outputs[1 ]);
1427
1500
1428
1501
auto B_token = q.size (0 );
1429
1502
auto Hk = k_cache.size (1 );
@@ -1496,20 +1569,22 @@ struct AttentionExecutor : public PagedAttentionExecutor {
1496
1569
}
1497
1570
}
1498
1571
1499
- void execute (const std::vector<MemoryPtr>& inputs, const MemoryPtr output ) override {
1572
+ void execute (const std::vector<MemoryPtr>& inputs, const std::vector< MemoryPtr> outputs ) override {
1500
1573
PlainTensor q, k, v, k_cache, v_cache;
1501
1574
PlainTensor past_lens, subsequence_begins, block_indices, block_indices_begins;
1502
1575
float scale;
1503
1576
size_t sliding_window;
1504
1577
PlainTensor alibi_slopes;
1505
1578
size_t max_context_len;
1506
1579
PlainTensor output_emb;
1580
+ PlainTensor output_score;
1507
1581
1508
- init (inputs, output , q, k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins,
1509
- scale, sliding_window, alibi_slopes, max_context_len, output_emb);
1582
+ init (inputs, outputs , q, k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins,
1583
+ scale, sliding_window, alibi_slopes, max_context_len, output_emb, output_score );
1510
1584
concat_pastkv (k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins);
1511
1585
1512
- _kernel (q, k_cache, v_cache, output_emb, max_context_len, past_lens, subsequence_begins, block_indices, block_indices_begins, alibi_slopes);
1586
+ _kernel (q, k_cache, v_cache, output_emb, output_score, max_context_len, past_lens, subsequence_begins, block_indices,
1587
+ block_indices_begins, alibi_slopes);
1513
1588
}
1514
1589
};
1515
1590
#endif
@@ -1523,6 +1598,7 @@ std::shared_ptr<PagedAttentionExecutor> make_pa_executor(ov::element::Type data_
1523
1598
if (kvcache_type == ov::element::u8) {
1524
1599
executor = std::make_shared<AttentionExecutor<ov::bfloat16, uint8_t >>();
1525
1600
} else {
1601
+ OPENVINO_ASSERT (kvcache_type == ov::element::bf16, " expect kvcache type bf16, current: " , kvcache_type);
1526
1602
executor = std::make_shared<AttentionExecutor<ov::bfloat16, ov::bfloat16>>();
1527
1603
}
1528
1604
#else
@@ -1534,6 +1610,7 @@ std::shared_ptr<PagedAttentionExecutor> make_pa_executor(ov::element::Type data_
1534
1610
} else if (kvcache_type == ov::element::f16) {
1535
1611
executor = std::make_shared<AttentionExecutor<float , ov::float16>>();
1536
1612
} else {
1613
+ OPENVINO_ASSERT (kvcache_type == ov::element::f32, " expect kvcache type f32, current: " , kvcache_type);
1537
1614
executor = std::make_shared<AttentionExecutor<float , float >>();
1538
1615
}
1539
1616
} else {
0 commit comments