@@ -234,6 +234,8 @@ size_t TreeLikelihood_initialize_gradient(Model *self, int flags){
234
234
int prepare_branch_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_BRANCH_MODEL ;
235
235
int prepare_substitution_model_unconstrained = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED ;
236
236
int prepare_substitution_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL ;
237
+ int prepare_substitution_model_rates = prepare_substitution_model | (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_RATES );
238
+ int prepare_substitution_model_frequencies = prepare_substitution_model | (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_FREQUENCIES );
237
239
238
240
if (prepare_substitution_model_unconstrained && prepare_substitution_model ){
239
241
fprintf (stderr , "Can only request unconstrained and constrained gradient at the same time\n" );
@@ -280,12 +282,18 @@ size_t TreeLikelihood_initialize_gradient(Model *self, int flags){
280
282
gradient_length += tlk -> m -> simplex -> K - 1 ;
281
283
tlk -> include_root_freqs = false;
282
284
}
283
- else if (prepare_substitution_model ) {
284
- gradient_length += tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
285
- gradient_length += tlk -> m -> simplex -> K ;
286
- tlk -> m -> grad_wrt_reparam = false;
287
- tlk -> include_root_freqs = false;
288
- }
285
+ else {
286
+ if (prepare_substitution_model_rates ) {
287
+ gradient_length += tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
288
+ tlk -> m -> grad_wrt_reparam = false;
289
+ tlk -> include_root_freqs = false;
290
+ }
291
+ if (prepare_substitution_model_frequencies ) {
292
+ gradient_length += tlk -> m -> simplex -> K ;
293
+ tlk -> m -> grad_wrt_reparam = false;
294
+ tlk -> include_root_freqs = false;
295
+ }
296
+ }
289
297
290
298
if (tlk -> gradient == NULL ){
291
299
tlk -> gradient = calloc (gradient_length , sizeof (double ));
@@ -3002,6 +3010,29 @@ void gradient_PMatrix(SingleTreeLikelihood* tlk, const double* pattern_likelihoo
3002
3010
}
3003
3011
}
3004
3012
3013
+ void gradient_PMatrix_rates (SingleTreeLikelihood * tlk , const double * pattern_likelihoods , double * gradient ){
3014
+ size_t parameter_count = tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
3015
+ if (tlk -> m -> rates_simplex != NULL && tlk -> m -> grad_wrt_reparam ){
3016
+ parameter_count -- ;
3017
+ }
3018
+ for (size_t i = 0 ; i < parameter_count ; i ++ ){
3019
+ gradient [i ] = calculate_dlnl_dQ (tlk , i , pattern_likelihoods );
3020
+ }
3021
+ }
3022
+
3023
+ void gradient_PMatrix_frequencies (SingleTreeLikelihood * tlk , const double * pattern_likelihoods , double * gradient ){
3024
+ size_t start = tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
3025
+ size_t parameter_count = tlk -> m -> simplex -> K ;
3026
+ if (tlk -> m -> grad_wrt_reparam ) parameter_count -- ;
3027
+
3028
+ if (tlk -> m -> rates_simplex != NULL && tlk -> m -> grad_wrt_reparam ){
3029
+ start -- ;
3030
+ }
3031
+ for (size_t i = 0 ; i < parameter_count ; i ++ ){
3032
+ gradient [i ] = calculate_dlnl_dQ (tlk , i + start , pattern_likelihoods );
3033
+ }
3034
+ }
3035
+
3005
3036
void gradient_branch_length_from_cat (SingleTreeLikelihood * tlk , const double * cat_branch_gradient , double * gradient ){
3006
3037
size_t nodeCount = Tree_node_count (tlk -> tree );
3007
3038
size_t catCount = tlk -> sm -> cat_count ;
@@ -3073,7 +3104,9 @@ void SingleTreeLikelihood_gradient( SingleTreeLikelihood *tlk, double* grads ){
3073
3104
bool prepare_tree = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_TREE ;
3074
3105
bool prepare_site_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SITE_MODEL ;
3075
3106
bool prepare_branch_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_BRANCH_MODEL ;
3076
- bool prepare_subsitution_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED || tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL ;
3107
+ bool prepare_substitution_model = tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED || tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL ;
3108
+ bool prepare_substitution_model_rates = prepare_substitution_model || (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_RATES );
3109
+ bool prepare_substitution_model_frequencies = prepare_substitution_model || (tlk -> prepared_gradient & TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_FREQUENCIES );
3077
3110
3078
3111
size_t nodeCount = Tree_node_count (tlk -> tree );
3079
3112
size_t catCount = tlk -> sm -> cat_count ;
@@ -3167,9 +3200,19 @@ void SingleTreeLikelihood_gradient( SingleTreeLikelihood *tlk, double* grads ){
3167
3200
offset += Parameters_count (tlk -> bm -> rates );
3168
3201
}
3169
3202
3170
- if (prepare_subsitution_model ){
3203
+ if (prepare_substitution_model ){
3171
3204
gradient_PMatrix (tlk , pattern_likelihoods , grads + offset );
3172
3205
}
3206
+ else {
3207
+ if (prepare_substitution_model_rates ){
3208
+ gradient_PMatrix_rates (tlk , pattern_likelihoods , grads + offset );
3209
+ offset += tlk -> m -> rates_simplex == NULL ? Parameters_count (tlk -> m -> rates ) : tlk -> m -> rates_simplex -> K ;
3210
+ if (tlk -> m -> grad_wrt_reparam ) offset -- ;
3211
+ }
3212
+ if (prepare_substitution_model_frequencies ){
3213
+ gradient_PMatrix_frequencies (tlk , pattern_likelihoods , grads + offset );
3214
+ }
3215
+ }
3173
3216
3174
3217
free (branch_lengths );
3175
3218
free (cat_branch_gradient );
0 commit comments