Skip to content

Commit b19ff2f

Browse files
committed
Finer control over gradient of substitution model
1 parent 5691ea3 commit b19ff2f

File tree

1 file changed

+51
-8
lines changed

1 file changed

+51
-8
lines changed

src/phyc/treelikelihood.c

+51-8
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ size_t TreeLikelihood_initialize_gradient(Model *self, int flags){
234234
int prepare_branch_model = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_BRANCH_MODEL;
235235
int prepare_substitution_model_unconstrained = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED;
236236
int prepare_substitution_model = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL;
237+
int prepare_substitution_model_rates = prepare_substitution_model | (tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_RATES);
238+
int prepare_substitution_model_frequencies = prepare_substitution_model | (tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_FREQUENCIES);
237239

238240
if(prepare_substitution_model_unconstrained && prepare_substitution_model){
239241
fprintf(stderr, "Can only request unconstrained and constrained gradient at the same time\n");
@@ -280,12 +282,18 @@ size_t TreeLikelihood_initialize_gradient(Model *self, int flags){
280282
gradient_length += tlk->m->simplex->K - 1;
281283
tlk->include_root_freqs = false;
282284
}
283-
else if (prepare_substitution_model) {
284-
gradient_length += tlk->m->rates_simplex == NULL ? Parameters_count(tlk->m->rates) : tlk->m->rates_simplex->K;
285-
gradient_length += tlk->m->simplex->K;
286-
tlk->m->grad_wrt_reparam = false;
287-
tlk->include_root_freqs = false;
288-
}
285+
else{
286+
if (prepare_substitution_model_rates) {
287+
gradient_length += tlk->m->rates_simplex == NULL ? Parameters_count(tlk->m->rates) : tlk->m->rates_simplex->K;
288+
tlk->m->grad_wrt_reparam = false;
289+
tlk->include_root_freqs = false;
290+
}
291+
if (prepare_substitution_model_frequencies) {
292+
gradient_length += tlk->m->simplex->K;
293+
tlk->m->grad_wrt_reparam = false;
294+
tlk->include_root_freqs = false;
295+
}
296+
}
289297

290298
if(tlk->gradient == NULL){
291299
tlk->gradient = calloc(gradient_length, sizeof(double));
@@ -3002,6 +3010,29 @@ void gradient_PMatrix(SingleTreeLikelihood* tlk, const double* pattern_likelihoo
30023010
}
30033011
}
30043012

3013+
void gradient_PMatrix_rates(SingleTreeLikelihood* tlk, const double* pattern_likelihoods, double* gradient){
3014+
size_t parameter_count = tlk->m->rates_simplex == NULL ? Parameters_count(tlk->m->rates) : tlk->m->rates_simplex->K;
3015+
if(tlk->m->rates_simplex != NULL && tlk->m->grad_wrt_reparam){
3016+
parameter_count--;
3017+
}
3018+
for(size_t i = 0; i < parameter_count; i++){
3019+
gradient[i] = calculate_dlnl_dQ(tlk, i, pattern_likelihoods);
3020+
}
3021+
}
3022+
3023+
void gradient_PMatrix_frequencies(SingleTreeLikelihood* tlk, const double* pattern_likelihoods, double* gradient){
3024+
size_t start = tlk->m->rates_simplex == NULL ? Parameters_count(tlk->m->rates) : tlk->m->rates_simplex->K;
3025+
size_t parameter_count = tlk->m->simplex->K;
3026+
if(tlk->m->grad_wrt_reparam) parameter_count--;
3027+
3028+
if(tlk->m->rates_simplex != NULL && tlk->m->grad_wrt_reparam){
3029+
start--;
3030+
}
3031+
for(size_t i = 0; i < parameter_count; i++){
3032+
gradient[i] = calculate_dlnl_dQ(tlk, i+start, pattern_likelihoods);
3033+
}
3034+
}
3035+
30053036
void gradient_branch_length_from_cat(SingleTreeLikelihood* tlk, const double* cat_branch_gradient, double* gradient){
30063037
size_t nodeCount = Tree_node_count(tlk->tree);
30073038
size_t catCount = tlk->sm->cat_count;
@@ -3073,7 +3104,9 @@ void SingleTreeLikelihood_gradient( SingleTreeLikelihood *tlk, double* grads ){
30733104
bool prepare_tree = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_TREE;
30743105
bool prepare_site_model = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SITE_MODEL;
30753106
bool prepare_branch_model = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_BRANCH_MODEL;
3076-
bool prepare_subsitution_model = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED || tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL;
3107+
bool prepare_substitution_model = tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED || tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL;
3108+
bool prepare_substitution_model_rates = prepare_substitution_model || (tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_RATES);
3109+
bool prepare_substitution_model_frequencies = prepare_substitution_model || (tlk->prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_FREQUENCIES);
30773110

30783111
size_t nodeCount = Tree_node_count(tlk->tree);
30793112
size_t catCount = tlk->sm->cat_count;
@@ -3167,9 +3200,19 @@ void SingleTreeLikelihood_gradient( SingleTreeLikelihood *tlk, double* grads ){
31673200
offset += Parameters_count(tlk->bm->rates);
31683201
}
31693202

3170-
if(prepare_subsitution_model){
3203+
if(prepare_substitution_model){
31713204
gradient_PMatrix(tlk, pattern_likelihoods, grads+offset);
31723205
}
3206+
else{
3207+
if(prepare_substitution_model_rates){
3208+
gradient_PMatrix_rates(tlk, pattern_likelihoods, grads+offset);
3209+
offset += tlk->m->rates_simplex == NULL ? Parameters_count(tlk->m->rates) : tlk->m->rates_simplex->K;
3210+
if(tlk->m->grad_wrt_reparam) offset--;
3211+
}
3212+
if(prepare_substitution_model_frequencies){
3213+
gradient_PMatrix_frequencies(tlk, pattern_likelihoods, grads+offset);
3214+
}
3215+
}
31733216

31743217
free(branch_lengths);
31753218
free(cat_branch_gradient);

0 commit comments

Comments
 (0)