From 42ac1aa838a4e0f2b5b9cf96c994b50db0fb8ade Mon Sep 17 00:00:00 2001 From: yihengwuKP Date: Fri, 9 Aug 2024 17:02:56 -0500 Subject: [PATCH 1/6] add options for using np.linalg.pinv for Wp calc --- pysages/methods/abf.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pysages/methods/abf.py b/pysages/methods/abf.py index ba713d15..c4caf38f 100644 --- a/pysages/methods/abf.py +++ b/pysages/methods/abf.py @@ -103,6 +103,11 @@ class ABF(GriddedSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_np_pinv: Optional[Bool] = False + If set to True, the Wp will be calculated using np.linalg.pinv(Jxi.T)@p + rather than solve_pos_def(Jxi @ Jxi.T, Jxi @ p). + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -110,6 +115,7 @@ class ABF(GriddedSamplingMethod): def __init__(self, cvs, grid, **kwargs): super().__init__(cvs, grid, **kwargs) self.N = np.asarray(self.kwargs.get("N", 500)) + self.use_np_pinv = self.kwargs.get("use_np_pinv", False) def build(self, snapshot, helpers, *args, **kwargs): """ @@ -201,11 +207,16 @@ def update(state, data): xi, Jxi = cv(data) p = data.momenta + use_np_pinv = data.use_np_pinv + # The following could equivalently be computed as `linalg.pinv(Jxi.T) @ p` # (both seem to have the same performance). # Another option to benchmark against is # Wp = linalg.tensorsolve(Jxi @ Jxi.T, Jxi @ p) - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + if use_np_pinv: + Wp = np.linalg.pinv(Jxi.T) @ p + else: + Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt From ca18225354c9f433bff1e3be58200f2258ee49f3 Mon Sep 17 00:00:00 2001 From: yihengwuKP Date: Fri, 9 Aug 2024 17:27:29 -0500 Subject: [PATCH 2/6] bug fix: correct assignment of use_np_pinv --- examples/openmm/abf/alanine-dipeptide_openmm.py | 3 +-- pysages/methods/abf.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/openmm/abf/alanine-dipeptide_openmm.py b/examples/openmm/abf/alanine-dipeptide_openmm.py index f913b23f..632fb693 100644 --- a/examples/openmm/abf/alanine-dipeptide_openmm.py +++ b/examples/openmm/abf/alanine-dipeptide_openmm.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 - import matplotlib.pyplot as plt import numpy @@ -115,7 +114,7 @@ def post_run_action(**kwargs): def main(): cvs = [DihedralAngle((4, 6, 8, 14)), DihedralAngle((6, 8, 14, 16))] grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(32, 32), periodic=True) - method = ABF(cvs, grid) + method = ABF(cvs, grid, use_np_pinv=True) raw_result = pysages.run(method, generate_simulation, 25, post_run_action=post_run_action) result = pysages.analyze(raw_result, topology=(14,)) diff --git a/pysages/methods/abf.py b/pysages/methods/abf.py index c4caf38f..f9a60137 100644 --- a/pysages/methods/abf.py +++ b/pysages/methods/abf.py @@ -160,6 +160,7 @@ def _abf(method, snapshot, helpers): """ cv = method.cv grid = method.grid + use_np_pinv = method.use_np_pinv dt = snapshot.dt dims = grid.shape.size @@ -207,7 +208,6 @@ def update(state, data): xi, Jxi = cv(data) p = data.momenta - use_np_pinv = data.use_np_pinv # The following could equivalently be computed as `linalg.pinv(Jxi.T) @ p` # (both seem to have the same performance). From 4bf3ac3f41d8447a101466661c4cbdc6c942f29d Mon Sep 17 00:00:00 2001 From: yihengwuKP Date: Fri, 9 Aug 2024 18:35:16 -0500 Subject: [PATCH 3/6] fix Wp word issue & impl option for other methods --- pysages/methods/abf.py | 5 +++-- pysages/methods/cff.py | 13 ++++++++++++- pysages/methods/funn.py | 13 ++++++++++++- pysages/methods/sirens.py | 13 ++++++++++++- pysages/methods/spectral_abf.py | 13 ++++++++++++- 5 files changed, 51 insertions(+), 6 deletions(-) diff --git a/pysages/methods/abf.py b/pysages/methods/abf.py index f9a60137..1c12cd74 100644 --- a/pysages/methods/abf.py +++ b/pysages/methods/abf.py @@ -105,8 +105,9 @@ class ABF(GriddedSamplingMethod): `restraints.upper`. use_np_pinv: Optional[Bool] = False - If set to True, the Wp will be calculated using np.linalg.pinv(Jxi.T)@p - rather than solve_pos_def(Jxi @ Jxi.T, Jxi @ p). + If set to True, the product W times momentum p + will be calculated using pseudo-inverse from numpy + rather than using the solving function from scipy This is computationally more expensive but numerically more stable. """ diff --git a/pysages/methods/cff.py b/pysages/methods/cff.py index 57b5fa96..ad21fca7 100644 --- a/pysages/methods/cff.py +++ b/pysages/methods/cff.py @@ -148,6 +148,12 @@ class CFF(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_np_pinv: Optional[Bool] = False + If set to True, the product W times momentum p + will be calculated using pseudo-inverse from numpy + rather than using the solving function from scipy + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -171,6 +177,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs): self.fmodel = MLP(dims, dims, topology, transform=scale) self.optimizer = kwargs.get("optimizer", default_optimizer) self.foptimizer = kwargs.get("foptimizer", default_foptimizer) + self.use_np_pinv = self.kwargs.get("use_np_pinv", False) def build(self, snapshot, helpers): return _cff(self, snapshot, helpers) @@ -180,6 +187,7 @@ def _cff(method: CFF, snapshot, helpers): cv = method.cv grid = method.grid train_freq = method.train_freq + use_np_pinv = method.use_np_pinv dt = snapshot.dt # Neural network paramters @@ -221,7 +229,10 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + if use_np_pinv: + Wp = np.linalg.pinv(Jxi.T) @ p + else: + Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/funn.py b/pysages/methods/funn.py index 6130d396..1f6171d5 100644 --- a/pysages/methods/funn.py +++ b/pysages/methods/funn.py @@ -126,6 +126,12 @@ class FUNN(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_np_pinv: Optional[Bool] = False + If set to True, the product W times momentum p + will be calculated using pseudo-inverse from numpy + rather than using the solving function from scipy + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -142,6 +148,7 @@ def __init__(self, cvs, grid, topology, **kwargs): self.model = MLP(dims, dims, topology, transform=scale) default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6)) self.optimizer = kwargs.get("optimizer", default_optimizer) + self.use_np_pinv = self.kwargs.get("use_np_pinv", False) def build(self, snapshot, helpers): return _funn(self, snapshot, helpers) @@ -151,6 +158,7 @@ def _funn(method, snapshot, helpers): cv = method.cv grid = method.grid train_freq = method.train_freq + use_np_pinv = method.use_np_pinv dt = snapshot.dt dims = grid.shape.size @@ -186,7 +194,10 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + if use_np_pinv: + Wp = np.linalg.pinv(Jxi.T) @ p + else: + Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/sirens.py b/pysages/methods/sirens.py index b1342f9b..403b8545 100644 --- a/pysages/methods/sirens.py +++ b/pysages/methods/sirens.py @@ -146,6 +146,12 @@ class Sirens(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_np_pinv: Optional[Bool] = False + If set to True, the product W times momentum p + will be calculated using pseudo-inverse from numpy + rather than using the solving function from scipy + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -172,6 +178,7 @@ def __init__(self, cvs, grid, topology, **kwargs): scale = partial(_scale, grid=grid) self.model = Siren(dims, 1, topology, transform=scale) self.optimizer = optimizer + self.use_np_pinv = self.kwargs.get("use_np_pinv", False) def __check_init_invariants__(self, mode, kT, optimizer): if mode not in ("abf", "cff"): @@ -196,6 +203,7 @@ def _sirens(method: Sirens, snapshot, helpers): cv = method.cv grid = method.grid train_freq = method.train_freq + use_np_pinv = method.use_np_pinv dt = snapshot.dt # Neural network paramters @@ -244,7 +252,10 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + if use_np_pinv: + Wp = np.linalg.pinv(Jxi.T) @ p + else: + Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index 63b47955..81623016 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -124,6 +124,12 @@ class SpectralABF(GriddedSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_np_pinv: Optional[Bool] = False + If set to True, the product W times momentum p + will be calculated using pseudo-inverse from numpy + rather than using the solving function from scipy + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -135,6 +141,7 @@ def __init__(self, cvs, grid, **kwargs): self.fit_threshold = self.kwargs.get("fit_threshold", 500) self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev]) self.model = SpectralGradientFit(self.grid) + self.use_np_pinv = self.kwargs.get("use_np_pinv", False) def build(self, snapshot, helpers, *_args, **_kwargs): """ @@ -148,6 +155,7 @@ def _spectral_abf(method, snapshot, helpers): grid = method.grid fit_freq = method.fit_freq fit_threshold = method.fit_threshold + use_np_pinv = method.use_np_pinv dt = snapshot.dt dims = grid.shape.size @@ -181,7 +189,10 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + if use_np_pinv: + Wp = np.linalg.pinv(Jxi.T) @ p + else: + Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # From d7199eccd8762443256f4ffa9be29cd29b0adab2 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:26:34 -0500 Subject: [PATCH 4/6] Reduce code duplication --- .../openmm/abf/alanine-dipeptide_openmm.py | 2 +- pysages/methods/abf.py | 23 ++++++------------- pysages/methods/cff.py | 18 ++++++--------- pysages/methods/funn.py | 18 ++++++--------- pysages/methods/sirens.py | 18 ++++++--------- pysages/methods/spectral_abf.py | 18 ++++++--------- pysages/utils/__init__.py | 2 +- pysages/utils/core.py | 22 ++++++++++++++++++ 8 files changed, 59 insertions(+), 62 deletions(-) diff --git a/examples/openmm/abf/alanine-dipeptide_openmm.py b/examples/openmm/abf/alanine-dipeptide_openmm.py index 632fb693..810c362d 100644 --- a/examples/openmm/abf/alanine-dipeptide_openmm.py +++ b/examples/openmm/abf/alanine-dipeptide_openmm.py @@ -114,7 +114,7 @@ def post_run_action(**kwargs): def main(): cvs = [DihedralAngle((4, 6, 8, 14)), DihedralAngle((6, 8, 14, 16))] grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(32, 32), periodic=True) - method = ABF(cvs, grid, use_np_pinv=True) + method = ABF(cvs, grid, use_pinv=True) raw_result = pysages.run(method, generate_simulation, 25, post_run_action=post_run_action) result = pysages.analyze(raw_result, topology=(14,)) diff --git a/pysages/methods/abf.py b/pysages/methods/abf.py index 1c12cd74..dbba8888 100644 --- a/pysages/methods/abf.py +++ b/pysages/methods/abf.py @@ -27,7 +27,7 @@ from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals from pysages.typing import JaxArray, NamedTuple -from pysages.utils import dispatch, solve_pos_def +from pysages.utils import dispatch, linear_solver class ABFState(NamedTuple): @@ -104,10 +104,9 @@ class ABF(GriddedSamplingMethod): collective variable lies outside the box from `restraints.lower` to `restraints.upper`. - use_np_pinv: Optional[Bool] = False - If set to True, the product W times momentum p - will be calculated using pseudo-inverse from numpy - rather than using the solving function from scipy + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. This is computationally more expensive but numerically more stable. """ @@ -116,7 +115,7 @@ class ABF(GriddedSamplingMethod): def __init__(self, cvs, grid, **kwargs): super().__init__(cvs, grid, **kwargs) self.N = np.asarray(self.kwargs.get("N", 500)) - self.use_np_pinv = self.kwargs.get("use_np_pinv", False) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers, *args, **kwargs): """ @@ -161,11 +160,11 @@ def _abf(method, snapshot, helpers): """ cv = method.cv grid = method.grid - use_np_pinv = method.use_np_pinv dt = snapshot.dt dims = grid.shape.size natoms = np.size(snapshot.positions, 0) + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) estimate_force = build_force_estimator(method) @@ -209,15 +208,7 @@ def update(state, data): xi, Jxi = cv(data) p = data.momenta - - # The following could equivalently be computed as `linalg.pinv(Jxi.T) @ p` - # (both seem to have the same performance). - # Another option to benchmark against is - # Wp = linalg.tensorsolve(Jxi @ Jxi.T, Jxi @ p) - if use_np_pinv: - Wp = np.linalg.pinv(Jxi.T) @ p - else: - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt diff --git a/pysages/methods/cff.py b/pysages/methods/cff.py index ad21fca7..3a3306c1 100644 --- a/pysages/methods/cff.py +++ b/pysages/methods/cff.py @@ -31,7 +31,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver # Aliases f32 = np.float32 @@ -149,10 +149,9 @@ class CFF(NNSamplingMethod): collective variable lies outside the box from `restraints.lower` to `restraints.upper`. - use_np_pinv: Optional[Bool] = False - If set to True, the product W times momentum p - will be calculated using pseudo-inverse from numpy - rather than using the solving function from scipy + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. This is computationally more expensive but numerically more stable. """ @@ -177,7 +176,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs): self.fmodel = MLP(dims, dims, topology, transform=scale) self.optimizer = kwargs.get("optimizer", default_optimizer) self.foptimizer = kwargs.get("foptimizer", default_foptimizer) - self.use_np_pinv = self.kwargs.get("use_np_pinv", False) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers): return _cff(self, snapshot, helpers) @@ -187,7 +186,6 @@ def _cff(method: CFF, snapshot, helpers): cv = method.cv grid = method.grid train_freq = method.train_freq - use_np_pinv = method.use_np_pinv dt = snapshot.dt # Neural network paramters @@ -195,6 +193,7 @@ def _cff(method: CFF, snapshot, helpers): fps, _ = unpack(method.fmodel.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy = build_free_energy_learner(method) estimate_force = build_force_estimator(method) @@ -229,10 +228,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - if use_np_pinv: - Wp = np.linalg.pinv(Jxi.T) @ p - else: - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/funn.py b/pysages/methods/funn.py index 1f6171d5..593dda5a 100644 --- a/pysages/methods/funn.py +++ b/pysages/methods/funn.py @@ -34,7 +34,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class FUNNState(NamedTuple): @@ -127,10 +127,9 @@ class FUNN(NNSamplingMethod): collective variable lies outside the box from `restraints.lower` to `restraints.upper`. - use_np_pinv: Optional[Bool] = False - If set to True, the product W times momentum p - will be calculated using pseudo-inverse from numpy - rather than using the solving function from scipy + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. This is computationally more expensive but numerically more stable. """ @@ -148,7 +147,7 @@ def __init__(self, cvs, grid, topology, **kwargs): self.model = MLP(dims, dims, topology, transform=scale) default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6)) self.optimizer = kwargs.get("optimizer", default_optimizer) - self.use_np_pinv = self.kwargs.get("use_np_pinv", False) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers): return _funn(self, snapshot, helpers) @@ -158,7 +157,6 @@ def _funn(method, snapshot, helpers): cv = method.cv grid = method.grid train_freq = method.train_freq - use_np_pinv = method.use_np_pinv dt = snapshot.dt dims = grid.shape.size @@ -168,6 +166,7 @@ def _funn(method, snapshot, helpers): ps, _ = unpack(method.model.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy_grad = build_free_energy_grad_learner(method) estimate_free_energy_grad = build_force_estimator(method) @@ -194,10 +193,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - if use_np_pinv: - Wp = np.linalg.pinv(Jxi.T) @ p - else: - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/sirens.py b/pysages/methods/sirens.py index 403b8545..836e31b6 100644 --- a/pysages/methods/sirens.py +++ b/pysages/methods/sirens.py @@ -32,7 +32,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class SirensState(NamedTuple): # pylint: disable=R0903 @@ -147,10 +147,9 @@ class Sirens(NNSamplingMethod): collective variable lies outside the box from `restraints.lower` to `restraints.upper`. - use_np_pinv: Optional[Bool] = False - If set to True, the product W times momentum p - will be calculated using pseudo-inverse from numpy - rather than using the solving function from scipy + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. This is computationally more expensive but numerically more stable. """ @@ -178,7 +177,7 @@ def __init__(self, cvs, grid, topology, **kwargs): scale = partial(_scale, grid=grid) self.model = Siren(dims, 1, topology, transform=scale) self.optimizer = optimizer - self.use_np_pinv = self.kwargs.get("use_np_pinv", False) + self.use_pinv = self.kwargs.get("use_pinv", False) def __check_init_invariants__(self, mode, kT, optimizer): if mode not in ("abf", "cff"): @@ -203,13 +202,13 @@ def _sirens(method: Sirens, snapshot, helpers): cv = method.cv grid = method.grid train_freq = method.train_freq - use_np_pinv = method.use_np_pinv dt = snapshot.dt # Neural network paramters ps, _ = unpack(method.model.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy = build_free_energy_learner(method) estimate_force = build_force_estimator(method) @@ -252,10 +251,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - if use_np_pinv: - Wp = np.linalg.pinv(Jxi.T) @ p - else: - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index 81623016..c50d2f11 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -30,7 +30,7 @@ from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class SpectralABFState(NamedTuple): @@ -125,10 +125,9 @@ class SpectralABF(GriddedSamplingMethod): collective variable lies outside the box from `restraints.lower` to `restraints.upper`. - use_np_pinv: Optional[Bool] = False - If set to True, the product W times momentum p - will be calculated using pseudo-inverse from numpy - rather than using the solving function from scipy + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. This is computationally more expensive but numerically more stable. """ @@ -141,7 +140,7 @@ def __init__(self, cvs, grid, **kwargs): self.fit_threshold = self.kwargs.get("fit_threshold", 500) self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev]) self.model = SpectralGradientFit(self.grid) - self.use_np_pinv = self.kwargs.get("use_np_pinv", False) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers, *_args, **_kwargs): """ @@ -155,13 +154,13 @@ def _spectral_abf(method, snapshot, helpers): grid = method.grid fit_freq = method.fit_freq fit_threshold = method.fit_threshold - use_np_pinv = method.use_np_pinv dt = snapshot.dt dims = grid.shape.size natoms = np.size(snapshot.positions, 0) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) fit = build_fitter(method.model) fit_forces = build_free_energy_fitter(method, fit) @@ -189,10 +188,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - if use_np_pinv: - Wp = np.linalg.pinv(Jxi.T) @ p - else: - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 3351d48e..f87d403f 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -18,5 +18,5 @@ solve_pos_def, try_import, ) -from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity +from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity, linear_solver from .transformations import quaternion_from_euler, quaternion_matrix diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 06afdbc8..c60ff23d 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -8,6 +8,7 @@ from plum import Dispatcher from pysages.typing import JaxArray, Scalar +from pysages.utils.compat import solve_pos_def # PySAGES main dispatcher dispatch = Dispatcher() @@ -70,3 +71,24 @@ def gaussian(a, sigma, x): N-dimensional origin-centered gaussian with height `a` and standard deviation `sigma`. """ return a * np.exp(-row_sum((x / sigma) ** 2) / 2) + + +def linear_solver(use_pinv: bool): + """ + Returns a function that solves the linear system `A.T @ X = B` for `X`. + When `use_pinv == True`, `numpy.linalg.pinv` is used rather than `scipy.linalg.solve` + (this is computationally more expensive but numerically more stable). + """ + if method.use_pinv: + + # This is numerically more robust + def tsolve(A, B): + return np.linalg.pinv(A.T) @ B + + else: + + # Another option to benchmark against is `linalg.tensorsolve(A @ A.T, A @ B)` + def tsolve(A, B): + return solve_pos_def(A @ A.T, A @ B) + + return tsolve From caa41c06d752a575aa4b3e22e554998d234d8a80 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:27:51 -0500 Subject: [PATCH 5/6] Fix copy-paste typo --- pysages/utils/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysages/utils/core.py b/pysages/utils/core.py index c60ff23d..25e2a2d1 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -79,7 +79,7 @@ def linear_solver(use_pinv: bool): When `use_pinv == True`, `numpy.linalg.pinv` is used rather than `scipy.linalg.solve` (this is computationally more expensive but numerically more stable). """ - if method.use_pinv: + if use_pinv: # This is numerically more robust def tsolve(A, B): From 30580165a9ed01a7ff09d4fab4fb804b663b5194 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Thu, 15 Aug 2024 16:25:50 -0500 Subject: [PATCH 6/6] Make it black complaint --- pysages/utils/__init__.py | 11 ++++++++++- pysages/utils/core.py | 2 -- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index f87d403f..ffde6466 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -18,5 +18,14 @@ solve_pos_def, try_import, ) -from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity, linear_solver +from .core import ( + ToCPU, + copy, + dispatch, + eps, + first_or_all, + gaussian, + identity, + linear_solver, +) from .transformations import quaternion_from_euler, quaternion_matrix diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 25e2a2d1..20fe0f3e 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -80,13 +80,11 @@ def linear_solver(use_pinv: bool): (this is computationally more expensive but numerically more stable). """ if use_pinv: - # This is numerically more robust def tsolve(A, B): return np.linalg.pinv(A.T) @ B else: - # Another option to benchmark against is `linalg.tensorsolve(A @ A.T, A @ B)` def tsolve(A, B): return solve_pos_def(A @ A.T, A @ B)