1
1
#include " llama_context.h"
2
+ #include " common.h"
2
3
#include " llama.h"
3
4
#include " llama_model.h"
4
5
#include < godot_cpp/classes/engine.hpp>
5
6
#include < godot_cpp/classes/os.hpp>
7
+ #include < godot_cpp/classes/worker_thread_pool.hpp>
6
8
#include < godot_cpp/core/class_db.hpp>
7
9
#include < godot_cpp/variant/utility_functions.hpp>
8
10
9
11
using namespace godot ;
10
12
11
- void LlamaContext::set_model (const Ref<LlamaModel> p_model) {
12
- model = p_model;
13
- }
14
-
15
- Ref<LlamaModel> LlamaContext::get_model () {
16
- return model;
13
+ void LlamaContext::_bind_methods () {
14
+ ClassDB::bind_method (D_METHOD (" set_model" , " model" ), &LlamaContext::set_model);
15
+ ClassDB::bind_method (D_METHOD (" get_model" ), &LlamaContext::get_model);
16
+ ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::OBJECT, " model" , PROPERTY_HINT_RESOURCE_TYPE, " LlamaModel" ), " set_model" , " get_model" );
17
+ ClassDB::bind_method (D_METHOD (" request_completion" , " prompt" ), &LlamaContext::request_completion);
18
+ ClassDB::bind_method (D_METHOD (" _fulfill_completion" , " prompt" ), &LlamaContext::_fulfill_completion);
19
+ ADD_SIGNAL (MethodInfo (" completion_generated" , PropertyInfo (Variant::STRING, " completion" ), PropertyInfo (Variant::BOOL, " is_final" )));
17
20
}
18
21
19
22
void LlamaContext::_ready () {
@@ -40,14 +43,103 @@ void LlamaContext::_ready() {
40
43
UtilityFunctions::print (vformat (" %s: Context initialized" , __func__));
41
44
}
42
45
46
+ Variant LlamaContext::request_completion (const String &prompt) {
47
+ UtilityFunctions::print (vformat (" %s: Requesting completion for prompt: %s" , __func__, prompt));
48
+ if (task_id) {
49
+ WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
50
+ }
51
+ task_id = WorkerThreadPool::get_singleton ()->add_task (Callable (this , " _fulfill_completion" ).bind (prompt));
52
+ return OK;
53
+ }
54
+
55
+ void LlamaContext::_fulfill_completion (const String &prompt) {
56
+ UtilityFunctions::print (vformat (" %s: Fulfilling completion for prompt: %s" , __func__, prompt));
57
+ std::vector<llama_token> tokens_list;
58
+ tokens_list = ::llama_tokenize (ctx, std::string (prompt.utf8 ().get_data ()), true );
59
+
60
+ const int n_len = 128 ;
61
+ const int n_ctx = llama_n_ctx (ctx);
62
+ const int n_kv_req = tokens_list.size () + (n_len - tokens_list.size ());
63
+ if (n_kv_req > n_ctx) {
64
+ UtilityFunctions::printerr (vformat (" %s: n_kv_req > n_ctx, the required KV cache size is not big enough\n either reduce n_len or increase n_ctx" , __func__));
65
+ return ;
66
+ }
67
+
68
+ for (size_t i = 0 ; i < tokens_list.size (); i++) {
69
+ llama_batch_add (batch, tokens_list[i], i, { 0 }, false );
70
+ }
71
+ batch.logits [batch.n_tokens - 1 ] = true ;
72
+
73
+ int decode_res = llama_decode (ctx, batch);
74
+ if (decode_res != 0 ) {
75
+ UtilityFunctions::printerr (vformat (" %s: Failed to decode prompt with error code: %d" , __func__, decode_res));
76
+ return ;
77
+ }
78
+
79
+ int n_cur = batch.n_tokens ;
80
+ int n_decode = 0 ;
81
+ llama_model *llama_model = model->model ;
82
+ while (n_cur <= n_len) {
83
+ // sample the next token
84
+ {
85
+ auto n_vocab = llama_n_vocab (llama_model);
86
+ auto *logits = llama_get_logits_ith (ctx, batch.n_tokens - 1 );
87
+
88
+ std::vector<llama_token_data> candidates;
89
+ candidates.reserve (n_vocab);
90
+
91
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
92
+ candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
93
+ }
94
+
95
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
96
+
97
+ // sample the most likely token
98
+ const llama_token new_token_id = llama_sample_token_greedy (ctx, &candidates_p);
99
+
100
+ // is it an end of stream?
101
+ if (new_token_id == llama_token_eos (llama_model) || n_cur == n_len) {
102
+ call_thread_safe (" emit_signal" , " completion_generated" , " \n " , true );
103
+
104
+ break ;
105
+ }
106
+
107
+ call_thread_safe (" emit_signal" , " completion_generated" , vformat (" %s" , llama_token_to_piece (ctx, new_token_id).c_str ()), false );
108
+
109
+ // prepare the next batch
110
+ llama_batch_clear (batch);
111
+
112
+ // push this new token for next evaluation
113
+ llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
114
+
115
+ n_decode += 1 ;
116
+ }
117
+
118
+ n_cur += 1 ;
119
+
120
+ // evaluate the current batch with the transformer model
121
+ int decode_res = llama_decode (ctx, batch);
122
+ if (decode_res != 0 ) {
123
+ UtilityFunctions::printerr (vformat (" %s: Failed to decode batch with error code: %d" , __func__, decode_res));
124
+ return ;
125
+ }
126
+ }
127
+ }
128
+
129
+ void LlamaContext::set_model (const Ref<LlamaModel> p_model) {
130
+ model = p_model;
131
+ }
132
+
133
+ Ref<LlamaModel> LlamaContext::get_model () {
134
+ return model;
135
+ }
136
+
43
137
LlamaContext::~LlamaContext () {
44
138
if (ctx) {
45
139
llama_free (ctx);
46
140
}
47
- }
48
-
49
- void LlamaContext::_bind_methods () {
50
- ClassDB::bind_method (D_METHOD (" set_model" , " model" ), &LlamaContext::set_model);
51
- ClassDB::bind_method (D_METHOD (" get_model" ), &LlamaContext::get_model);
52
- ClassDB::add_property (" LlamaContext" , PropertyInfo (Variant::OBJECT, " model" , PROPERTY_HINT_RESOURCE_TYPE, " LlamaModel" ), " set_model" , " get_model" );
141
+ llama_batch_free (batch);
142
+ if (task_id) {
143
+ WorkerThreadPool::get_singleton ()->wait_for_task_completion (task_id);
144
+ }
53
145
}
0 commit comments