2
2
#include " common.h"
3
3
#include " llama.h"
4
4
#include " llama_model.h"
5
+ #include < algorithm>
5
6
#include < godot_cpp/classes/engine.hpp>
6
7
#include < godot_cpp/classes/os.hpp>
7
8
#include < godot_cpp/classes/worker_thread_pool.hpp>
8
9
#include < godot_cpp/core/class_db.hpp>
10
+ #include < godot_cpp/variant/dictionary.hpp>
9
11
#include < godot_cpp/variant/utility_functions.hpp>
10
12
11
13
using namespace godot ;
@@ -15,31 +17,41 @@ void LlamaContext::_bind_methods() {
15
17
ClassDB::bind_method (D_METHOD (" get_model" ), &LlamaContext::get_model);
16
18
ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::OBJECT, " model" , PROPERTY_HINT_RESOURCE_TYPE, " LlamaModel" ), " set_model" , " get_model" );
17
19
18
- ClassDB::bind_method (D_METHOD (" get_seed" ), &LlamaContext::get_seed);
19
- ClassDB::bind_method (D_METHOD (" set_seed" , " seed" ), &LlamaContext::set_seed);
20
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " seed" ), " set_seed" , " get_seed" );
20
+ ClassDB::bind_method (D_METHOD (" get_seed" ), &LlamaContext::get_seed);
21
+ ClassDB::bind_method (D_METHOD (" set_seed" , " seed" ), &LlamaContext::set_seed);
22
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " seed" ), " set_seed" , " get_seed" );
21
23
22
- ClassDB::bind_method (D_METHOD (" get_n_ctx " ), &LlamaContext::get_n_ctx );
23
- ClassDB::bind_method (D_METHOD (" set_n_ctx " , " n_ctx " ), &LlamaContext::set_n_ctx );
24
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT , " n_ctx " ), " set_n_ctx " , " get_n_ctx " );
24
+ ClassDB::bind_method (D_METHOD (" get_temperature " ), &LlamaContext::get_temperature );
25
+ ClassDB::bind_method (D_METHOD (" set_temperature " , " temperature " ), &LlamaContext::set_temperature );
26
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT , " temperature " ), " set_temperature " , " get_temperature " );
25
27
26
- ClassDB::bind_method (D_METHOD (" get_n_threads " ), &LlamaContext::get_n_threads );
27
- ClassDB::bind_method (D_METHOD (" set_n_threads " , " n_threads " ), &LlamaContext::set_n_threads );
28
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT , " n_threads " ), " set_n_threads " , " get_n_threads " );
28
+ ClassDB::bind_method (D_METHOD (" get_top_p " ), &LlamaContext::get_top_p );
29
+ ClassDB::bind_method (D_METHOD (" set_top_p " , " top_p " ), &LlamaContext::set_top_p );
30
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT , " top_p " ), " set_top_p " , " get_top_p " );
29
31
30
- ClassDB::bind_method (D_METHOD (" get_n_threads_batch" ), &LlamaContext::get_n_threads_batch);
31
- ClassDB::bind_method (D_METHOD (" set_n_threads_batch" , " n_threads_batch" ), &LlamaContext::set_n_threads_batch);
32
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_threads_batch" ), " set_n_threads_batch" , " get_n_threads_batch" );
32
+ ClassDB::bind_method (D_METHOD (" get_frequency_penalty" ), &LlamaContext::get_frequency_penalty);
33
+ ClassDB::bind_method (D_METHOD (" set_frequency_penalty" , " frequency_penalty" ), &LlamaContext::set_frequency_penalty);
34
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT, " frequency_penalty" ), " set_frequency_penalty" , " get_frequency_penalty" );
35
+
36
+ ClassDB::bind_method (D_METHOD (" get_presence_penalty" ), &LlamaContext::get_presence_penalty);
37
+ ClassDB::bind_method (D_METHOD (" set_presence_penalty" , " presence_penalty" ), &LlamaContext::set_presence_penalty);
38
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT, " presence_penalty" ), " set_presence_penalty" , " get_presence_penalty" );
39
+
40
+ ClassDB::bind_method (D_METHOD (" get_n_ctx" ), &LlamaContext::get_n_ctx);
41
+ ClassDB::bind_method (D_METHOD (" set_n_ctx" , " n_ctx" ), &LlamaContext::set_n_ctx);
42
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_ctx" ), " set_n_ctx" , " get_n_ctx" );
43
+
44
+ ClassDB::bind_method (D_METHOD (" get_n_len" ), &LlamaContext::get_n_len);
45
+ ClassDB::bind_method (D_METHOD (" set_n_len" , " n_len" ), &LlamaContext::set_n_len);
46
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_len" ), " set_n_len" , " get_n_len" );
33
47
34
48
ClassDB::bind_method (D_METHOD (" request_completion" , " prompt" ), &LlamaContext::request_completion);
35
- ClassDB::bind_method (D_METHOD (" _fulfill_completion " , " prompt " ), &LlamaContext::_fulfill_completion );
49
+ ClassDB::bind_method (D_METHOD (" __thread_loop " ), &LlamaContext::__thread_loop );
36
50
37
- ADD_SIGNAL (MethodInfo (" completion_generated" , PropertyInfo (Variant::STRING , " completion " ), PropertyInfo (Variant::BOOL, " is_final " )));
51
+ ADD_SIGNAL (MethodInfo (" completion_generated" , PropertyInfo (Variant::DICTIONARY , " chunk " )));
38
52
}
39
53
40
54
LlamaContext::LlamaContext () {
41
- batch = llama_batch_init (4096 , 0 , 1 );
42
-
43
55
ctx_params = llama_context_default_params ();
44
56
ctx_params.seed = -1 ;
45
57
ctx_params.n_ctx = 4096 ;
@@ -60,109 +72,186 @@ void LlamaContext::_ready() {
60
72
return ;
61
73
}
62
74
75
+ mutex.instantiate ();
76
+ semaphore.instantiate ();
77
+ thread.instantiate ();
78
+
79
+ llama_backend_init ();
80
+ llama_numa_init (ggml_numa_strategy::GGML_NUMA_STRATEGY_DISABLED);
81
+
63
82
ctx = llama_new_context_with_model (model->model , ctx_params);
64
83
if (ctx == NULL ) {
65
84
UtilityFunctions::printerr (vformat (" %s: Failed to initialize llama context, null ctx" , __func__));
66
85
return ;
67
86
}
87
+
88
+ sampling_ctx = llama_sampling_init (sampling_params);
89
+
68
90
UtilityFunctions::print (vformat (" %s: Context initialized" , __func__));
69
- }
70
91
71
- PackedStringArray LlamaContext::_get_configuration_warnings () const {
72
- PackedStringArray warnings;
73
- if (model == NULL ) {
74
- warnings.push_back (" Model resource property not defined" );
75
- }
76
- return warnings;
92
+ thread->start (callable_mp (this , &LlamaContext::__thread_loop));
77
93
}
78
94
79
- Variant LlamaContext::request_completion (const String &prompt) {
80
- UtilityFunctions::print (vformat (" %s: Requesting completion for prompt: %s" , __func__, prompt));
81
- if (task_id) {
82
- WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
83
- }
84
- task_id = WorkerThreadPool::get_singleton ()->add_task (Callable (this , " _fulfill_completion" ).bind (prompt));
85
- return OK;
86
- }
95
+ void LlamaContext::__thread_loop () {
96
+ while (true ) {
97
+ semaphore->wait ();
87
98
88
- void LlamaContext::_fulfill_completion (const String &prompt) {
89
- UtilityFunctions::print (vformat (" %s: Fulfilling completion for prompt: %s" , __func__, prompt));
90
- std::vector<llama_token> tokens_list;
91
- tokens_list = ::llama_tokenize (ctx, std::string (prompt.utf8 ().get_data ()), true );
99
+ mutex->lock ();
100
+ if (exit_thread) {
101
+ mutex->unlock ();
102
+ break ;
103
+ }
104
+ if (completion_requests.size () == 0 ) {
105
+ mutex->unlock ();
106
+ continue ;
107
+ }
108
+ completion_request req = completion_requests.get (0 );
109
+ completion_requests.remove_at (0 );
110
+ mutex->unlock ();
92
111
93
- const int n_len = 128 ;
94
- const int n_ctx = llama_n_ctx (ctx);
95
- const int n_kv_req = tokens_list.size () + (n_len - tokens_list.size ());
96
- if (n_kv_req > n_ctx) {
97
- UtilityFunctions::printerr (vformat (" %s: n_kv_req > n_ctx, the required KV cache size is not big enough\n either reduce n_len or increase n_ctx" , __func__));
98
- return ;
99
- }
112
+ UtilityFunctions::print (vformat (" %s: Running completion for prompt id: %d" , __func__, req.id ));
100
113
101
- for (size_t i = 0 ; i < tokens_list.size (); i++) {
102
- llama_batch_add (batch, tokens_list[i], i, { 0 }, false );
103
- }
114
+ std::vector<llama_token> request_tokens;
115
+ request_tokens = ::llama_tokenize (ctx, req.prompt .utf8 ().get_data (), true , true );
104
116
105
- batch.logits [batch.n_tokens - 1 ] = true ;
117
+ size_t shared_prefix_idx = 0 ;
118
+ auto diff = std::mismatch (context_tokens.begin (), context_tokens.end (), request_tokens.begin (), request_tokens.end ());
119
+ if (diff.first != context_tokens.end ()) {
120
+ shared_prefix_idx = std::distance (context_tokens.begin (), diff.first );
121
+ } else {
122
+ shared_prefix_idx = std::min (context_tokens.size (), request_tokens.size ());
123
+ }
106
124
107
- llama_kv_cache_clear (ctx);
125
+ bool rm_success = llama_kv_cache_seq_rm (ctx, -1 , shared_prefix_idx, -1 );
126
+ if (!rm_success) {
127
+ UtilityFunctions::printerr (vformat (" %s: Failed to remove tokens from kv cache" , __func__));
128
+ Dictionary response;
129
+ response[" id" ] = req.id ;
130
+ response[" error" ] = " Failed to remove tokens from kv cache" ;
131
+ call_thread_safe (" emit_signal" , " completion_generated" , response);
132
+ continue ;
133
+ }
134
+ context_tokens.erase (context_tokens.begin () + shared_prefix_idx, context_tokens.end ());
135
+ request_tokens.erase (request_tokens.begin (), request_tokens.begin () + shared_prefix_idx);
108
136
109
- int decode_res = llama_decode (ctx, batch);
110
- if (decode_res != 0 ) {
111
- UtilityFunctions::printerr (vformat (" %s: Failed to decode prompt with error code: %d" , __func__, decode_res));
112
- return ;
113
- }
137
+ uint batch_size = std::min (ctx_params.n_batch , (uint )request_tokens.size ());
138
+
139
+ llama_batch batch = llama_batch_init (batch_size, 0 , 1 );
140
+
141
+ // chunk request_tokens into sequences of size batch_size
142
+ std::vector<std::vector<llama_token>> sequences;
143
+ for (size_t i = 0 ; i < request_tokens.size (); i += batch_size) {
144
+ sequences.push_back (std::vector<llama_token>(request_tokens.begin () + i, request_tokens.begin () + std::min (i + batch_size, request_tokens.size ())));
145
+ }
146
+
147
+ printf (" Request tokens: \n " );
148
+ for (auto sequence : sequences) {
149
+ for (auto token : sequence) {
150
+ printf (" %s" , llama_token_to_piece (ctx, token).c_str ());
151
+ }
152
+ }
153
+ printf (" \n " );
114
154
115
- int n_cur = batch.n_tokens ;
116
- int n_decode = 0 ;
117
- llama_model *llama_model = model->model ;
155
+ int curr_token_pos = context_tokens.size ();
156
+ bool decode_failed = false ;
118
157
119
- while (n_cur <= n_len) {
120
- // sample the next token
121
- {
122
- auto n_vocab = llama_n_vocab (llama_model);
123
- auto *logits = llama_get_logits_ith (ctx, batch.n_tokens - 1 );
158
+ for (size_t i = 0 ; i < sequences.size (); i++) {
159
+ llama_batch_clear (batch);
124
160
125
- std::vector<llama_token_data> candidates;
126
- candidates.reserve (n_vocab);
161
+ std::vector<llama_token> sequence = sequences[i];
127
162
128
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
129
- candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
163
+ for (size_t j = 0 ; j < sequence.size (); j++) {
164
+ llama_batch_add (batch, sequence[j], j + curr_token_pos, { 0 }, false );
165
+ curr_token_pos++;
130
166
}
131
167
132
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
168
+ if (i == sequences.size () - 1 ) {
169
+ batch.logits [batch.n_tokens - 1 ] = true ;
170
+ }
133
171
134
- // sample the most likely token
135
- const llama_token new_token_id = llama_sample_token_greedy (ctx, &candidates_p);
172
+ if (llama_decode (ctx, batch) != 0 ) {
173
+ decode_failed = true ;
174
+ break ;
175
+ }
176
+ }
136
177
137
- // is it an end of stream?
138
- if (new_token_id == llama_token_eos (llama_model) || n_cur == n_len) {
139
- call_thread_safe (" emit_signal" , " completion_generated" , " \n " , true );
178
+ if (decode_failed) {
179
+ Dictionary response;
180
+ response[" id" ] = req.id ;
181
+ response[" error" ] = " llama_decode() failed" ;
182
+ call_thread_safe (" emit_signal" , " completion_generated" , response);
183
+ continue ;
184
+ }
185
+
186
+ context_tokens.insert (context_tokens.end (), request_tokens.begin (), request_tokens.end ());
187
+
188
+ while (true ) {
189
+ if (exit_thread) {
190
+ return ;
191
+ }
192
+ llama_token new_token_id = llama_sampling_sample (sampling_ctx, ctx, NULL , batch.n_tokens - 1 );
193
+ llama_sampling_accept (sampling_ctx, ctx, new_token_id, false );
140
194
195
+ Dictionary response;
196
+ response[" id" ] = req.id ;
197
+
198
+ context_tokens.push_back (new_token_id);
199
+
200
+ if (llama_token_is_eog (model->model , new_token_id) || curr_token_pos == n_len) {
201
+ response[" done" ] = true ;
202
+ call_thread_safe (" emit_signal" , " completion_generated" , response);
141
203
break ;
142
204
}
143
205
144
- call_thread_safe (" emit_signal" , " completion_generated" , vformat (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ()), false );
206
+ response[" text" ] = llama_token_to_piece (ctx, new_token_id).c_str ();
207
+ response[" done" ] = false ;
208
+ call_thread_safe (" emit_signal" , " completion_generated" , response);
145
209
146
- // prepare the next batch
147
210
llama_batch_clear (batch);
148
211
149
- // push this new token for next evaluation
150
- llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
212
+ llama_batch_add (batch, new_token_id, curr_token_pos, { 0 }, true );
151
213
152
- n_decode += 1 ;
153
- }
214
+ curr_token_pos++;
154
215
155
- n_cur += 1 ;
216
+ if (llama_decode (ctx, batch) != 0 ) {
217
+ decode_failed = true ;
218
+ break ;
219
+ }
220
+ }
156
221
157
- // evaluate the current batch with the transformer model
158
- int decode_res = llama_decode (ctx, batch);
159
- if (decode_res != 0 ) {
160
- UtilityFunctions::printerr (vformat (" %s: Failed to decode batch with error code: %d" , __func__, decode_res));
161
- break ;
222
+ if (decode_failed) {
223
+ Dictionary response;
224
+ response[" id" ] = req.id ;
225
+ response[" error" ] = " llama_decode() failed" ;
226
+ call_thread_safe (" emit_signal" , " completion_generated" , response);
227
+ continue ;
162
228
}
163
229
}
164
230
}
165
231
232
+ PackedStringArray LlamaContext::_get_configuration_warnings () const {
233
+ PackedStringArray warnings;
234
+ if (model == NULL ) {
235
+ warnings.push_back (" Model resource property not defined" );
236
+ }
237
+ return warnings;
238
+ }
239
+
240
+ int LlamaContext::request_completion (const String &prompt) {
241
+ int id = request_id++;
242
+
243
+ UtilityFunctions::print (vformat (" %s: Requesting completion for prompt id: %d" , __func__, id));
244
+
245
+ mutex->lock ();
246
+ completion_request req = { id, prompt };
247
+ completion_requests.append (req);
248
+ mutex->unlock ();
249
+
250
+ semaphore->post ();
251
+
252
+ return id;
253
+ }
254
+
166
255
void LlamaContext::set_model (const Ref<LlamaModel> p_model) {
167
256
model = p_model;
168
257
}
@@ -184,28 +273,58 @@ void LlamaContext::set_n_ctx(int n_ctx) {
184
273
ctx_params.n_ctx = n_ctx;
185
274
}
186
275
187
- int LlamaContext::get_n_threads () {
188
- return ctx_params. n_threads ;
276
+ int LlamaContext::get_n_len () {
277
+ return n_len ;
189
278
}
190
- void LlamaContext::set_n_threads (int n_threads ) {
191
- ctx_params. n_threads = n_threads ;
279
+ void LlamaContext::set_n_len (int n_len ) {
280
+ this -> n_len = n_len ;
192
281
}
193
282
194
- int LlamaContext::get_n_threads_batch () {
195
- return ctx_params. n_threads_batch ;
283
+ float LlamaContext::get_temperature () {
284
+ return sampling_params. temp ;
196
285
}
197
- void LlamaContext::set_n_threads_batch ( int n_threads_batch ) {
198
- ctx_params. n_threads_batch = n_threads_batch ;
286
+ void LlamaContext::set_temperature ( float temperature ) {
287
+ sampling_params. temp = temperature ;
199
288
}
200
289
201
- LlamaContext::~LlamaContext () {
202
- if (ctx) {
203
- llama_free (ctx);
290
+ float LlamaContext::get_top_p () {
291
+ return sampling_params.top_p ;
292
+ }
293
+ void LlamaContext::set_top_p (float top_p) {
294
+ sampling_params.top_p = top_p;
295
+ }
296
+
297
+ float LlamaContext::get_frequency_penalty () {
298
+ return sampling_params.penalty_freq ;
299
+ }
300
+ void LlamaContext::set_frequency_penalty (float frequency_penalty) {
301
+ sampling_params.penalty_freq = frequency_penalty;
302
+ }
303
+
304
+ float LlamaContext::get_presence_penalty () {
305
+ return sampling_params.penalty_present ;
306
+ }
307
+ void LlamaContext::set_presence_penalty (float presence_penalty) {
308
+ sampling_params.penalty_present = presence_penalty;
309
+ }
310
+
311
+ void LlamaContext::_exit_tree () {
312
+ if (Engine::get_singleton ()->is_editor_hint ()) {
313
+ return ;
204
314
}
205
315
206
- llama_batch_free (batch);
316
+ mutex->lock ();
317
+ exit_thread = true ;
318
+ mutex->unlock ();
319
+
320
+ semaphore->post ();
207
321
208
- if (task_id) {
209
- WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
322
+ thread->wait_to_finish ();
323
+
324
+ if (ctx) {
325
+ llama_free (ctx);
210
326
}
327
+
328
+ llama_sampling_free (sampling_ctx);
329
+ llama_backend_free ();
211
330
}
0 commit comments