@@ -182,40 +182,27 @@ SimpleClockModelInterface::SimpleClockModelInterface(const std::vector<double> &
182
182
new_BranchModel2 (" simple_clock" , branchModel_, treeModel->GetModel (), nullptr );
183
183
}
184
184
185
- Model *SubstitutionModelInterface::Initialize (const std::string &name,
186
- Parameters *rates, Model *frequencies,
187
- Model *rates_model) {
188
- DataType *datatype = new_NucleotideDataType ();
189
- Simplex *rates_simplex = rates_model == nullptr
190
- ? nullptr
191
- : reinterpret_cast <Simplex *>(rates_model->obj );
192
- substModel_ = SubstitutionModel_factory (
193
- name.c_str (), datatype, reinterpret_cast <Simplex *>(frequencies->obj ),
194
- rates_simplex, rates, nullptr );
195
- Model *model =
196
- new_SubstitutionModel2 (name.c_str (), substModel_, frequencies, rates_model);
197
- free_DataType (datatype);
198
- return model;
199
- }
200
-
201
185
JC69Interface::JC69Interface () {
202
186
Simplex *frequencies_simplex = new_Simplex (" jc69_frequency_simplex" , 4 );
203
187
Model *frequencies_model =
204
188
new_SimplexModel (" jc69_frequencies" , frequencies_simplex);
205
- model_ = Initialize (" jc69" , nullptr , frequencies_model, nullptr );
189
+ dataType_ = new NucleotideDataTypeInterface ();
190
+ substModel_ = new_JC69 (frequencies_simplex);
191
+ model_ = new_SubstitutionModel2 (" jc69" , substModel_, frequencies_model, nullptr );
206
192
frequencies_model->free (frequencies_model);
207
193
parameterCount_ = 0 ;
208
194
}
209
195
210
196
HKYInterface::HKYInterface (double kappa, const std::vector<double > &frequencies) {
211
- Parameters *kappa_parameters = new_Parameters (1 );
212
- Parameters_move (kappa_parameters,
213
- new_Parameter (" hky_kappa" , kappa, new_Constraint (0 , INFINITY)));
197
+ Parameter *kappa_parameter =
198
+ new_Parameter (" hky_kappa" , kappa, new_Constraint (0 , INFINITY));
214
199
Simplex *frequencies_simplex = new_Simplex_with_values (
215
200
" hky_frequency_simplex" , frequencies.data (), frequencies.size ());
216
201
Model *frequencies_model = new_SimplexModel (" hky_frequencies" , frequencies_simplex);
217
- model_ = Initialize (" hky" , kappa_parameters, frequencies_model, nullptr );
218
- free_Parameters (kappa_parameters);
202
+ dataType_ = new NucleotideDataTypeInterface ();
203
+ substModel_ = new_HKY_with_parameters (frequencies_simplex, kappa_parameter);
204
+ model_ = new_SubstitutionModel2 (" hky" , substModel_, frequencies_model, nullptr );
205
+ free_Parameter (kappa_parameter);
219
206
frequencies_model->free (frequencies_model);
220
207
parameterCount_ = 5 ;
221
208
}
@@ -255,7 +242,17 @@ GTRInterface::GTRInterface(const std::vector<double> &rates,
255
242
Simplex *frequencies_simplex = new_Simplex_with_values (
256
243
" gtr_frequency_simplex" , frequencies.data (), frequencies.size ());
257
244
Model *frequencies_model = new_SimplexModel (" gtr_frequencies" , frequencies_simplex);
258
- model_ = Initialize (" gtr" , rates_parameters, frequencies_model, rates_model);
245
+
246
+ dataType_ = new NucleotideDataTypeInterface ();
247
+ const char *assigments[5 ] = {" AC" , " AG" , " AT" , " CG" , " CT" };
248
+ Simplex *rates_simplex = rates_model == nullptr
249
+ ? nullptr
250
+ : reinterpret_cast <Simplex *>(rates_model->obj );
251
+ substModel_ =
252
+ SubstitutionModel_factory (" gtr" , dataType_->dataType_ ,
253
+ reinterpret_cast <Simplex *>(frequencies_model->obj ),
254
+ rates_simplex, rates_parameters, assigments);
255
+ model_ = new_SubstitutionModel2 (" gtr" , substModel_, frequencies_model, rates_model);
259
256
free_Parameters (rates_parameters);
260
257
frequencies_model->free (frequencies_model);
261
258
if (rates_model != nullptr ) {
@@ -278,6 +275,53 @@ void GTRInterface::SetParameters(const double *parameters) {
278
275
substModel_->simplex ->set_values (substModel_->simplex , parameters + 3 );
279
276
}
280
277
278
+ GeneralSubstitutionModelInterface::GeneralSubstitutionModelInterface (
279
+ DataTypeInterface *dataType, const std::vector<double > &rates,
280
+ const std::vector<double > &frequencies, const std::vector<unsigned > &structure,
281
+ bool normalize) {
282
+ Parameters *rate_parameters = new_Parameters (rates.size ());
283
+ for (auto rate : rates) {
284
+ Parameters_move (rate_parameters, new_Parameter (" subst_rates" , rate,
285
+ new_Constraint (0 , INFINITY)));
286
+ }
287
+
288
+ parameterCount_ = rates.size ();
289
+ Simplex *frequencies_simplex = new_Simplex_with_values (
290
+ " subst_frequency_simplex" , frequencies.data (), frequencies.size ());
291
+ Model *frequencies_model =
292
+ new_SimplexModel (" subst_frequencies" , frequencies_simplex);
293
+
294
+ DiscreteParameter *dp =
295
+ new_DiscreteParameter_with_values (structure.data (), structure.size ());
296
+ Model *mdp = new_DiscreteParameterModel (" structure" , dp);
297
+
298
+ substModel_ = new_GeneralModel_with_parameters (
299
+ dataType->dataType_ , reinterpret_cast <DiscreteParameter *>(mdp->obj ),
300
+ rate_parameters, frequencies_simplex, -1 , normalize);
301
+ model_ = new_SubstitutionModel3 (" id" , substModel_, frequencies_model, nullptr , mdp);
302
+
303
+ free_Parameters (rate_parameters);
304
+ frequencies_model->free (frequencies_model);
305
+ dataType_ = dataType;
306
+ }
307
+
308
+ void GeneralSubstitutionModelInterface::SetRates (const double *rates) {
309
+ if (substModel_->rates == nullptr ) {
310
+ substModel_->rates_simplex ->set_values (substModel_->rates_simplex , rates);
311
+ } else {
312
+ Parameters_set_values (substModel_->rates , rates);
313
+ }
314
+ }
315
+
316
+ void GeneralSubstitutionModelInterface::SetFrequencies (const double *frequencies) {
317
+ substModel_->simplex ->set_values (substModel_->simplex , frequencies);
318
+ }
319
+
320
+ void GeneralSubstitutionModelInterface::SetParameters (const double *parameters) {
321
+ Parameters_set_values (substModel_->rates , parameters);
322
+ substModel_->simplex ->set_values (substModel_->simplex , parameters + 3 );
323
+ }
324
+
281
325
void SiteModelInterface::SetMu (double mu) { Parameter_set_value (siteModel_->mu , mu); }
282
326
283
327
ConstantSiteModelInterface::ConstantSiteModelInterface (std::optional<double > mu) {
@@ -453,7 +497,7 @@ TreeLikelihoodInterface::TreeLikelihoodInterface(
453
497
Sequences_add (sequences,
454
498
new_Sequence (sequence.first .c_str (), sequence.second .c_str ()));
455
499
}
456
- sequences->datatype = new_NucleotideDataType () ;
500
+ sequences->datatype = substitutionModel-> GetDataType ()-> dataType_ ;
457
501
SitePattern *sitePattern = new_SitePattern (sequences);
458
502
free_Sequences (sequences);
459
503
Model *mbm = branchModel.has_value () ? branchModel_->GetModel () : nullptr ;
@@ -472,6 +516,43 @@ TreeLikelihoodInterface::TreeLikelihoodInterface(
472
516
RequestGradient ();
473
517
}
474
518
519
+ TreeLikelihoodInterface::TreeLikelihoodInterface (
520
+ const std::vector<std::string> &taxa, const std::vector<std::string> &attributes,
521
+ TreeModelInterface *treeModel, SubstitutionModelInterface *substitutionModel,
522
+ SiteModelInterface *siteModel, std::optional<BranchModelInterface *> branchModel,
523
+ bool use_ambiguities, bool use_tip_states, bool include_jacobian)
524
+ : treeModel_(treeModel),
525
+ substitutionModel_(substitutionModel),
526
+ siteModel_(siteModel) {
527
+ branchModel_ = branchModel.has_value () ? *branchModel : nullptr ;
528
+ const char **taxaPtr = new const char *[taxa.size ()];
529
+ const char **attributesPtr = new const char *[taxa.size ()];
530
+ for (size_t i = 0 ; i < taxa.size (); i++) {
531
+ taxaPtr[i] = const_cast <char *>(taxa[i].c_str ());
532
+ attributesPtr[i] = const_cast <char *>(attributes[i].c_str ());
533
+ }
534
+ DataType *dataType = substitutionModel->GetDataType ()->dataType_ ;
535
+ SitePattern *sitePattern =
536
+ new_AttributePattern (dataType, taxaPtr, attributesPtr, taxa.size ());
537
+ delete[] taxaPtr;
538
+ delete[] attributesPtr;
539
+
540
+ Model *mbm = branchModel.has_value () ? branchModel_->GetModel () : nullptr ;
541
+ BranchModel *bm =
542
+ branchModel.has_value () ? reinterpret_cast <BranchModel *>(mbm->obj ) : nullptr ;
543
+ SingleTreeLikelihood *tlk = new_SingleTreeLikelihood (
544
+ reinterpret_cast <Tree *>(treeModel_->GetManagedObject ()),
545
+ reinterpret_cast <SubstitutionModel *>(substitutionModel_->GetManagedObject ()),
546
+ reinterpret_cast <SiteModel *>(siteModel_->GetManagedObject ()), sitePattern, bm,
547
+ use_tip_states);
548
+ model_ = new_TreeLikelihoodModel (" id" , tlk, treeModel_->GetModel (),
549
+ substitutionModel_->GetModel (),
550
+ siteModel_->GetModel (), mbm);
551
+
552
+ tlk->include_jacobian = include_jacobian;
553
+ RequestGradient ();
554
+ }
555
+
475
556
void TreeLikelihoodInterface::RequestGradient (
476
557
std::vector<TreeLikelihoodGradientFlags> flags) {
477
558
int flags_int = 0 ;
0 commit comments