Skip to content

Commit

Permalink
Add options for abf to use pseudo inverse (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz authored Aug 22, 2024
2 parents 2a2b4c3 + 3058016 commit 9c8d2a4
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 17 deletions.
3 changes: 1 addition & 2 deletions examples/openmm/abf/alanine-dipeptide_openmm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3


import matplotlib.pyplot as plt
import numpy

Expand Down Expand Up @@ -115,7 +114,7 @@ def post_run_action(**kwargs):
def main():
cvs = [DihedralAngle((4, 6, 8, 14)), DihedralAngle((6, 8, 14, 16))]
grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(32, 32), periodic=True)
method = ABF(cvs, grid)
method = ABF(cvs, grid, use_pinv=True)

raw_result = pysages.run(method, generate_simulation, 25, post_run_action=post_run_action)
result = pysages.analyze(raw_result, topology=(14,))
Expand Down
15 changes: 9 additions & 6 deletions pysages/methods/abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pysages.methods.restraints import apply_restraints
from pysages.methods.utils import numpyfy_vals
from pysages.typing import JaxArray, NamedTuple
from pysages.utils import dispatch, solve_pos_def
from pysages.utils import dispatch, linear_solver


class ABFState(NamedTuple):
Expand Down Expand Up @@ -103,13 +103,19 @@ class ABF(GriddedSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}

def __init__(self, cvs, grid, **kwargs):
super().__init__(cvs, grid, **kwargs)
self.N = np.asarray(self.kwargs.get("N", 500))
self.use_pinv = self.kwargs.get("use_pinv", False)

def build(self, snapshot, helpers, *args, **kwargs):
"""
Expand Down Expand Up @@ -158,6 +164,7 @@ def _abf(method, snapshot, helpers):
dt = snapshot.dt
dims = grid.shape.size
natoms = np.size(snapshot.positions, 0)
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
estimate_force = build_force_estimator(method)

Expand Down Expand Up @@ -201,11 +208,7 @@ def update(state, data):
xi, Jxi = cv(data)

p = data.momenta
# The following could equivalently be computed as `linalg.pinv(Jxi.T) @ p`
# (both seem to have the same performance).
# Another option to benchmark against is
# Wp = linalg.tensorsolve(Jxi @ Jxi.T, Jxi @ p)
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
# Second order backward finite difference
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt

Expand Down
11 changes: 9 additions & 2 deletions pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve, normalize
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, first_or_all, solve_pos_def
from pysages.utils import dispatch, first_or_all, linear_solver

# Aliases
f32 = np.float32
Expand Down Expand Up @@ -148,6 +148,11 @@ class CFF(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -171,6 +176,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs):
self.fmodel = MLP(dims, dims, topology, transform=scale)
self.optimizer = kwargs.get("optimizer", default_optimizer)
self.foptimizer = kwargs.get("foptimizer", default_foptimizer)
self.use_pinv = self.kwargs.get("use_pinv", False)

def build(self, snapshot, helpers):
return _cff(self, snapshot, helpers)
Expand All @@ -187,6 +193,7 @@ def _cff(method: CFF, snapshot, helpers):
fps, _ = unpack(method.fmodel.parameters)

# Helper methods
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
learn_free_energy = build_free_energy_learner(method)
estimate_force = build_force_estimator(method)
Expand Down Expand Up @@ -221,7 +228,7 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
11 changes: 9 additions & 2 deletions pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve, normalize
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, first_or_all, solve_pos_def
from pysages.utils import dispatch, first_or_all, linear_solver


class FUNNState(NamedTuple):
Expand Down Expand Up @@ -126,6 +126,11 @@ class FUNN(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -142,6 +147,7 @@ def __init__(self, cvs, grid, topology, **kwargs):
self.model = MLP(dims, dims, topology, transform=scale)
default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6))
self.optimizer = kwargs.get("optimizer", default_optimizer)
self.use_pinv = self.kwargs.get("use_pinv", False)

def build(self, snapshot, helpers):
return _funn(self, snapshot, helpers)
Expand All @@ -160,6 +166,7 @@ def _funn(method, snapshot, helpers):
ps, _ = unpack(method.model.parameters)

# Helper methods
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
learn_free_energy_grad = build_free_energy_grad_learner(method)
estimate_free_energy_grad = build_force_estimator(method)
Expand All @@ -186,7 +193,7 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
11 changes: 9 additions & 2 deletions pysages/methods/sirens.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, first_or_all, solve_pos_def
from pysages.utils import dispatch, first_or_all, linear_solver


class SirensState(NamedTuple): # pylint: disable=R0903
Expand Down Expand Up @@ -146,6 +146,11 @@ class Sirens(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -172,6 +177,7 @@ def __init__(self, cvs, grid, topology, **kwargs):
scale = partial(_scale, grid=grid)
self.model = Siren(dims, 1, topology, transform=scale)
self.optimizer = optimizer
self.use_pinv = self.kwargs.get("use_pinv", False)

def __check_init_invariants__(self, mode, kT, optimizer):
if mode not in ("abf", "cff"):
Expand Down Expand Up @@ -202,6 +208,7 @@ def _sirens(method: Sirens, snapshot, helpers):
ps, _ = unpack(method.model.parameters)

# Helper methods
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
learn_free_energy = build_free_energy_learner(method)
estimate_force = build_force_estimator(method)
Expand Down Expand Up @@ -244,7 +251,7 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
11 changes: 9 additions & 2 deletions pysages/methods/spectral_abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pysages.methods.restraints import apply_restraints
from pysages.methods.utils import numpyfy_vals
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, first_or_all, solve_pos_def
from pysages.utils import dispatch, first_or_all, linear_solver


class SpectralABFState(NamedTuple):
Expand Down Expand Up @@ -124,6 +124,11 @@ class SpectralABF(GriddedSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -135,6 +140,7 @@ def __init__(self, cvs, grid, **kwargs):
self.fit_threshold = self.kwargs.get("fit_threshold", 500)
self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev])
self.model = SpectralGradientFit(self.grid)
self.use_pinv = self.kwargs.get("use_pinv", False)

def build(self, snapshot, helpers, *_args, **_kwargs):
"""
Expand All @@ -154,6 +160,7 @@ def _spectral_abf(method, snapshot, helpers):
natoms = np.size(snapshot.positions, 0)

# Helper methods
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
fit = build_fitter(method.model)
fit_forces = build_free_energy_fitter(method, fit)
Expand Down Expand Up @@ -181,7 +188,7 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
# Second order backward finite difference
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
Expand Down
11 changes: 10 additions & 1 deletion pysages/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,14 @@
solve_pos_def,
try_import,
)
from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity
from .core import (
ToCPU,
copy,
dispatch,
eps,
first_or_all,
gaussian,
identity,
linear_solver,
)
from .transformations import quaternion_from_euler, quaternion_matrix
20 changes: 20 additions & 0 deletions pysages/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from plum import Dispatcher

from pysages.typing import JaxArray, Scalar
from pysages.utils.compat import solve_pos_def

# PySAGES main dispatcher
dispatch = Dispatcher()
Expand Down Expand Up @@ -70,3 +71,22 @@ def gaussian(a, sigma, x):
N-dimensional origin-centered gaussian with height `a` and standard deviation `sigma`.
"""
return a * np.exp(-row_sum((x / sigma) ** 2) / 2)


def linear_solver(use_pinv: bool):
"""
Returns a function that solves the linear system `A.T @ X = B` for `X`.
When `use_pinv == True`, `numpy.linalg.pinv` is used rather than `scipy.linalg.solve`
(this is computationally more expensive but numerically more stable).
"""
if use_pinv:
# This is numerically more robust
def tsolve(A, B):
return np.linalg.pinv(A.T) @ B

else:
# Another option to benchmark against is `linalg.tensorsolve(A @ A.T, A @ B)`
def tsolve(A, B):
return solve_pos_def(A @ A.T, A @ B)

return tsolve

0 comments on commit 9c8d2a4

Please sign in to comment.