Skip to content

Commit 5691ea3

Browse files
committed
Add GeneralSubstitutionModel to C++ wrapper
- Implement AttributePattern for use in C++ wrapper - make gradient of GeneralSubstitutionModel slightly faster
1 parent 346c1b2 commit 5691ea3

8 files changed

+237
-30
lines changed

src/phyc/gensubst.c

+30-1
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,43 @@ static void _general_dQdp(SubstitutionModel *m, size_t index){
281281
}
282282
}
283283

284+
void dPdp_with_dQdp_general(SubstitutionModel *m, double* dQ, double* dP, double t){
285+
size_t stateCount = m->nstate;
286+
double *v = m->eigendcmp->eval;
287+
double* temp = dvector(stateCount*stateCount);
288+
for(size_t i = 0; i < stateCount; i++){
289+
temp[i] = exp(v[i]*t);
290+
}
291+
double* dPPtr = dP;
292+
double* dQPtr = m->dQ;
293+
for(size_t i = 0; i < stateCount; i++){
294+
for(size_t j = 0; j < stateCount; j++){
295+
if(v[i] != v[j]){
296+
*dPPtr = *dQPtr*(temp[i] - temp[j])/(v[i]-v[j]);
297+
}
298+
else{
299+
*dPPtr = *dQPtr*t*temp[i];
300+
}
301+
dPPtr++;
302+
dQPtr++;
303+
}
304+
}
305+
Matrix_mult3(temp, (const double**)m->eigendcmp->evec, dP, stateCount, stateCount, stateCount, stateCount);
306+
Matrix_mult4(dP, temp, (const double**)m->eigendcmp->Invevec, stateCount, stateCount, stateCount, stateCount);
307+
free(temp);
308+
}
309+
284310
static void _general_dPdp(SubstitutionModel *m, int index, double* mat, double t){
285311
if(m->need_update){
286312
m->update_Q(m);
287313
m->dQ_need_update = true;
288314
}
289315
if(m->dQ_need_update){
290316
_general_dQdp(m, index);
317+
size_t stateCount = m->nstate;
318+
Matrix_mult3(mat, (const double**)m->eigendcmp->Invevec, m->dQ, stateCount, stateCount, stateCount, stateCount);
319+
Matrix_mult4(m->dQ, mat, (const double**)m->eigendcmp->evec, stateCount, stateCount, stateCount, stateCount);
291320
m->dQ_need_update = false;
292321
}
293-
dPdp_with_dQdp(m, m->dQ, mat, t);
322+
dPdp_with_dQdp_general(m, m->dQ, mat, t);
294323
}

src/phyc/sitepattern.c

+34
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,40 @@ SitePattern * new_SitePattern3( const Sequences *aln, int start, int length, int
318318
return sp;
319319
}
320320

321+
SitePattern * new_AttributePattern( DataType* datatype, const char** taxa, const char** attributes, size_t taxon_count ){
322+
SitePattern *sp = (SitePattern *)malloc( sizeof(SitePattern) );
323+
assert(sp);
324+
sp->id = 0;
325+
326+
sp->datatype = datatype;
327+
sp->datatype->ref_count++;
328+
sp->nstate = sp->datatype->state_count(sp->datatype);
329+
sp->get_partials = _sitepattern_get_partials;
330+
sp->nsites = 1;
331+
sp->count = 1;
332+
sp->size = taxon_count;
333+
sp->indexes = ivector(sp->count);
334+
sp->weights = dvector(sp->count);
335+
sp->indexes[0] = -1;
336+
sp->weights[0] = 1;
337+
338+
sp->patterns = malloc(sizeof(uint8_t*)*taxon_count);
339+
for(size_t i = 0; i < taxon_count; i++){
340+
sp->patterns[i] = malloc(sizeof(uint8_t));
341+
sp->patterns[i][0] = datatype->encoding_string(datatype, attributes[i]);
342+
}
343+
344+
sp->names = (char**)malloc( taxon_count * sizeof(char*) );
345+
assert(sp->names);
346+
347+
for ( size_t i = 0; i < taxon_count; i++) {
348+
sp->names[i] = String_clone( taxa[i] );
349+
}
350+
sp->ref_count = 1;
351+
352+
return sp;
353+
}
354+
321355
void free_SitePattern( SitePattern *sp ){
322356
if(sp->ref_count == 1){
323357
if(sp->patterns != NULL ){

src/phyc/sitepattern.h

+2
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ SitePattern * new_SitePattern( const Sequences *aln );
8585

8686
SitePattern * new_SitePattern2( const Sequences *aln, int start, int length, int every );
8787

88+
SitePattern * new_AttributePattern( DataType* datatype, const char** taxa, const char** attributes, size_t taxon_count );
89+
8890
SitePattern ** SitePattern_split( const SitePattern *sitePattern, const int count );
8991

9092
SitePattern* SitePattern_merge( const SitePattern* sitePattern1, const SitePattern* sitePattern2 );

src/phyc/substmodel.c

+4
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,7 @@ SubstitutionModel * create_substitution_model( const char *name, const modeltype
870870
m->name = String_clone(name);
871871
m->modeltype = modelname;
872872
m->datatype = datatype;
873+
m->datatype->ref_count++;
873874

874875
// Parameters and Q matrix
875876
m->Q = NULL;
@@ -922,6 +923,7 @@ SubstitutionModel * create_substitution_model( const char *name, const modeltype
922923
SubstitutionModel * create_nucleotide_model( const char *name, const modeltype modelname, Simplex* freqs ){
923924
DataType* datatype = new_NucleotideDataType();
924925
SubstitutionModel *m = create_substitution_model(name, modelname, datatype, freqs);
926+
free_DataType(datatype);
925927
m->nstate = 4;
926928
m->Q = dmatrix(m->nstate, m->nstate);
927929
m->PP = dmatrix(m->nstate, m->nstate);
@@ -932,6 +934,7 @@ SubstitutionModel * create_nucleotide_model( const char *name, const modeltype m
932934
SubstitutionModel * create_codon_model( const char *name, const modeltype modelname, unsigned gen_code, Simplex* freqs ){
933935
DataType* datatype = new_CodonDataType(gen_code);
934936
SubstitutionModel *m = create_substitution_model(name, modelname, datatype, freqs);
937+
free_DataType(datatype);
935938
m->nstate = NUMBER_OF_CODONS[gen_code];
936939
m->gen_code = gen_code;
937940
m->Q = dmatrix(m->nstate, m->nstate);
@@ -943,6 +946,7 @@ SubstitutionModel * create_codon_model( const char *name, const modeltype modeln
943946
SubstitutionModel * create_aa_model( const char *name, const modeltype modelname, Simplex* freqs ){
944947
DataType* datatype = new_AminoAcidDataType();
945948
SubstitutionModel *m = create_substitution_model(name, modelname, datatype, freqs);
949+
free_DataType(datatype);
946950
m->nstate = 20;
947951
m->Q = dmatrix(m->nstate, m->nstate);
948952
m->PP = dmatrix(m->nstate, m->nstate);

src/phyc/substmodel.h

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ typedef struct SubstitutionModel{
117117
}SubstitutionModel;
118118

119119
Model * new_SubstitutionModel2( const char* name, SubstitutionModel *sm, Model* freqs_simplex, Model* rates_simplex );
120+
Model * new_SubstitutionModel3( const char* name, SubstitutionModel *sm, Model* freqs_simplex, Model* rates_simplex, Model* discrete_model );
120121

121122
void generale_update_freqs( SubstitutionModel *model );
122123

src/phyc/treelikelihood.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ static double TWENTY_DOUBLE_ONES[20] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,
3535
#define TREELIKELIHOOD_FLAG_TREE 1 << 0
3636
#define TREELIKELIHOOD_FLAG_SITE_MODEL 1 << 1
3737
#define TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL 1 << 2
38-
#define TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED 1 << 3
39-
#define TREELIKELIHOOD_FLAG_BRANCH_MODEL 1 << 4
38+
#define TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_RATES 1 << 3
39+
#define TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_FREQUENCIES 1 << 4
40+
#define TREELIKELIHOOD_FLAG_SUBSTITUTION_MODEL_UNCONSTRAINED 1 << 5
41+
#define TREELIKELIHOOD_FLAG_BRANCH_MODEL 1 << 6
4042

4143
struct _SingleTreeLikelihood;
4244

src/phycpp/physher.cpp

+105-24
Original file line numberDiff line numberDiff line change
@@ -182,40 +182,27 @@ SimpleClockModelInterface::SimpleClockModelInterface(const std::vector<double> &
182182
new_BranchModel2("simple_clock", branchModel_, treeModel->GetModel(), nullptr);
183183
}
184184

185-
Model *SubstitutionModelInterface::Initialize(const std::string &name,
186-
Parameters *rates, Model *frequencies,
187-
Model *rates_model) {
188-
DataType *datatype = new_NucleotideDataType();
189-
Simplex *rates_simplex = rates_model == nullptr
190-
? nullptr
191-
: reinterpret_cast<Simplex *>(rates_model->obj);
192-
substModel_ = SubstitutionModel_factory(
193-
name.c_str(), datatype, reinterpret_cast<Simplex *>(frequencies->obj),
194-
rates_simplex, rates, nullptr);
195-
Model *model =
196-
new_SubstitutionModel2(name.c_str(), substModel_, frequencies, rates_model);
197-
free_DataType(datatype);
198-
return model;
199-
}
200-
201185
JC69Interface::JC69Interface() {
202186
Simplex *frequencies_simplex = new_Simplex("jc69_frequency_simplex", 4);
203187
Model *frequencies_model =
204188
new_SimplexModel("jc69_frequencies", frequencies_simplex);
205-
model_ = Initialize("jc69", nullptr, frequencies_model, nullptr);
189+
dataType_ = new NucleotideDataTypeInterface();
190+
substModel_ = new_JC69(frequencies_simplex);
191+
model_ = new_SubstitutionModel2("jc69", substModel_, frequencies_model, nullptr);
206192
frequencies_model->free(frequencies_model);
207193
parameterCount_ = 0;
208194
}
209195

210196
HKYInterface::HKYInterface(double kappa, const std::vector<double> &frequencies) {
211-
Parameters *kappa_parameters = new_Parameters(1);
212-
Parameters_move(kappa_parameters,
213-
new_Parameter("hky_kappa", kappa, new_Constraint(0, INFINITY)));
197+
Parameter *kappa_parameter =
198+
new_Parameter("hky_kappa", kappa, new_Constraint(0, INFINITY));
214199
Simplex *frequencies_simplex = new_Simplex_with_values(
215200
"hky_frequency_simplex", frequencies.data(), frequencies.size());
216201
Model *frequencies_model = new_SimplexModel("hky_frequencies", frequencies_simplex);
217-
model_ = Initialize("hky", kappa_parameters, frequencies_model, nullptr);
218-
free_Parameters(kappa_parameters);
202+
dataType_ = new NucleotideDataTypeInterface();
203+
substModel_ = new_HKY_with_parameters(frequencies_simplex, kappa_parameter);
204+
model_ = new_SubstitutionModel2("hky", substModel_, frequencies_model, nullptr);
205+
free_Parameter(kappa_parameter);
219206
frequencies_model->free(frequencies_model);
220207
parameterCount_ = 5;
221208
}
@@ -255,7 +242,17 @@ GTRInterface::GTRInterface(const std::vector<double> &rates,
255242
Simplex *frequencies_simplex = new_Simplex_with_values(
256243
"gtr_frequency_simplex", frequencies.data(), frequencies.size());
257244
Model *frequencies_model = new_SimplexModel("gtr_frequencies", frequencies_simplex);
258-
model_ = Initialize("gtr", rates_parameters, frequencies_model, rates_model);
245+
246+
dataType_ = new NucleotideDataTypeInterface();
247+
const char *assigments[5] = {"AC", "AG", "AT", "CG", "CT"};
248+
Simplex *rates_simplex = rates_model == nullptr
249+
? nullptr
250+
: reinterpret_cast<Simplex *>(rates_model->obj);
251+
substModel_ =
252+
SubstitutionModel_factory("gtr", dataType_->dataType_,
253+
reinterpret_cast<Simplex *>(frequencies_model->obj),
254+
rates_simplex, rates_parameters, assigments);
255+
model_ = new_SubstitutionModel2("gtr", substModel_, frequencies_model, rates_model);
259256
free_Parameters(rates_parameters);
260257
frequencies_model->free(frequencies_model);
261258
if (rates_model != nullptr) {
@@ -278,6 +275,53 @@ void GTRInterface::SetParameters(const double *parameters) {
278275
substModel_->simplex->set_values(substModel_->simplex, parameters + 3);
279276
}
280277

278+
GeneralSubstitutionModelInterface::GeneralSubstitutionModelInterface(
279+
DataTypeInterface *dataType, const std::vector<double> &rates,
280+
const std::vector<double> &frequencies, const std::vector<unsigned> &structure,
281+
bool normalize) {
282+
Parameters *rate_parameters = new_Parameters(rates.size());
283+
for (auto rate : rates) {
284+
Parameters_move(rate_parameters, new_Parameter("subst_rates", rate,
285+
new_Constraint(0, INFINITY)));
286+
}
287+
288+
parameterCount_ = rates.size();
289+
Simplex *frequencies_simplex = new_Simplex_with_values(
290+
"subst_frequency_simplex", frequencies.data(), frequencies.size());
291+
Model *frequencies_model =
292+
new_SimplexModel("subst_frequencies", frequencies_simplex);
293+
294+
DiscreteParameter *dp =
295+
new_DiscreteParameter_with_values(structure.data(), structure.size());
296+
Model *mdp = new_DiscreteParameterModel("structure", dp);
297+
298+
substModel_ = new_GeneralModel_with_parameters(
299+
dataType->dataType_, reinterpret_cast<DiscreteParameter *>(mdp->obj),
300+
rate_parameters, frequencies_simplex, -1, normalize);
301+
model_ = new_SubstitutionModel3("id", substModel_, frequencies_model, nullptr, mdp);
302+
303+
free_Parameters(rate_parameters);
304+
frequencies_model->free(frequencies_model);
305+
dataType_ = dataType;
306+
}
307+
308+
void GeneralSubstitutionModelInterface::SetRates(const double *rates) {
309+
if (substModel_->rates == nullptr) {
310+
substModel_->rates_simplex->set_values(substModel_->rates_simplex, rates);
311+
} else {
312+
Parameters_set_values(substModel_->rates, rates);
313+
}
314+
}
315+
316+
void GeneralSubstitutionModelInterface::SetFrequencies(const double *frequencies) {
317+
substModel_->simplex->set_values(substModel_->simplex, frequencies);
318+
}
319+
320+
void GeneralSubstitutionModelInterface::SetParameters(const double *parameters) {
321+
Parameters_set_values(substModel_->rates, parameters);
322+
substModel_->simplex->set_values(substModel_->simplex, parameters + 3);
323+
}
324+
281325
void SiteModelInterface::SetMu(double mu) { Parameter_set_value(siteModel_->mu, mu); }
282326

283327
ConstantSiteModelInterface::ConstantSiteModelInterface(std::optional<double> mu) {
@@ -453,7 +497,7 @@ TreeLikelihoodInterface::TreeLikelihoodInterface(
453497
Sequences_add(sequences,
454498
new_Sequence(sequence.first.c_str(), sequence.second.c_str()));
455499
}
456-
sequences->datatype = new_NucleotideDataType();
500+
sequences->datatype = substitutionModel->GetDataType()->dataType_;
457501
SitePattern *sitePattern = new_SitePattern(sequences);
458502
free_Sequences(sequences);
459503
Model *mbm = branchModel.has_value() ? branchModel_->GetModel() : nullptr;
@@ -472,6 +516,43 @@ TreeLikelihoodInterface::TreeLikelihoodInterface(
472516
RequestGradient();
473517
}
474518

519+
TreeLikelihoodInterface::TreeLikelihoodInterface(
520+
const std::vector<std::string> &taxa, const std::vector<std::string> &attributes,
521+
TreeModelInterface *treeModel, SubstitutionModelInterface *substitutionModel,
522+
SiteModelInterface *siteModel, std::optional<BranchModelInterface *> branchModel,
523+
bool use_ambiguities, bool use_tip_states, bool include_jacobian)
524+
: treeModel_(treeModel),
525+
substitutionModel_(substitutionModel),
526+
siteModel_(siteModel) {
527+
branchModel_ = branchModel.has_value() ? *branchModel : nullptr;
528+
const char **taxaPtr = new const char *[taxa.size()];
529+
const char **attributesPtr = new const char *[taxa.size()];
530+
for (size_t i = 0; i < taxa.size(); i++) {
531+
taxaPtr[i] = const_cast<char *>(taxa[i].c_str());
532+
attributesPtr[i] = const_cast<char *>(attributes[i].c_str());
533+
}
534+
DataType *dataType = substitutionModel->GetDataType()->dataType_;
535+
SitePattern *sitePattern =
536+
new_AttributePattern(dataType, taxaPtr, attributesPtr, taxa.size());
537+
delete[] taxaPtr;
538+
delete[] attributesPtr;
539+
540+
Model *mbm = branchModel.has_value() ? branchModel_->GetModel() : nullptr;
541+
BranchModel *bm =
542+
branchModel.has_value() ? reinterpret_cast<BranchModel *>(mbm->obj) : nullptr;
543+
SingleTreeLikelihood *tlk = new_SingleTreeLikelihood(
544+
reinterpret_cast<Tree *>(treeModel_->GetManagedObject()),
545+
reinterpret_cast<SubstitutionModel *>(substitutionModel_->GetManagedObject()),
546+
reinterpret_cast<SiteModel *>(siteModel_->GetManagedObject()), sitePattern, bm,
547+
use_tip_states);
548+
model_ = new_TreeLikelihoodModel("id", tlk, treeModel_->GetModel(),
549+
substitutionModel_->GetModel(),
550+
siteModel_->GetModel(), mbm);
551+
552+
tlk->include_jacobian = include_jacobian;
553+
RequestGradient();
554+
}
555+
475556
void TreeLikelihoodInterface::RequestGradient(
476557
std::vector<TreeLikelihoodGradientFlags> flags) {
477558
int flags_int = 0;

0 commit comments

Comments
 (0)