Skip to content

vulkan: move common FA code to flash_attn_base.comp #13556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 3 additions & 150 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Original file line number Diff line number Diff line change
Expand Up @@ -9,60 +9,13 @@
#extension GL_KHR_shader_subgroup_shuffle : enable

#include "types.comp"
#include "flash_attn_base.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;

layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split;

const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;

layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;

uint32_t ne1;
uint32_t ne2;
uint32_t ne3;

uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;

uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;

float scale;
float max_bias;
float logit_softcap;

uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;

uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;

layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
Expand All @@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};

#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif

#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18

vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;

return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif

#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;

return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif

#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))

// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
Expand All @@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}

// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}

// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);

const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);

return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}

shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];

Expand All @@ -146,58 +45,12 @@ void main() {
init_iq_shmem(gl_WorkGroupSize);
#endif

const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t N = p.N;
const uint32_t KV = p.KV;
init_indices();

const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;

uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;

if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}

const uint32_t Tr = CEIL_DIV(N, Br);

const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);

// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;

// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;

const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;

// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;

// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;

// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;

uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;

[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
Expand Down
162 changes: 162 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 4) const uint32_t Clamp = 0;
layout (constant_id = 5) const uint32_t D_split = 16;


layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;

uint32_t ne1;
uint32_t ne2;
uint32_t ne3;

uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;

uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;

float scale;
float max_bias;
float logit_softcap;

uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;

uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;

layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};

#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif

#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18

vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;

return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif

#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;

return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif

#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))


// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}

// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);

const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);

return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}

uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;

void init_indices()
{
N = p.N;
KV = p.KV;

i = gl_WorkGroupID.x;
split_k_index = 0;

if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}

Tr = CEIL_DIV(N, Br);

start_j = split_k_index * p.split_kv / Bc;
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);

// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
iq3 = gl_WorkGroupID.z;

// broadcast factors
rk2 = p.neq2/p.nek2;
rk3 = p.neq3/p.nek3;

rv2 = p.neq2/p.nev2;
rv3 = p.neq3/p.nev3;

// k indices
ik3 = iq3 / rk3;
ik2 = iq2 / rk2;

// v indices
iv3 = iq3 / rv3;
iv2 = iq2 / rv2;

// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
k_stride = p.nb11;
v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
}
Loading
Loading