6
6
#include < godot_cpp/classes/os.hpp>
7
7
#include < godot_cpp/classes/worker_thread_pool.hpp>
8
8
#include < godot_cpp/core/class_db.hpp>
9
+ #include < godot_cpp/variant/string.hpp>
9
10
#include < godot_cpp/variant/utility_functions.hpp>
10
11
11
12
using namespace godot ;
@@ -15,26 +16,18 @@ void LlamaContext::_bind_methods() {
15
16
ClassDB::bind_method (D_METHOD (" get_model" ), &LlamaContext::get_model);
16
17
ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::OBJECT, " model" , PROPERTY_HINT_RESOURCE_TYPE, " LlamaModel" ), " set_model" , " get_model" );
17
18
18
- ClassDB::bind_method (D_METHOD (" get_seed" ), &LlamaContext::get_seed);
19
- ClassDB::bind_method (D_METHOD (" set_seed" , " seed" ), &LlamaContext::set_seed);
20
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " seed" ), " set_seed" , " get_seed" );
19
+ ClassDB::bind_method (D_METHOD (" get_seed" ), &LlamaContext::get_seed);
20
+ ClassDB::bind_method (D_METHOD (" set_seed" , " seed" ), &LlamaContext::set_seed);
21
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " seed" ), " set_seed" , " get_seed" );
21
22
22
- ClassDB::bind_method (D_METHOD (" get_n_ctx" ), &LlamaContext::get_n_ctx);
23
- ClassDB::bind_method (D_METHOD (" set_n_ctx" , " n_ctx" ), &LlamaContext::set_n_ctx);
24
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_ctx" ), " set_n_ctx" , " get_n_ctx" );
23
+ ClassDB::bind_method (D_METHOD (" get_n_ctx" ), &LlamaContext::get_n_ctx);
24
+ ClassDB::bind_method (D_METHOD (" set_n_ctx" , " n_ctx" ), &LlamaContext::set_n_ctx);
25
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_ctx" ), " set_n_ctx" , " get_n_ctx" );
25
26
26
- ClassDB::bind_method (D_METHOD (" get_n_threads" ), &LlamaContext::get_n_threads);
27
- ClassDB::bind_method (D_METHOD (" set_n_threads" , " n_threads" ), &LlamaContext::set_n_threads);
28
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_threads" ), " set_n_threads" , " get_n_threads" );
27
+ ClassDB::bind_method (D_METHOD (" prompt" , " prompt" , " max_new_tokens" , " temperature" , " top_p" , " top_k" , " presence_penalty" , " frequency_penalty" ), &LlamaContext::prompt, DEFVAL (32 ), DEFVAL (0 .80f ), DEFVAL (0 .95f ), DEFVAL (40 ), DEFVAL (0.0 ), DEFVAL (0.0 ));
28
+ ClassDB::bind_method (D_METHOD (" _thread_prompt_loop" ), &LlamaContext::_thread_prompt_loop);
29
29
30
- ClassDB::bind_method (D_METHOD (" get_n_threads_batch" ), &LlamaContext::get_n_threads_batch);
31
- ClassDB::bind_method (D_METHOD (" set_n_threads_batch" , " n_threads_batch" ), &LlamaContext::set_n_threads_batch);
32
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_threads_batch" ), " set_n_threads_batch" , " get_n_threads_batch" );
33
-
34
- ClassDB::bind_method (D_METHOD (" request_completion" , " prompt" ), &LlamaContext::request_completion);
35
- ClassDB::bind_method (D_METHOD (" _fulfill_completion" , " prompt" ), &LlamaContext::_fulfill_completion);
36
-
37
- ADD_SIGNAL (MethodInfo (" completion_generated" , PropertyInfo (Variant::STRING, " completion" ), PropertyInfo (Variant::BOOL, " is_final" )));
30
+ ADD_SIGNAL (MethodInfo (" text_generated" , PropertyInfo (Variant::INT, " id" ), PropertyInfo (Variant::STRING, " text" ), PropertyInfo (Variant::BOOL, " is_final" )));
38
31
}
39
32
40
33
LlamaContext::LlamaContext () {
@@ -47,6 +40,11 @@ LlamaContext::LlamaContext() {
47
40
int32_t n_threads = OS::get_singleton ()->get_processor_count ();
48
41
ctx_params.n_threads = n_threads;
49
42
ctx_params.n_threads_batch = n_threads;
43
+
44
+ sampling_params = llama_sampling_params ();
45
+
46
+ n_prompts = 0 ;
47
+ should_exit = false ;
50
48
}
51
49
52
50
void LlamaContext::_ready () {
@@ -66,101 +64,66 @@ void LlamaContext::_ready() {
66
64
return ;
67
65
}
68
66
UtilityFunctions::print (vformat (" %s: Context initialized" , __func__));
69
- }
70
67
71
- PackedStringArray LlamaContext::_get_configuration_warnings () const {
72
- PackedStringArray warnings;
73
- if (model == NULL ) {
74
- warnings.push_back (" Model resource property not defined" );
75
- }
76
- return warnings;
68
+ sampling_ctx = llama_sampling_init (sampling_params);
69
+
70
+ prompt_mutex.instantiate ();
71
+ prompt_semaphore.instantiate ();
72
+ prompt_thread.instantiate ();
73
+
74
+ prompt_thread->start (callable_mp (this , &LlamaContext::_thread_prompt_loop));
77
75
}
78
76
79
- Variant LlamaContext::request_completion (const String &prompt) {
80
- UtilityFunctions::print (vformat (" %s: Requesting completion for prompt: %s" , __func__, prompt));
81
- if (task_id) {
82
- WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
83
- }
84
- task_id = WorkerThreadPool::get_singleton ()->add_task (Callable (this , " _fulfill_completion" ).bind (prompt));
85
- return OK;
77
+ int LlamaContext::prompt (const String &prompt, const int max_new_tokens, const float temperature, const float top_p, const int top_k, const float presence_penalty, const float frequency_penalty) {
78
+ UtilityFunctions::print (vformat (" %s: Prompting with prompt: %s, max_new_tokens: %d, temperature: %f, top_p: %f, top_k: %d, presence_penalty: %f, frequency_penalty: %f" , __func__, prompt, max_new_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty));
79
+ prompt_mutex->lock ();
80
+ int id = n_prompts++;
81
+ prompt_requests.push_back ({ id, prompt, max_new_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty });
82
+ prompt_mutex->unlock ();
83
+ prompt_semaphore->post ();
84
+ return id;
86
85
}
87
86
88
- void LlamaContext::_fulfill_completion (const String &prompt) {
89
- UtilityFunctions::print (vformat (" %s: Fulfilling completion for prompt: %s" , __func__, prompt));
90
- std::vector<llama_token> tokens_list;
91
- tokens_list = ::llama_tokenize (ctx, std::string (prompt.utf8 ().get_data ()), true );
87
+ void LlamaContext::_thread_prompt_loop () {
88
+ while (true ) {
89
+ prompt_semaphore->wait ();
92
90
93
- const int n_len = 128 ;
94
- const int n_ctx = llama_n_ctx (ctx);
95
- const int n_kv_req = tokens_list.size () + (n_len - tokens_list.size ());
96
- if (n_kv_req > n_ctx) {
97
- UtilityFunctions::printerr (vformat (" %s: n_kv_req > n_ctx, the required KV cache size is not big enough\n either reduce n_len or increase n_ctx" , __func__));
98
- return ;
99
- }
91
+ prompt_mutex->lock ();
92
+ if (should_exit) {
93
+ prompt_mutex->unlock ();
94
+ return ;
95
+ }
96
+ if (prompt_requests.size () == 0 ) {
97
+ prompt_mutex->unlock ();
98
+ continue ;
99
+ }
100
+ prompt_request req = prompt_requests.get (0 );
101
+ prompt_requests.remove_at (0 );
102
+ prompt_mutex->unlock ();
100
103
101
- for (size_t i = 0 ; i < tokens_list.size (); i++) {
102
- llama_batch_add (batch, tokens_list[i], i, { 0 }, false );
103
- }
104
+ UtilityFunctions::print (vformat (" %s: Running prompt %d: %s, max_new_tokens: %d, temperature: %f, top_p: %f, top_k: %d, presence_penalty: %f, frequency_penalty: %f" , __func__, req.id , req.prompt , req.max_new_tokens , req.temperature , req.top_p , req.top_k , req.presence_penalty , req.frequency_penalty ));
104
105
105
- batch.logits [batch.n_tokens - 1 ] = true ;
106
+ llama_sampling_reset (sampling_ctx);
107
+ llama_batch_clear (batch);
108
+ llama_kv_cache_clear (ctx);
106
109
107
- llama_kv_cache_clear (ctx);
110
+ auto ¶ms = sampling_ctx->params ;
111
+ params.temp = req.temperature ;
112
+ params.top_p = req.top_p ;
113
+ params.top_k = req.top_k ;
114
+ params.penalty_present = req.presence_penalty ;
115
+ params.penalty_freq = req.frequency_penalty ;
108
116
109
- int decode_res = llama_decode (ctx, batch);
110
- if (decode_res != 0 ) {
111
- UtilityFunctions::printerr (vformat (" %s: Failed to decode prompt with error code: %d" , __func__, decode_res));
112
- return ;
117
+ std::vector<llama_token> tokens = ::llama_tokenize (ctx, req.prompt .utf8 ().get_data (), false , true );
113
118
}
119
+ }
114
120
115
- int n_cur = batch.n_tokens ;
116
- int n_decode = 0 ;
117
- llama_model *llama_model = model->model ;
118
-
119
- while (n_cur <= n_len) {
120
- // sample the next token
121
- {
122
- auto n_vocab = llama_n_vocab (llama_model);
123
- auto *logits = llama_get_logits_ith (ctx, batch.n_tokens - 1 );
124
-
125
- std::vector<llama_token_data> candidates;
126
- candidates.reserve (n_vocab);
127
-
128
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
129
- candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
130
- }
131
-
132
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
133
-
134
- // sample the most likely token
135
- const llama_token new_token_id = llama_sample_token_greedy (ctx, &candidates_p);
136
-
137
- // is it an end of stream?
138
- if (new_token_id == llama_token_eos (llama_model) || n_cur == n_len) {
139
- call_thread_safe (" emit_signal" , " completion_generated" , " \n " , true );
140
-
141
- break ;
142
- }
143
-
144
- call_thread_safe (" emit_signal" , " completion_generated" , vformat (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ()), false );
145
-
146
- // prepare the next batch
147
- llama_batch_clear (batch);
148
-
149
- // push this new token for next evaluation
150
- llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
151
-
152
- n_decode += 1 ;
153
- }
154
-
155
- n_cur += 1 ;
156
-
157
- // evaluate the current batch with the transformer model
158
- int decode_res = llama_decode (ctx, batch);
159
- if (decode_res != 0 ) {
160
- UtilityFunctions::printerr (vformat (" %s: Failed to decode batch with error code: %d" , __func__, decode_res));
161
- break ;
162
- }
121
+ PackedStringArray LlamaContext::_get_configuration_warnings () const {
122
+ PackedStringArray warnings;
123
+ if (model == NULL ) {
124
+ warnings.push_back (" Model resource property not defined" );
163
125
}
126
+ return warnings;
164
127
}
165
128
166
129
void LlamaContext::set_model (const Ref<LlamaModel> p_model) {
@@ -173,39 +136,38 @@ Ref<LlamaModel> LlamaContext::get_model() {
173
136
int LlamaContext::get_seed () {
174
137
return ctx_params.seed ;
175
138
}
176
- void LlamaContext::set_seed (int seed) {
139
+ void LlamaContext::set_seed (const int seed) {
177
140
ctx_params.seed = seed;
178
141
}
179
142
180
143
int LlamaContext::get_n_ctx () {
181
144
return ctx_params.n_ctx ;
182
145
}
183
- void LlamaContext::set_n_ctx (int n_ctx) {
146
+ void LlamaContext::set_n_ctx (const int n_ctx) {
184
147
ctx_params.n_ctx = n_ctx;
185
148
}
186
149
187
- int LlamaContext::get_n_threads () {
188
- return ctx_params.n_threads ;
189
- }
190
- void LlamaContext::set_n_threads (int n_threads) {
191
- ctx_params.n_threads = n_threads;
192
- }
150
+ void LlamaContext::_exit_tree () {
151
+ prompt_mutex->lock ();
152
+ should_exit = true ;
153
+ prompt_requests.clear ();
154
+ prompt_mutex->unlock ();
193
155
194
- int LlamaContext::get_n_threads_batch () {
195
- return ctx_params. n_threads_batch ;
196
- }
197
- void LlamaContext::set_n_threads_batch ( int n_threads_batch) {
198
- ctx_params. n_threads_batch = n_threads_batch;
199
- }
156
+ prompt_semaphore-> post ();
157
+
158
+ if (prompt_thread. is_valid ()) {
159
+ prompt_thread-> wait_to_finish ();
160
+ }
161
+ prompt_thread. unref ();
200
162
201
- LlamaContext::~LlamaContext () {
202
163
if (ctx) {
203
164
llama_free (ctx);
204
165
}
166
+ if (sampling_ctx) {
167
+ llama_sampling_free (sampling_ctx);
168
+ }
169
+ }
205
170
171
+ LlamaContext::~LlamaContext () {
206
172
llama_batch_free (batch);
207
-
208
- if (task_id) {
209
- WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
210
- }
211
173
}
0 commit comments