4
4
#include " llama_model.h"
5
5
#include < godot_cpp/classes/engine.hpp>
6
6
#include < godot_cpp/classes/os.hpp>
7
- #include < godot_cpp/classes/worker_thread_pool.hpp>
7
+ #include < godot_cpp/classes/semaphore.hpp>
8
+ #include < godot_cpp/classes/thread.hpp>
8
9
#include < godot_cpp/core/class_db.hpp>
9
10
#include < godot_cpp/variant/utility_functions.hpp>
10
11
@@ -15,29 +16,42 @@ void LlamaContext::_bind_methods() {
15
16
ClassDB::bind_method (D_METHOD (" get_model" ), &LlamaContext::get_model);
16
17
ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::OBJECT, " model" , PROPERTY_HINT_RESOURCE_TYPE, " LlamaModel" ), " set_model" , " get_model" );
17
18
18
- ClassDB::bind_method (D_METHOD (" get_seed" ), &LlamaContext::get_seed);
19
- ClassDB::bind_method (D_METHOD (" set_seed" , " seed" ), &LlamaContext::set_seed);
20
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " seed" ), " set_seed" , " get_seed" );
19
+ ClassDB::bind_method (D_METHOD (" get_seed" ), &LlamaContext::get_seed);
20
+ ClassDB::bind_method (D_METHOD (" set_seed" , " seed" ), &LlamaContext::set_seed);
21
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " seed" ), " set_seed" , " get_seed" );
21
22
22
- ClassDB::bind_method (D_METHOD (" get_n_ctx" ), &LlamaContext::get_n_ctx);
23
- ClassDB::bind_method (D_METHOD (" set_n_ctx" , " n_ctx" ), &LlamaContext::set_n_ctx);
24
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_ctx" ), " set_n_ctx" , " get_n_ctx" );
23
+ ClassDB::bind_method (D_METHOD (" get_n_ctx" ), &LlamaContext::get_n_ctx);
24
+ ClassDB::bind_method (D_METHOD (" set_n_ctx" , " n_ctx" ), &LlamaContext::set_n_ctx);
25
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " n_ctx" ), " set_n_ctx" , " get_n_ctx" );
25
26
26
- ClassDB::bind_method (D_METHOD (" get_n_threads " ), &LlamaContext::get_n_threads );
27
- ClassDB::bind_method (D_METHOD (" set_n_threads " , " n_threads " ), &LlamaContext::set_n_threads );
28
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT , " n_threads " ), " set_n_threads " , " get_n_threads " );
27
+ ClassDB::bind_method (D_METHOD (" get_temperature " ), &LlamaContext::get_temperature );
28
+ ClassDB::bind_method (D_METHOD (" set_temperature " , " temperature " ), &LlamaContext::set_temperature );
29
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT , " temperature " ), " set_temperature " , " get_temperature " );
29
30
30
- ClassDB::bind_method (D_METHOD (" get_n_threads_batch " ), &LlamaContext::get_n_threads_batch );
31
- ClassDB::bind_method (D_METHOD (" set_n_threads_batch " , " n_threads_batch " ), &LlamaContext::set_n_threads_batch );
32
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT , " n_threads_batch " ), " set_n_threads_batch " , " get_n_threads_batch " );
31
+ ClassDB::bind_method (D_METHOD (" get_top_p " ), &LlamaContext::get_top_p );
32
+ ClassDB::bind_method (D_METHOD (" set_top_p " , " top_p " ), &LlamaContext::set_top_p );
33
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT , " top_p " ), " set_top_p " , " get_top_p " );
33
34
34
- ClassDB::bind_method (D_METHOD (" request_completion" , " prompt" ), &LlamaContext::request_completion);
35
- ClassDB::bind_method (D_METHOD (" _fulfill_completion" , " prompt" ), &LlamaContext::_fulfill_completion);
35
+ ClassDB::bind_method (D_METHOD (" get_top_k" ), &LlamaContext::get_top_k);
36
+ ClassDB::bind_method (D_METHOD (" set_top_k" , " top_k" ), &LlamaContext::set_top_k);
37
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::INT, " top_k" ), " set_top_k" , " get_top_k" );
36
38
37
- ADD_SIGNAL (MethodInfo (" completion_generated" , PropertyInfo (Variant::STRING, " completion" ), PropertyInfo (Variant::BOOL, " is_final" )));
39
+ ClassDB::bind_method (D_METHOD (" get_presence_penalty" ), &LlamaContext::get_presence_penalty);
40
+ ClassDB::bind_method (D_METHOD (" set_presence_penalty" , " presence_penalty" ), &LlamaContext::set_presence_penalty);
41
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT, " presence_penalty" ), " set_presence_penalty" , " get_presence_penalty" );
42
+
43
+ ClassDB::bind_method (D_METHOD (" get_frequency_penalty" ), &LlamaContext::get_frequency_penalty);
44
+ ClassDB::bind_method (D_METHOD (" set_frequency_penalty" , " frequency_penalty" ), &LlamaContext::set_frequency_penalty);
45
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::FLOAT, " frequency_penalty" ), " set_frequency_penalty" , " get_frequency_penalty" );
46
+
47
+ ClassDB::bind_method (D_METHOD (" prompt" , " prompt" , " max_new_tokens" ), &LlamaContext::prompt);
48
+ ClassDB::bind_method (D_METHOD (" _run_prompts" ), &LlamaContext::_run_prompts);
49
+
50
+ ADD_SIGNAL (MethodInfo (" prompt_completion" , PropertyInfo (Variant::STRING, " prompt_id" ), PropertyInfo (Variant::STRING, " completion" ), PropertyInfo (Variant::BOOL, " is_final" )));
38
51
}
39
52
40
- LlamaContext::LlamaContext () {
53
+ LlamaContext::LlamaContext () :
54
+ sampling_params() {
41
55
batch = llama_batch_init (4096 , 0 , 1 );
42
56
43
57
ctx_params = llama_context_default_params ();
@@ -66,100 +80,57 @@ void LlamaContext::_ready() {
66
80
return ;
67
81
}
68
82
UtilityFunctions::print (vformat (" %s: Context initialized" , __func__));
69
- }
70
-
71
- PackedStringArray LlamaContext::_get_configuration_warnings () const {
72
- PackedStringArray warnings;
73
- if (model == NULL ) {
74
- warnings.push_back (" Model resource property not defined" );
75
- }
76
- return warnings;
77
- }
78
-
79
- Variant LlamaContext::request_completion (const String &prompt) {
80
- UtilityFunctions::print (vformat (" %s: Requesting completion for prompt: %s" , __func__, prompt));
81
- if (task_id) {
82
- WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
83
- }
84
- task_id = WorkerThreadPool::get_singleton ()->add_task (Callable (this , " _fulfill_completion" ).bind (prompt));
85
- return OK;
86
- }
87
-
88
- void LlamaContext::_fulfill_completion (const String &prompt) {
89
- UtilityFunctions::print (vformat (" %s: Fulfilling completion for prompt: %s" , __func__, prompt));
90
- std::vector<llama_token> tokens_list;
91
- tokens_list = ::llama_tokenize (ctx, std::string (prompt.utf8 ().get_data ()), true );
92
83
93
- const int n_len = 128 ;
94
- const int n_ctx = llama_n_ctx (ctx);
95
- const int n_kv_req = tokens_list.size () + (n_len - tokens_list.size ());
96
- if (n_kv_req > n_ctx) {
97
- UtilityFunctions::printerr (vformat (" %s: n_kv_req > n_ctx, the required KV cache size is not big enough\n either reduce n_len or increase n_ctx" , __func__));
98
- return ;
99
- }
100
-
101
- for (size_t i = 0 ; i < tokens_list.size (); i++) {
102
- llama_batch_add (batch, tokens_list[i], i, { 0 }, false );
103
- }
84
+ sampling_ctx = llama_sampling_init (sampling_params);
104
85
105
- batch.logits [batch.n_tokens - 1 ] = true ;
86
+ semaphore.instantiate ();
87
+ mutex.instantiate ();
88
+ worker_thread.instantiate ();
106
89
107
- llama_kv_cache_clear (ctx);
90
+ worker_thread->start (Callable (this , " _run_prompts" ));
91
+ }
108
92
109
- int decode_res = llama_decode (ctx, batch);
110
- if (decode_res != 0 ) {
111
- UtilityFunctions::printerr ( vformat ( " %s: Failed to decode prompt with error code: %d " , __func__, decode_res));
112
- return ;
93
+ PackedStringArray LlamaContext::_get_configuration_warnings () const {
94
+ PackedStringArray warnings;
95
+ if (model == NULL ) {
96
+ warnings. push_back ( " Model resource property not defined " ) ;
113
97
}
98
+ return warnings;
99
+ }
114
100
115
- int n_cur = batch.n_tokens ;
116
- int n_decode = 0 ;
117
- llama_model *llama_model = model->model ;
118
-
119
- while (n_cur <= n_len) {
120
- // sample the next token
121
- {
122
- auto n_vocab = llama_n_vocab (llama_model);
123
- auto *logits = llama_get_logits_ith (ctx, batch.n_tokens - 1 );
124
-
125
- std::vector<llama_token_data> candidates;
126
- candidates.reserve (n_vocab);
127
-
128
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
129
- candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
130
- }
131
-
132
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
133
-
134
- // sample the most likely token
135
- const llama_token new_token_id = llama_sample_token_greedy (ctx, &candidates_p);
136
-
137
- // is it an end of stream?
138
- if (new_token_id == llama_token_eos (llama_model) || n_cur == n_len) {
139
- call_thread_safe (" emit_signal" , " completion_generated" , " \n " , true );
101
+ int LlamaContext::prompt (const String &prompt, int max_new_tokens) {
102
+ mutex->lock ();
103
+ int prompt_id = n_prompts++;
104
+ prompts.push_back (prompt);
105
+ mutex->unlock ();
140
106
141
- break ;
142
- }
107
+ semaphore->post ();
143
108
144
- call_thread_safe ( " emit_signal " , " completion_generated " , vformat (" % s" , llama_token_to_piece (ctx, new_token_id). c_str ()), false );
109
+ UtilityFunctions::print ( vformat (" New prompt %d: % s" , prompt_id, prompt) );
145
110
146
- // prepare the next batch
147
- llama_batch_clear (batch);
111
+ return prompt_id;
112
+ }
148
113
149
- // push this new token for next evaluation
150
- llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
114
+ void LlamaContext::_run_prompts () {
115
+ while (true ) {
116
+ semaphore->wait ();
151
117
152
- n_decode += 1 ;
118
+ mutex->lock ();
119
+ if (should_exit) {
120
+ mutex->unlock ();
121
+ break ;
122
+ }
123
+ if (prompts.is_empty ()) {
124
+ mutex->unlock ();
125
+ continue ;
153
126
}
127
+ String prompt = prompts.get (0 );
128
+ prompts.remove_at (0 );
129
+ mutex->unlock ();
154
130
155
- n_cur += 1 ;
131
+ UtilityFunctions::print ( vformat ( " Running prompt %s " , prompt)) ;
156
132
157
- // evaluate the current batch with the transformer model
158
- int decode_res = llama_decode (ctx, batch);
159
- if (decode_res != 0 ) {
160
- UtilityFunctions::printerr (vformat (" %s: Failed to decode batch with error code: %d" , __func__, decode_res));
161
- break ;
162
- }
133
+ OS::get_singleton ()->delay_msec (2000 );
163
134
}
164
135
}
165
136
@@ -184,28 +155,52 @@ void LlamaContext::set_n_ctx(int n_ctx) {
184
155
ctx_params.n_ctx = n_ctx;
185
156
}
186
157
187
- int LlamaContext::get_n_threads () {
188
- return ctx_params. n_threads ;
158
+ float LlamaContext::get_temperature () {
159
+ return sampling_params. temp ;
189
160
}
190
- void LlamaContext::set_n_threads (int n_threads) {
191
- ctx_params.n_threads = n_threads;
161
+ void LlamaContext::set_temperature (float temperature) {
162
+ sampling_params.temp = temperature;
163
+ }
164
+
165
+ float LlamaContext::get_top_p () {
166
+ return sampling_params.top_p ;
167
+ }
168
+ void LlamaContext::set_top_p (float top_p) {
169
+ sampling_params.top_p = top_p;
170
+ }
171
+
172
+ int LlamaContext::get_top_k () {
173
+ return sampling_params.top_k ;
174
+ }
175
+ void LlamaContext::set_top_k (int top_k) {
176
+ sampling_params.top_k = top_k;
192
177
}
193
178
194
- int LlamaContext::get_n_threads_batch () {
195
- return ctx_params. n_threads_batch ;
179
+ float LlamaContext::get_presence_penalty () {
180
+ return sampling_params. penalty_present ;
196
181
}
197
- void LlamaContext::set_n_threads_batch (int n_threads_batch) {
198
- ctx_params.n_threads_batch = n_threads_batch;
182
+ void LlamaContext::set_presence_penalty (float presence_penalty) {
183
+ sampling_params.penalty_present = presence_penalty;
184
+ }
185
+
186
+ float LlamaContext::get_frequency_penalty () {
187
+ return sampling_params.penalty_freq ;
188
+ }
189
+ void LlamaContext::set_frequency_penalty (float frequency_penalty) {
190
+ sampling_params.penalty_freq = frequency_penalty;
199
191
}
200
192
201
193
LlamaContext::~LlamaContext () {
194
+ llama_batch_free (batch);
195
+ llama_sampling_free (sampling_ctx);
202
196
if (ctx) {
203
197
llama_free (ctx);
204
198
}
205
199
206
- llama_batch_free (batch);
207
-
208
- if (task_id) {
209
- WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
210
- }
200
+ mutex->lock ();
201
+ prompts.clear ();
202
+ should_exit = true ;
203
+ mutex->unlock ();
204
+ semaphore->post ();
205
+ worker_thread->wait_to_finish ();
211
206
}
0 commit comments