2
2
#include " common.h"
3
3
#include " llama.h"
4
4
#include " llama_model.h"
5
+ #include < algorithm>
5
6
#include < godot_cpp/classes/engine.hpp>
6
7
#include < godot_cpp/classes/os.hpp>
7
8
#include < godot_cpp/classes/worker_thread_pool.hpp>
8
9
#include < godot_cpp/core/class_db.hpp>
9
- #include < godot_cpp/variant/utility_functions.hpp>
10
10
#include < godot_cpp/variant/dictionary.hpp>
11
+ #include < godot_cpp/variant/utility_functions.hpp>
11
12
12
13
using namespace godot ;
13
14
@@ -24,6 +25,10 @@ void LlamaContext::_bind_methods() {
24
25
ClassDB::bind_method (D_METHOD (" set_n_ctx" , " n_ctx" ), &LlamaContext::set_n_ctx);
25
26
ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_ctx" ), " set_n_ctx" , " get_n_ctx" );
26
27
28
+ ClassDB::bind_method (D_METHOD (" get_n_len" ), &LlamaContext::get_n_len);
29
+ ClassDB::bind_method (D_METHOD (" set_n_len" , " n_len" ), &LlamaContext::set_n_len);
30
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_len" ), " set_n_len" , " get_n_len" );
31
+
27
32
ClassDB::bind_method (D_METHOD (" request_completion" , " prompt" ), &LlamaContext::request_completion);
28
33
ClassDB::bind_method (D_METHOD (" __thread_loop" ), &LlamaContext::__thread_loop);
29
34
@@ -63,6 +68,9 @@ void LlamaContext::_ready() {
63
68
UtilityFunctions::printerr (vformat (" %s: Failed to initialize llama context, null ctx" , __func__));
64
69
return ;
65
70
}
71
+
72
+ sampling_ctx = llama_sampling_init (sampling_params);
73
+
66
74
UtilityFunctions::print (vformat (" %s: Context initialized" , __func__));
67
75
68
76
thread->start (callable_mp (this , &LlamaContext::__thread_loop));
@@ -73,6 +81,10 @@ void LlamaContext::__thread_loop() {
73
81
semaphore->wait ();
74
82
75
83
mutex->lock ();
84
+ if (exit_thread) {
85
+ mutex->unlock ();
86
+ break ;
87
+ }
76
88
if (completion_requests.size () == 0 ) {
77
89
mutex->unlock ();
78
90
continue ;
@@ -83,10 +95,115 @@ void LlamaContext::__thread_loop() {
83
95
84
96
UtilityFunctions::print (vformat (" %s: Running completion for prompt id: %d" , __func__, req.id ));
85
97
86
- Dictionary chunk;
87
- chunk[" id" ] = req.id ;
88
- chunk[" text" ] = " Hello, world!" ;
89
- call_deferred (" emit_signal" , " completion_generated" , chunk);
98
+ std::vector<llama_token> request_tokens;
99
+ request_tokens = ::llama_tokenize (ctx, req.prompt .utf8 ().get_data (), true );
100
+
101
+ size_t shared_prefix_idx = 0 ;
102
+ auto diff = std::mismatch (context_tokens.begin (), context_tokens.end (), request_tokens.begin (), request_tokens.end ());
103
+ if (diff.first != context_tokens.end ()) {
104
+ shared_prefix_idx = std::distance (context_tokens.begin (), diff.first );
105
+ } else {
106
+ shared_prefix_idx = std::min (context_tokens.size (), request_tokens.size ());
107
+ }
108
+
109
+ bool rm_success = llama_kv_cache_seq_rm (ctx, 0 , shared_prefix_idx, -1 );
110
+ if (!rm_success) {
111
+ UtilityFunctions::printerr (vformat (" %s: Failed to remove tokens from kv cache" , __func__));
112
+ Dictionary response;
113
+ response[" id" ] = req.id ;
114
+ response[" error" ] = " Failed to remove tokens from kv cache" ;
115
+ call_deferred (" emit_signal" , " completion_generated" , response);
116
+ continue ;
117
+ }
118
+ context_tokens.erase (context_tokens.begin () + shared_prefix_idx, context_tokens.end ());
119
+ request_tokens.erase (request_tokens.begin (), request_tokens.begin () + shared_prefix_idx);
120
+
121
+ uint batch_size = std::min (ctx_params.n_batch , (uint )request_tokens.size ());
122
+
123
+ llama_batch batch = llama_batch_init (batch_size, 0 , 1 );
124
+
125
+ // chunk request_tokens into sequences of size batch_size
126
+ std::vector<std::vector<llama_token>> sequences;
127
+ for (size_t i = 0 ; i < request_tokens.size (); i += batch_size) {
128
+ sequences.push_back (std::vector<llama_token>(request_tokens.begin () + i, request_tokens.begin () + std::min (i + batch_size, request_tokens.size ())));
129
+ }
130
+
131
+ int curr_token_pos = context_tokens.size ();
132
+ bool decode_failed = false ;
133
+
134
+ for (size_t i = 0 ; i < sequences.size (); i++) {
135
+ llama_batch_clear (batch);
136
+
137
+ std::vector<llama_token> sequence = sequences[i];
138
+
139
+ for (size_t j = 0 ; j < sequence.size (); j++) {
140
+ llama_batch_add (batch, sequence[j], j + curr_token_pos, { 0 }, false );
141
+ curr_token_pos++;
142
+ }
143
+
144
+ if (i == sequences.size () - 1 ) {
145
+ batch.logits [batch.n_tokens - 1 ] = true ;
146
+ }
147
+
148
+ if (llama_decode (ctx, batch) != 0 ) {
149
+ decode_failed = true ;
150
+ break ;
151
+ }
152
+ }
153
+
154
+ if (decode_failed) {
155
+ Dictionary response;
156
+ response[" id" ] = req.id ;
157
+ response[" error" ] = " llama_decode() failed" ;
158
+ call_deferred (" emit_signal" , " completion_generated" , response);
159
+ continue ;
160
+ }
161
+
162
+ context_tokens.insert (context_tokens.end (), request_tokens.begin (), request_tokens.end ());
163
+
164
+ while (true ) {
165
+ if (exit_thread) {
166
+ return ;
167
+ }
168
+ llama_token new_token_id = llama_sampling_sample (sampling_ctx, ctx, NULL , batch.n_tokens - 1 );
169
+ llama_sampling_accept (sampling_ctx, ctx, new_token_id, true );
170
+
171
+ Dictionary response;
172
+ response[" id" ] = req.id ;
173
+
174
+ if (llama_token_is_eog (model->model , new_token_id) || curr_token_pos == n_len) {
175
+ response[" done" ] = true ;
176
+ call_deferred (" emit_signal" , " completion_generated" , response);
177
+ break ;
178
+ }
179
+
180
+ context_tokens.push_back (new_token_id);
181
+
182
+ response[" text" ] = llama_token_to_piece (ctx, new_token_id).c_str ();
183
+ response[" done" ] = false ;
184
+ call_deferred (" emit_signal" , " completion_generated" , response);
185
+
186
+ llama_batch_clear (batch);
187
+
188
+ llama_batch_add (batch, new_token_id, curr_token_pos, { 0 }, true );
189
+
190
+ curr_token_pos++;
191
+
192
+ if (llama_decode (ctx, batch) != 0 ) {
193
+ decode_failed = true ;
194
+ break ;
195
+ }
196
+ }
197
+
198
+ if (decode_failed) {
199
+ Dictionary response;
200
+ response[" id" ] = req.id ;
201
+ response[" error" ] = " llama_decode() failed" ;
202
+ call_deferred (" emit_signal" , " completion_generated" , response);
203
+ continue ;
204
+ }
205
+
206
+ llama_sampling_reset (sampling_ctx);
90
207
}
91
208
}
92
209
@@ -134,7 +251,26 @@ void LlamaContext::set_n_ctx(int n_ctx) {
134
251
ctx_params.n_ctx = n_ctx;
135
252
}
136
253
137
- LlamaContext::~LlamaContext () {
254
+ int LlamaContext::get_n_len () {
255
+ return n_len;
256
+ }
257
+ void LlamaContext::set_n_len (int n_len) {
258
+ this ->n_len = n_len;
259
+ }
260
+
261
+ void LlamaContext::_exit_tree () {
262
+ if (Engine::get_singleton ()->is_editor_hint ()) {
263
+ return ;
264
+ }
265
+
266
+ mutex->lock ();
267
+ exit_thread = true ;
268
+ mutex->unlock ();
269
+
270
+ semaphore->post ();
271
+
272
+ thread->wait_to_finish ();
273
+
138
274
if (ctx) {
139
275
llama_free (ctx);
140
276
}
0 commit comments