Skip to content

Commit 9d533a7

Browse files
authored
llama : fix defrag bugs + add parameter (#5735)
* llama : fix defrag bugs + enable by default ggml-ci * llama : add defrag_thold parameter ggml-ci * llama : cont * llama : disable log message ggml-ci * llama : fix graph size check during defrag
1 parent cbbd1ef commit 9d533a7

File tree

5 files changed

+82
-30
lines changed

5 files changed

+82
-30
lines changed

common/common.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
335335
break;
336336
}
337337
params.yarn_beta_slow = std::stof(argv[i]);
338+
} else if (arg == "--defrag-thold" || arg == "-dt") {
339+
if (++i >= argc) {
340+
invalid_param = true;
341+
break;
342+
}
343+
params.defrag_thold = std::stof(argv[i]);
338344
} else if (arg == "--samplers") {
339345
if (++i >= argc) {
340346
invalid_param = true;
@@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10041010
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
10051011
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
10061012
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
1013+
printf(" -dt N, --defrag-thold N\n");
1014+
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
10071015
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
10081016
printf(" --no-penalize-nl do not penalize newline token\n");
10091017
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
@@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
12851293
cparams.yarn_beta_fast = params.yarn_beta_fast;
12861294
cparams.yarn_beta_slow = params.yarn_beta_slow;
12871295
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
1296+
cparams.defrag_thold = params.defrag_thold;
12881297
cparams.offload_kqv = !params.no_kv_offload;
12891298

12901299
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);

common/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ struct gpt_params {
7575
float yarn_beta_fast = 32.0f; // YaRN low correction dim
7676
float yarn_beta_slow = 1.0f; // YaRN high correction dim
7777
int32_t yarn_orig_ctx = 0; // YaRN original context length
78+
float defrag_thold = -1.0f; // KV cache defragmentation threshold
7879
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
7980
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
8081

examples/passkey/passkey.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
182182

183183
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
184184
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
185-
llama_kv_cache_defrag (ctx);
185+
//llama_kv_cache_defrag (ctx);
186186
llama_kv_cache_update (ctx);
187187

188188
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
@@ -213,7 +213,7 @@ int main(int argc, char ** argv) {
213213

214214
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
215215
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
216-
llama_kv_cache_defrag (ctx);
216+
//llama_kv_cache_defrag (ctx);
217217
llama_kv_cache_update (ctx);
218218

219219
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

llama.cpp

+69-28
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,7 @@ struct llama_cparams {
16411641
float yarn_attn_factor;
16421642
float yarn_beta_fast;
16431643
float yarn_beta_slow;
1644+
float defrag_thold;
16441645

16451646
bool mul_mat_q;
16461647
bool offload_kqv;
@@ -5117,16 +5118,16 @@ struct llm_build_context {
51175118
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
51185119
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
51195120

5120-
for (int i = 0; i < n_kv; ++i) {
5121-
const int id = ids[i];
5121+
for (uint32_t i = 0; i < ids.size(); ++i) {
5122+
const uint32_t id = ids[i];
51225123

5123-
if (i == id || id == n_kv) {
5124+
if (i == id || id == ids.size()) {
51245125
continue;
51255126
}
51265127

5127-
int nm = 1;
5128+
uint32_t nm = 1;
51285129

5129-
while (i + nm < n_kv && (int) ids[i + nm] == id + nm) {
5130+
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
51305131
nm++;
51315132
}
51325133

@@ -5158,6 +5159,8 @@ struct llm_build_context {
51585159
i += nm - 1;
51595160
}
51605161

5162+
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
5163+
51615164
return gf;
51625165
}
51635166

@@ -7938,6 +7941,8 @@ static int llama_decode_internal(
79387941
batch.seq_id = seq_id_arr.data();
79397942
}
79407943

7944+
llama_kv_cache_update(&lctx);
7945+
79417946
// if we have enough unused cells before the current head ->
79427947
// better to start searching from the beginning of the cache, hoping to fill it
79437948
if (kv_self.head > kv_self.used + 2*n_tokens) {
@@ -7956,8 +7961,6 @@ static int llama_decode_internal(
79567961

79577962
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
79587963

7959-
llama_kv_cache_update(&lctx);
7960-
79617964
ggml_backend_sched_reset(lctx.sched);
79627965
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
79637966

@@ -8007,6 +8010,18 @@ static int llama_decode_internal(
80078010
}
80088011
}
80098012

8013+
// decide if we need to defrag the kv cache
8014+
if (cparams.defrag_thold >= 0.0f) {
8015+
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
8016+
8017+
// queue defragmentation for next llama_kv_cache_update
8018+
if (fragmentation > cparams.defrag_thold) {
8019+
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
8020+
8021+
llama_kv_cache_defrag(kv_self);
8022+
}
8023+
}
8024+
80108025
#ifdef GGML_PERF
80118026
// print timing information per ggml operation (for debugging purposes)
80128027
// requires GGML_PERF to be defined
@@ -8098,12 +8113,16 @@ static int llama_decode_internal(
80988113
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
80998114
auto & kv_self = lctx.kv_self;
81008115

8116+
const auto & hparams = lctx.model.hparams;
8117+
8118+
const uint32_t n_layer = hparams.n_layer;
8119+
81018120
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
81028121
const uint32_t n_used = kv_self.used;
81038122

81048123
assert(n_used <= n_kv);
81058124

8106-
const int64_t t_start = ggml_time_us();
8125+
//const int64_t t_start = ggml_time_us();
81078126

81088127
// number of cells moved
81098128
uint32_t n_moves = 0;
@@ -8127,15 +8146,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
81278146

81288147
// found a hole - fill it with data from the end of the cache
81298148

8130-
// determine the size of the hole
81318149
uint32_t nh = 1;
8150+
8151+
// determine the size of the hole
81328152
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
81338153
nh++;
81348154
}
81358155

8136-
// starting from the end, find nh non-empty cells
8156+
// each move requires 6*n_layer tensors (see build_defrag)
8157+
// - source view, destination view, copy operation
8158+
// - x2 for keys and values
8159+
//
8160+
if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
8161+
// the graph is too big, we cannot move more cells
8162+
break;
8163+
}
8164+
81378165
uint32_t nf = 0;
81388166
uint32_t is = n_kv - 1;
8167+
8168+
// starting from the end, find nh non-empty cells
81398169
for (; is > i0; --is) {
81408170
const auto & cell1 = kv_self.cells[is];
81418171

@@ -8156,11 +8186,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
81568186

81578187
nf = 0;
81588188

8189+
uint32_t i1 = is;
8190+
8191+
// are we moving a continuous block of memory?
8192+
bool cont = false;
8193+
81598194
// go back and move the nf cells to the hole
8160-
for (uint32_t i1 = is; i1 < n_kv; ++i1) {
8161-
const auto & cell1 = kv_self.cells[i1];
8195+
for (; i1 < n_kv; ++i1) {
8196+
auto & cell1 = kv_self.cells[i1];
81628197

81638198
if (cell1.is_empty() || ids[i1] != n_kv) {
8199+
cont = false;
81648200
continue;
81658201
}
81668202

@@ -8170,11 +8206,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
81708206
// move the cell meta data
81718207
kv_self.cells[i0 + nf] = cell1;
81728208

8173-
n_moves++;
8209+
// clear the old cell and move the head there
8210+
cell1 = llama_kv_cell();
8211+
kv_self.head = n_used;
8212+
8213+
if (!cont) {
8214+
n_moves++;
8215+
cont = true;
8216+
}
8217+
81748218
nf++;
8219+
8220+
if (nf == nh) {
8221+
break;
8222+
}
81758223
}
81768224

8177-
LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
8225+
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
81788226

81798227
i0 += nh - 1;
81808228
}
@@ -8183,15 +8231,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
81838231
return;
81848232
}
81858233

8186-
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
8187-
8188-
kv_self.head = n_used;
8189-
kv_self.used = n_used;
8234+
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
81908235

8191-
// zero the rest of the cells
8192-
for (uint32_t i = n_used; i < n_kv; ++i) {
8193-
kv_self.cells[i] = llama_kv_cell();
8194-
}
8236+
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
81958237

81968238
#if 0
81978239
// CPU defrag
@@ -8203,9 +8245,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
82038245
// likely not worth the effort, as we have ggml_graph based defrag
82048246
//
82058247

8206-
const auto & hparams = lctx.model.hparams;
8207-
8208-
const uint32_t n_layer = hparams.n_layer;
82098248
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
82108249
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
82118250

@@ -8274,9 +8313,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
82748313
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
82758314
#endif
82768315

8277-
const int64_t t_end = ggml_time_us();
8316+
//const int64_t t_end = ggml_time_us();
82788317

8279-
LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
8318+
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
82808319
}
82818320

82828321
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
@@ -11670,6 +11709,7 @@ struct llama_context_params llama_context_default_params() {
1167011709
/*.yarn_beta_fast =*/ 32.0f,
1167111710
/*.yarn_beta_slow =*/ 1.0f,
1167211711
/*.yarn_orig_ctx =*/ 0,
11712+
/*.defrag_thold =*/ -1.0f,
1167311713
/*.cb_eval =*/ nullptr,
1167411714
/*.cb_eval_user_data =*/ nullptr,
1167511715
/*.type_k =*/ GGML_TYPE_F16,
@@ -11834,6 +11874,7 @@ struct llama_context * llama_new_context_with_model(
1183411874
cparams.yarn_attn_factor = params.yarn_attn_factor;
1183511875
cparams.yarn_beta_fast = params.yarn_beta_fast;
1183611876
cparams.yarn_beta_slow = params.yarn_beta_slow;
11877+
cparams.defrag_thold = params.defrag_thold;
1183711878
cparams.mul_mat_q = params.mul_mat_q;
1183811879
cparams.offload_kqv = params.offload_kqv;
1183911880
cparams.do_pooling = params.do_pooling;
@@ -12035,7 +12076,7 @@ struct llama_context * llama_new_context_with_model(
1203512076
}
1203612077

1203712078
// buffer used to store the computation graph and the tensor meta data
12038-
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
12079+
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
1203912080

1204012081
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
1204112082

llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ extern "C" {
245245
float yarn_beta_fast; // YaRN low correction dim
246246
float yarn_beta_slow; // YaRN high correction dim
247247
uint32_t yarn_orig_ctx; // YaRN original context size
248+
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
248249

249250
ggml_backend_sched_eval_callback cb_eval;
250251
void * cb_eval_user_data;

0 commit comments

Comments
 (0)