diff --git a/.github/workflows/docker-ci.yml b/.github/workflows/docker-ci.yml index c8b3e748..32d8560b 100644 --- a/.github/workflows/docker-ci.yml +++ b/.github/workflows/docker-ci.yml @@ -191,6 +191,22 @@ jobs: docker load --input /tmp/pysages.tar docker run -t pysages bash -c "cd PySAGES/examples/openmm/cv_ermsd/ermsd_cg && python3 ./unbiased.py && python3 ./check_ermsd.py" + cv-q-openmm-abf: + runs-on: ubuntu-latest + needs: build + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Download artifact + uses: actions/download-artifact@v3 + with: + name: pysages + path: /tmp + - name: Load and run test + run: | + docker load --input /tmp/pysages.tar + docker run -t pysages bash -c "cd PySAGES/examples/openmm/cv_Q && python3 ./abf.py && python3 ./check_Q.py && python3 ./check_force.py" + alanine-dipeptide-string: runs-on: ubuntu-latest needs: build diff --git a/docs/source/module-pysages-colvars-angles.rst b/docs/source/module-pysages-colvars-angles.rst index 2b3a279a..7cf25440 100644 --- a/docs/source/module-pysages-colvars-angles.rst +++ b/docs/source/module-pysages-colvars-angles.rst @@ -7,6 +7,7 @@ Angles as collective variables pysages.colvars.angles.Angle pysages.colvars.angles.DihedralAngle + pysages.colvars.angles.VectorAngle pysages.colvars.angles.RingPuckeringCoordinates pysages.colvars.angles.RingPhaseAngle pysages.colvars.angles.RingAmplitude diff --git a/docs/source/module-pysages-colvars-contacts.rst b/docs/source/module-pysages-colvars-contacts.rst new file mode 100644 index 00000000..914398e3 --- /dev/null +++ b/docs/source/module-pysages-colvars-contacts.rst @@ -0,0 +1,14 @@ +Contacts as collective variables +------------------------------------ + +.. rubric:: Overview + +.. autosummary:: + + pysages.colvars.contacts.NativeContactFraction + +.. rubric:: Details + +.. automodule:: pysages.colvars.contacts + :synopsis: Python classes Contacts collective variables. + :members: diff --git a/docs/source/module-pysages-colvars.rst b/docs/source/module-pysages-colvars.rst index bf2d1419..a288897e 100644 --- a/docs/source/module-pysages-colvars.rst +++ b/docs/source/module-pysages-colvars.rst @@ -22,6 +22,8 @@ Collective Variables available in PySAGES pysages.colvars.orientation.ERMSD pysages.colvars.orientation.ERMSDCG + pysages.colvars.contacts.NativeContactFraction + Abstract base classes .. autosummary:: @@ -41,6 +43,7 @@ Abstract base classes module-pysages-colvars-coordinates module-pysages-colvars-core module-pysages-colvars-orientation + module-pysages-colvars-contacts .. automodule:: pysages.colvars :synopsis: Python classes for collective variables. diff --git a/docs/source/pysages_wordlist.txt b/docs/source/pysages_wordlist.txt index ee241703..c702b8d9 100644 --- a/docs/source/pysages_wordlist.txt +++ b/docs/source/pysages_wordlist.txt @@ -96,3 +96,4 @@ unitless nm nucleotides nt +Mittal diff --git a/examples/openmm/cv_Q/abf.py b/examples/openmm/cv_Q/abf.py new file mode 100644 index 00000000..8fc3bbc9 --- /dev/null +++ b/examples/openmm/cv_Q/abf.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python + +import numpy as np +import openmm as mm +import openmm.app as app +from openmm import unit +from scipy.spatial import distance as sd + +import pysages +from pysages import Grid +from pysages.colvars.contacts import NativeContactFraction +from pysages.methods import ABF, HistogramLogger + +AUGC = ["A", "U", "G", "C"] + + +def create_exclusions_from_bonds(particles, bonds, bond_cutoff=3): + """ + create exclusion from bond. + """ + n_particles = max(particles) + 1 + exclusions = [set() for _ in range(n_particles)] + bonded12 = [set() for _ in range(n_particles)] + for bond in bonds: + p1, p2 = bond + exclusions[p1].add(p2) + exclusions[p2].add(p1) + bonded12[p1].add(p2) + bonded12[p2].add(p1) + + for level in range(bond_cutoff - 1): + current_exclusions = [exclusion.copy() for exclusion in exclusions] + for i in range(n_particles): + for j in current_exclusions[i]: + exclusions[j].update(bonded12[i]) + + final_exclusions = [] + for i in range(len(exclusions)): + for j in exclusions[i]: + if j < i: + final_exclusions.append((j, i)) + + return final_exclusions + + +step_size = 2 * unit.femtosecond +nsteps = 100 + +contact_cutoff = 0.5 # nanometer + +pdb = app.PDBFile("../../inputs/GAGA.box_0mM.pdb") +positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + +rna_indices = [] +for atom in pdb.topology.atoms(): + if atom.residue.name in AUGC and atom.element.name != "hydrogen": + rna_indices.append(atom.index) + +rna_bonds = [] +for bond in pdb.topology.bonds(): + if ( + bond.atom1.residue.name in AUGC + and bond.atom1.element.name != "hydrogen" + and bond.atom2.residue.name in AUGC + and bond.atom2.element.name != "hydrogen" + ): + rna_bonds.append([bond.atom1.index, bond.atom2.index]) + +exclusions = create_exclusions_from_bonds(rna_indices, rna_bonds) + +rna_pos = positions.astype("float")[np.asarray(rna_indices)] +contact_matrix = sd.squareform(sd.pdist(rna_pos)) < contact_cutoff +contacts = np.transpose(np.nonzero(contact_matrix)) +rna_id_contacts = np.array( + [ + [rna_indices[i], rna_indices[j]] + for i, j in contacts + if i != j + and (rna_indices[i], rna_indices[j]) not in exclusions + and (rna_indices[j], rna_indices[i]) not in exclusions + ] +) +# notice that we need to get rid of self-self contact! +indices = np.unique(rna_id_contacts) +references = positions.astype("float")[np.asarray(indices)] + + +def generate_simulation(): + forcefield = app.ForceField("amber14-all.xml", "amber14/tip3pfb.xml") + system = forcefield.createSystem( + pdb.topology, + nonbondedMethod=app.PME, + nonbondedCutoff=1.2 * unit.nanometer, + constraints=app.HBonds, + ) + + integrator = mm.LangevinIntegrator( + 298 * unit.kelvin, 5 / unit.picosecond, step_size.in_units_of(unit.picosecond) + ) + simulation = app.Simulation(pdb.topology, system, integrator) + simulation.context.setPositions(pdb.positions) + + print("minimizing energy...") + simulation.minimizeEnergy() + + print("Using {} platform".format(simulation.context.getPlatform().getName())) + simulation.reporters.append(app.PDBReporter("output.pdb", 1, enforcePeriodicBox=False)) + simulation.reporters.append( + app.StateDataReporter( + "log", + 1, + step=True, + time=True, + speed=True, + remainingTime=True, + elapsedTime=True, + totalSteps=nsteps, + ) + ) + + return simulation + + +def main(): + cvs = [ + NativeContactFraction(indices, rna_id_contacts, references, clip=True), + ] + + method = ABF(cvs, Grid(lower=(0,), upper=(1,), shape=(32,)), use_pinv=True) + callback = HistogramLogger(1) + + raw_result = pysages.run(method, generate_simulation, nsteps, callback) + pysages.save(raw_result, "state.pkl") + + np.savetxt("Q.txt", raw_result.callbacks[0].data[:, :1]) + np.savetxt("references.txt", references) + np.save("contact_pairs.npy", rna_id_contacts) + np.save("contact_pairs_remapped.npy", cvs[0].contact_pairs) + + +if __name__ == "__main__": + main() diff --git a/examples/openmm/cv_Q/check_Q.py b/examples/openmm/cv_Q/check_Q.py new file mode 100644 index 00000000..f130089a --- /dev/null +++ b/examples/openmm/cv_Q/check_Q.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# the ermsd calculation code is taken from barnaba package +# github.com/srnas/barnaba + +import numpy as np +from openmm import app, unit +from scipy.spatial import distance as sd + + +def calc_q(contact_dists, contact_dists0, gamma=50, lambda_d=1.5, clip=False, clip_val=5): + N_contacts = len(contact_dists) + diff = gamma * (contact_dists - lambda_d * contact_dists0) + if clip: + diff = np.clip(diff, None, clip_val) + + Q = 1 / N_contacts * np.sum(1 / (1 + np.exp(diff))) + + return Q + + +def pos2q(pos, contact_pairs, contact_dists0, gamma=50, lambda_d=1.5, clip=False, clip_val=5): + dist_matrix = sd.squareform(sd.pdist(pos)) + contact_dists = dist_matrix[contact_pairs[:, 0], contact_pairs[:, 1]] + Q = calc_q( + contact_dists, contact_dists0, gamma=gamma, lambda_d=lambda_d, clip=clip, clip_val=clip_val + ) + + return Q + + +contact_pairs = np.load("contact_pairs.npy", allow_pickle=True) +contact_pairs_remapped = np.load("contact_pairs_remapped.npy", allow_pickle=True) +references = np.loadtxt("references.txt") +dist_matrix_ref = sd.squareform(sd.pdist(references)) +contact_dists0 = dist_matrix_ref[contact_pairs_remapped[:, 0], contact_pairs_remapped[:, 1]] + +traj = app.PDBFile("output.pdb") +n_frames = traj.getNumFrames() +Q_posthoc = [] +for i in range(n_frames): + pos = traj.getPositions(asNumpy=True, frame=i).value_in_unit(unit.nanometer).astype("float") + Q_hot = pos2q(pos, contact_pairs, contact_dists0, gamma=50, lambda_d=1.5, clip=True) + Q_posthoc.append(Q_hot) + +Q_pysages = np.loadtxt("Q.txt") +np.savetxt("Q_posthoc.txt", Q_posthoc) + +assert ( + np.mean((Q_pysages - Q_posthoc) ** 2) < 1e-6 +), "the difference between pysages Q and post-hoc calculation is too large!" + +print("checking for Q passed!") diff --git a/examples/openmm/cv_Q/check_force.py b/examples/openmm/cv_Q/check_force.py new file mode 100755 index 00000000..863eb86c --- /dev/null +++ b/examples/openmm/cv_Q/check_force.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python + +import numpy as np +import openmm.app as app +from jax import grad +from openmm import unit +from scipy.spatial import distance as sd + +from pysages.colvars.contacts import NativeContactFraction + +contact_cutoff = 0.5 # nm + +pdb = app.PDBFile("../../inputs/GAGA.box_0mM.pdb") +positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) +rna_indices = [] +for i, residue in enumerate(pdb.topology.residues()): + if residue.name in ["A", "U", "G", "C"]: + for atom in residue.atoms(): + if atom.element.name != "hydrogen": + rna_indices.append(atom.index) + +rna_pos = positions.astype("float")[np.asarray(rna_indices)] +contact_matrix = sd.squareform(sd.pdist(rna_pos)) < contact_cutoff +contacts = np.transpose(np.nonzero(contact_matrix)) +rna_id_contacts = np.array([[rna_indices[i], rna_indices[j]] for i, j in contacts if i != j]) +indices = np.unique(rna_id_contacts) +references = positions.astype("float")[np.asarray(indices)] +ncf = NativeContactFraction(indices, rna_id_contacts, references) +ncf_grad = grad(ncf.function) +grad_hot = ncf_grad(references + np.random.random(references.shape)) +assert not np.any(np.isnan(grad_hot)), "force contains NaN values" +print("checking for forces of Q passed!") diff --git a/examples/openmm/cv_Q/contact_pairs.npy b/examples/openmm/cv_Q/contact_pairs.npy new file mode 100644 index 00000000..9f05251e Binary files /dev/null and b/examples/openmm/cv_Q/contact_pairs.npy differ diff --git a/examples/openmm/cv_Q/contact_pairs_remapped.npy b/examples/openmm/cv_Q/contact_pairs_remapped.npy new file mode 100644 index 00000000..5ab0f0bc Binary files /dev/null and b/examples/openmm/cv_Q/contact_pairs_remapped.npy differ diff --git a/examples/openmm/cv_Q/log b/examples/openmm/cv_Q/log new file mode 100644 index 00000000..8af35127 --- /dev/null +++ b/examples/openmm/cv_Q/log @@ -0,0 +1,101 @@ +#"Step","Time (ps)","Speed (ns/day)","Elapsed Time (s)","Time Remaining" +1,0.002,0,0.00017595291137695312,-- +2,0.004,0.21,0.8229403495788574,1:20 +3,0.006,0.238,1.453230381011963,1:10 +4,0.008,0.338,1.5342180728912354,0:49 +5,0.01,0.428,1.6147615909576416,0:38 +6,0.012,0.51,1.6954145431518555,0:31 +7,0.014,0.57,1.8184168338775635,0:28 +8,0.016,0.636,1.9012887477874756,0:24 +9,0.018000000000000002,0.697,1.9821698665618896,0:22 +10,0.020000000000000004,0.754,2.06231951713562,0:20 +11,0.022000000000000006,0.804,2.1495120525360107,0:19 +12,0.024000000000000007,0.852,2.2315711975097656,0:17 +13,0.02600000000000001,0.897,2.312931776046753,0:16 +14,0.02800000000000001,0.938,2.393878698348999,0:15 +15,0.030000000000000013,0.977,2.4750168323516846,0:15 +16,0.032000000000000015,1.01,2.5556459426879883,0:14 +17,0.034000000000000016,1.05,2.636949062347412,0:13 +18,0.03600000000000002,1.06,2.7603514194488525,0:13 +19,0.03800000000000002,1.09,2.844087839126587,0:12 +20,0.04000000000000002,1.12,2.9256513118743896,0:12 +21,0.04200000000000002,1.15,3.006678581237793,0:11 +22,0.044000000000000025,1.18,3.0876474380493164,0:11 +23,0.04600000000000003,1.2,3.1686651706695557,0:11 +24,0.04800000000000003,1.22,3.250168561935425,0:10 +25,0.05000000000000003,1.24,3.331664800643921,0:10 +26,0.05200000000000003,1.27,3.4126813411712646,0:10 +27,0.054000000000000034,1.27,3.5363223552703857,0:09 +28,0.056000000000000036,1.29,3.620398759841919,0:09 +29,0.05800000000000004,1.31,3.701617479324341,0:09 +30,0.06000000000000004,1.32,3.783266305923462,0:09 +31,0.06200000000000004,1.34,3.86579966545105,0:08 +32,0.06400000000000004,1.36,3.9472033977508545,0:08 +33,0.06600000000000004,1.37,4.028719902038574,0:08 +34,0.06800000000000005,1.39,4.109215259552002,0:08 +35,0.07000000000000005,1.4,4.1900269985198975,0:08 +36,0.07200000000000005,1.42,4.2716147899627686,0:07 +37,0.07400000000000005,1.42,4.3958916664123535,0:07 +38,0.07600000000000005,1.43,4.479905605316162,0:07 +39,0.07800000000000006,1.44,4.561229467391968,0:07 +40,0.08000000000000006,1.45,4.642746925354004,0:07 +41,0.08200000000000006,1.46,4.725122690200806,0:06 +42,0.08400000000000006,1.47,4.807098865509033,0:06 +43,0.08600000000000006,1.48,4.888307809829712,0:06 +44,0.08800000000000006,1.5,4.969586610794067,0:06 +45,0.09000000000000007,1.51,5.051191091537476,0:06 +46,0.09200000000000007,1.52,5.132446527481079,0:06 +47,0.09400000000000007,1.52,5.214017629623413,0:06 +48,0.09600000000000007,1.52,5.342374563217163,0:05 +49,0.09800000000000007,1.53,5.426711320877075,0:05 +50,0.10000000000000007,1.54,5.5087950229644775,0:05 +51,0.10200000000000008,1.55,5.590127468109131,0:05 +52,0.10400000000000008,1.55,5.670928716659546,0:05 +53,0.10600000000000008,1.56,5.751867055892944,0:05 +54,0.10800000000000008,1.57,5.832394361495972,0:05 +55,0.11000000000000008,1.58,5.913583517074585,0:04 +56,0.11200000000000009,1.59,5.995504856109619,0:04 +57,0.11400000000000009,1.58,6.120810031890869,0:04 +58,0.11600000000000009,1.59,6.2051613330841064,0:04 +59,0.11800000000000009,1.59,6.286706924438477,0:04 +60,0.12000000000000009,1.6,6.367578744888306,0:04 +61,0.1220000000000001,1.61,6.448859930038452,0:04 +62,0.1240000000000001,1.61,6.530088663101196,0:04 +63,0.12600000000000008,1.62,6.612604141235352,0:03 +64,0.12800000000000009,1.63,6.693545818328857,0:03 +65,0.1300000000000001,1.63,6.7746665477752686,0:03 +66,0.1320000000000001,1.64,6.855661392211914,0:03 +67,0.1340000000000001,1.64,6.936746597290039,0:03 +68,0.1360000000000001,1.64,7.06310510635376,0:03 +69,0.1380000000000001,1.64,7.1471946239471436,0:03 +70,0.1400000000000001,1.65,7.2290332317352295,0:03 +71,0.1420000000000001,1.65,7.312357425689697,0:03 +72,0.1440000000000001,1.66,7.393218040466309,0:02 +73,0.1460000000000001,1.66,7.473695278167725,0:02 +74,0.1480000000000001,1.67,7.555282354354858,0:02 +75,0.1500000000000001,1.67,7.637059450149536,0:02 +76,0.1520000000000001,1.68,7.718220233917236,0:02 +77,0.1540000000000001,1.68,7.799330472946167,0:02 +78,0.1560000000000001,1.68,7.929299831390381,0:02 +79,0.1580000000000001,1.68,8.013344287872314,0:02 +80,0.16000000000000011,1.69,8.094869375228882,0:02 +81,0.16200000000000012,1.69,8.176614761352539,0:01 +82,0.16400000000000012,1.69,8.258145809173584,0:01 +83,0.16600000000000012,1.7,8.340224504470825,0:01 +84,0.16800000000000012,1.7,8.421715021133423,0:01 +85,0.17000000000000012,1.71,8.503269672393799,0:01 +86,0.17200000000000013,1.71,8.585331201553345,0:01 +87,0.17400000000000013,1.71,8.666454792022705,0:01 +88,0.17600000000000013,1.72,8.747462272644043,0:01 +89,0.17800000000000013,1.71,8.872422695159912,0:01 +90,0.18000000000000013,1.72,8.956512451171875,0:01 +91,0.18200000000000013,1.72,9.037654161453247,0:00 +92,0.18400000000000014,1.72,9.119468927383423,0:00 +93,0.18600000000000014,1.73,9.201036214828491,0:00 +94,0.18800000000000014,1.73,9.282272815704346,0:00 +95,0.19000000000000014,1.73,9.363948345184326,0:00 +96,0.19200000000000014,1.74,9.445691585540771,0:00 +97,0.19400000000000014,1.74,9.526889324188232,0:00 +98,0.19600000000000015,1.74,9.608045816421509,0:00 +99,0.19800000000000015,1.75,9.690306901931763,0:00 +100,0.20000000000000015,1.74,9.817447423934937,0:00 diff --git a/examples/openmm/cv_Q/plot_Q.py b/examples/openmm/cv_Q/plot_Q.py new file mode 100644 index 00000000..d619d025 --- /dev/null +++ b/examples/openmm/cv_Q/plot_Q.py @@ -0,0 +1,24 @@ +import matplotlib.pyplot as plt +import numpy as np + +a = np.loadtxt("Q.txt") + +b = np.loadtxt("Q_posthoc.txt") + +mi, ma = np.min(a), np.max(a) + +plt.figure(figsize=(5, 5)) + +plt.scatter(a[1::], b[:-1], alpha=0.5) +# plt.plot(a[1::], '-o') +# plt.plot(b[:-1], '-o') + +plt.plot([mi * 0.999, ma * 1.001], [mi * 0.999, ma * 1.001]) + +plt.xlim(mi * 0.999, ma * 1.001) + +plt.ylim(mi * 0.999, ma * 1.001) +plt.xlabel("Q from pysages", fontsize=15) +plt.ylabel("Q from post-hoc calculation", fontsize=15) +plt.savefig("Q_comparison.png", dpi=300) +plt.show() diff --git a/examples/openmm/cv_Q/state.pkl b/examples/openmm/cv_Q/state.pkl new file mode 100644 index 00000000..c8d18d95 Binary files /dev/null and b/examples/openmm/cv_Q/state.pkl differ diff --git a/pysages/colvars/angles.py b/pysages/colvars/angles.py index 2c9c2cdf..eec86453 100644 --- a/pysages/colvars/angles.py +++ b/pysages/colvars/angles.py @@ -40,18 +40,34 @@ def function(self): Function that calculates the angle value from a simulation snapshot. Look at `pysages.colvars.angles.angle` for details. """ + return lambda p1, p2, p3: angle(p2, p1, p2, p3) + + +class VectorAngle(FourPointCV): + """ + Compute the angle formed by two vectors (four points). + Notice that the two vectors are defined as p1->p2, p3->p4. + """ + + @property + def function(self): + """ + Returns + -------- + Functions that calculates the angle value from a simulation snapshot. + """ return angle -def angle(p1, p2, p3): +def angle(p1, p2, p3, p4): r""" - Calculates angle between 3 points in space. + Calculates angle between two vectors (4 points) in space. - Takes 3 positions in space and calculates the angle between them. + Takes 4 positions in space and calculates the angle between them. - :math:`\vec{q} = \vec{p}_1 - \vec{p}_2` + :math:`\vec{q} = \vec{p}_2 - \vec{p}_1` - :math:`\vec{r} = \vec{p}_3 - \vec{p}_2` + :math:`\vec{r} = \vec{p}_4 - \vec{p}_3` :math:`\theta = \arctan(|\vec{q} \times \vec{r}|, \vec{q} \cdot \vec{r})` @@ -63,14 +79,17 @@ def angle(p1, p2, p3): :math:`\vec{p}_2` 3D vector in space p3: jax.Array :math:`\vec{p}_3` 3D vector in space + p4: jax.Array + :math:`\vec{p}_3` 3D vector in space Returns ------- float :math:`\theta` """ - q = p1 - p2 - r = p3 - p2 + + q = p2 - p1 + r = p4 - p3 return np.arctan2(linalg.norm(np.cross(q, r)), np.dot(q, r)) @@ -131,7 +150,8 @@ def dihedral_angle(p1, p2, p3, p4): @multicomponent class RingPuckeringCoordinates(CollectiveVariable): """ - Computes the amplitude and the phase angle of a monocyclic ring by the Cremer-Pople method. + Computes the amplitude and the phase angle of a monocyclic ring + by the Cremer-Pople method. Mathematical definitions can be found in [D. Cremer and J. A. Pople, JACS, 1974](https://pubs.acs.org/doi/10.1021/ja00839a011) Equations 4-14. diff --git a/pysages/colvars/contacts.py b/pysages/colvars/contacts.py new file mode 100644 index 00000000..9c6c1869 --- /dev/null +++ b/pysages/colvars/contacts.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Collective variable related to native contacts. + +Q describes the fraction of the native contacts +for folding of proteins or nucleic acids +""" +from jax import numpy as np + +from pysages.colvars.core import CollectiveVariable + + +def remap_indices(contacts, sorted_unique_indices): + # Generate the mapping by using argsort on sorted_unique_indices + old_indices = np.array(sorted_unique_indices) + new_indices = np.argsort(old_indices) + + index_mapping = np.zeros(old_indices.max() + 1, dtype=new_indices.dtype) + index_mapping = index_mapping.at[old_indices].set(new_indices) + + contacts = np.array(contacts) + remapped_contacts = index_mapping[contacts] + + return remapped_contacts + + +class NativeContactFraction(CollectiveVariable): + r""" + Calculate the native contact fraction Q. + + Parameters + ------------ + indices: list[int] + List of the indices of the atoms of interests. + indices should be equal to unique(contact_pairs) + contact_pairs: list[tuple(int)] + List of pairs of indices of atoms (list of contacts) to be considered. + The overall shape should be (n_contacts, 2). + These pairs are usually generated by a cutoff radius. + For example, in a system with 5 atoms, if in the reference structure, + only 2-3, 2-4 and 1-5 are in contact with each other, + we should pass in [(2, 3), (2, 4), (1, 5)] + reference: list[tuple(float)] + Reference coordinates for the selected atoms. + gamma: float + Smoothing parameter. Default is 50 nm^-1. + lambda_d: float + Scaling factor for the distances. + clip: bool + Clip the :math:`\gamma(r_{ij}-\lambda_d*r_{ij}^0)` by some value. + Default is False + clip_val: float + Default is 5. + """ + + def __init__( + self, indices, contact_pairs, references, gamma=50, lambda_d=1.5, clip=False, clip_val=5 + ): + super().__init__(indices) + indices = np.sort(np.asarray(indices)) + contact_pairs = np.asarray(contact_pairs) + assert all( + indices == np.unique(contact_pairs) + ), "contact pairs should contain and only contain index from indices" + assert not np.any( + contact_pairs[:, 0] == contact_pairs[:, 1] + ), "contact pairs contain self-self interaction" + self.contact_pairs = remap_indices(contact_pairs, indices) + + self.references = np.asarray(references) + self.gamma = gamma + self.lambda_d = lambda_d + self.clip = clip + self.clip_val = clip_val + + @property + def function(self): + return lambda r: native_contact_fraction( + r, + self.contact_pairs, + self.references, + self.gamma, + self.lambda_d, + self.clip, + self.clip_val, + ) + + +def native_contact_fraction(r, contact_pairs, references, gamma, lambda_d, clip, clip_val): + r""" + Calculate the native contact fraction Q. + Mathematical details can be found in + [Best, Mittal, JPCB, 2010](https://pubs.acs.org/doi/10.1021/jp102575b) + + :math:`Q=\frac{1}{N_\mathrm{contacts}}\sum_{(i, j)} \ + \frac{1}{1+\exp(\gamma(r_{ij}-\lambda_d r_{ij}^0))}` + + Parameters + ---------- + r: (n_atoms, 3) array + Current positions of the atoms. + contact_pairs: (n_contacts, 2) array + pairs of indices of atoms that are in contact. + references: (n_atoms, 3) array + Reference contact distances. + gamma: float + Smoothing parameter. + lambda_d: float + Scaling factor for the distances. + clip: bool + Clip the :math:`\gamma(r_{ij}-\lambda_d*r_{ij}^0)` by some value. + clip_val: float + clip value + + Returns + ------- + Q: float + Native contact fraction. + + + """ + # Calculate pairwise distances + distance_matrix = r[:, None, :] - r[None, :, :] + N = distance_matrix.shape[0] + diagonal_indices = np.arange(N), np.arange(N) + mask_val = 100 # mask trick to avoid divergence + mask_array = np.array([mask_val, mask_val, mask_val]) + distance_matrix = distance_matrix.at[diagonal_indices].set(mask_array) + distances = np.linalg.norm(distance_matrix, axis=-1) + + reference_distance_matrix = references[:, None, :] - references[None, :, :] + reference_distance_matrix = reference_distance_matrix.at[diagonal_indices].set(mask_array) + reference_distances = np.linalg.norm(reference_distance_matrix, axis=-1) + + distances_contacts = distances[contact_pairs[:, 0], contact_pairs[:, 1]] + reference_distances_contacts = reference_distances[contact_pairs[:, 0], contact_pairs[:, 1]] + diff = gamma * (distances_contacts - lambda_d * reference_distances_contacts) + if clip: + diff = np.clip(diff, None, clip_val) + Q_contribution = 1 / (1 + np.exp(diff)) + + N_contacts = contact_pairs.shape[0] + + return np.sum(Q_contribution) / N_contacts diff --git a/pysages/colvars/orientation.py b/pysages/colvars/orientation.py index 867cdce0..6736d1ad 100644 --- a/pysages/colvars/orientation.py +++ b/pysages/colvars/orientation.py @@ -21,7 +21,7 @@ from jax.numpy import linalg from pysages.colvars.coordinates import weighted_barycenter -from pysages.colvars.core import CollectiveVariable, multicomponent +from pysages.colvars.core import CollectiveVariable def fitted_positions(positions, weights): @@ -126,7 +126,6 @@ def function(self): return lambda r: rmsd(r, self.Q, self.weights, self.optimal_rotation) -@multicomponent class ERMSD(CollectiveVariable): """ Use a reference to calculate the eRMSD of a set of RNA structures. @@ -378,7 +377,6 @@ def ermsd(rs, reference, cutoff, a, b): return ermsd_core(rs, reference, cutoff, a, b) -@multicomponent class ERMSDCG(CollectiveVariable): """ Use a reference to calculate the eRMSD of