diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26838b99..34273681 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/mwouts/jupytext - rev: v1.16.2 + rev: v1.16.4b hooks: - id: jupytext args: [--sync] diff --git a/examples/Advanced_Sampling_Introduction.md b/examples/Advanced_Sampling_Introduction.md index c773b51b..73ec65b0 100644 --- a/examples/Advanced_Sampling_Introduction.md +++ b/examples/Advanced_Sampling_Introduction.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/Install_PySAGES_Environment.md b/examples/Install_PySAGES_Environment.md index b8ff796d..50a79ecc 100644 --- a/examples/Install_PySAGES_Environment.md +++ b/examples/Install_PySAGES_Environment.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/hoomd-blue/ann/Butane_ANN.md b/examples/hoomd-blue/ann/Butane_ANN.md index c2106e61..38914349 100644 --- a/examples/hoomd-blue/ann/Butane_ANN.md +++ b/examples/hoomd-blue/ann/Butane_ANN.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/hoomd-blue/cff/Butane_CFF.md b/examples/hoomd-blue/cff/Butane_CFF.md index 7aa76271..35aea699 100644 --- a/examples/hoomd-blue/cff/Butane_CFF.md +++ b/examples/hoomd-blue/cff/Butane_CFF.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/hoomd-blue/funn/Butane_FUNN.md b/examples/hoomd-blue/funn/Butane_FUNN.md index e4c22704..dd301eee 100644 --- a/examples/hoomd-blue/funn/Butane_FUNN.md +++ b/examples/hoomd-blue/funn/Butane_FUNN.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/hoomd-blue/harmonic_bias/Harmonic_Bias.md b/examples/hoomd-blue/harmonic_bias/Harmonic_Bias.md index 1ecfe52e..70231897 100644 --- a/examples/hoomd-blue/harmonic_bias/Harmonic_Bias.md +++ b/examples/hoomd-blue/harmonic_bias/Harmonic_Bias.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.ipynb b/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.ipynb index 0ae6b9df..e346d8cd 100644 --- a/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.ipynb +++ b/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.ipynb @@ -638,8 +638,7 @@ "A = result[\"free_energy\"]\n", "# Alternatively:\n", "# fes_fn = result[\"fes_fn\"]\n", - "# A = fes_fn(mesh)\n", - "A = A.max() - A" + "# A = fes_fn(mesh)" ] }, { diff --git a/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.md b/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.md index 472b12ca..78d15f3f 100644 --- a/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.md +++ b/examples/hoomd-blue/spectral_abf/Butane-SpectralABF.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -386,7 +386,6 @@ A = result["free_energy"] # Alternatively: # fes_fn = result["fes_fn"] # A = fes_fn(mesh) -A = A.max() - A ``` ```python colab={"base_uri": "https://localhost:8080/", "height": 302} id="7_d_XfVLLkbI" outputId="e35db259-31f8-4a3b-b1fa-7e91a8a5c88a" diff --git a/examples/hoomd-blue/spectral_abf/butane.py b/examples/hoomd-blue/spectral_abf/butane.py index 77a68bb7..405010ba 100644 --- a/examples/hoomd-blue/spectral_abf/butane.py +++ b/examples/hoomd-blue/spectral_abf/butane.py @@ -257,7 +257,7 @@ def get_args(argv): ("time-steps", "t", int, 5e5, "Number of simulation steps"), ] parser = argparse.ArgumentParser(description="Example script to run SpectralABF") - for (name, short, T, val, doc) in available_args: + for name, short, T, val, doc in available_args: parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc) return parser.parse_args(argv) @@ -283,7 +283,6 @@ def main(argv=[]): mesh = result["mesh"] fes_fn = result["fes_fn"] A = fes_fn(mesh) - A = A.max() - A # plot the free energy fig, ax = plt.subplots() diff --git a/examples/hoomd-blue/umbrella_integration/Umbrella_Integration.md b/examples/hoomd-blue/umbrella_integration/Umbrella_Integration.md index c847a621..8f7a25ca 100644 --- a/examples/hoomd-blue/umbrella_integration/Umbrella_Integration.md +++ b/examples/hoomd-blue/umbrella_integration/Umbrella_Integration.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/hoomd3/spectral_abf/butane.py b/examples/hoomd3/spectral_abf/butane.py index baedfa2e..ae6c11f3 100644 --- a/examples/hoomd3/spectral_abf/butane.py +++ b/examples/hoomd3/spectral_abf/butane.py @@ -298,7 +298,6 @@ def main(argv=[]): mesh = result["mesh"] fes_fn = result["fes_fn"] A = fes_fn(mesh) - A = A.max() - A # plot the free energy fig, ax = plt.subplots() diff --git a/examples/lammps/spectral_abf/adp.input b/examples/lammps/spectral_abf/adp.input new file mode 100644 index 00000000..5873472e --- /dev/null +++ b/examples/lammps/spectral_abf/adp.input @@ -0,0 +1,244 @@ +LAMMPS data file for ACE + +22 atoms +21 bonds +36 angles +66 dihedrals + +7 atom types +8 bond types +16 angle types +19 dihedral types + +-100 100 xlo xhi +-100 100 ylo yhi +-100 100 zlo zhi + +Masses + +1 1.008 +2 12.01 +3 12.01 +4 16.0 +5 14.01 +6 1.008 +7 1.008 + +Atoms + + 1 1 1 0.1123 2.0177833000000001 1.0353295 0.054216800000000002 + 2 1 2 -0.36620000000000003 1.9998066999999999 2.1250665999999998 0.0538082 + 3 1 1 0.1123 1.4254165999999999 2.4843904000000001 0.90724000000000005 + 4 1 1 0.1123 1.5326313 2.4823206 -0.86386980000000002 + 5 1 3 0.5972000021951126 3.4222907 2.6597886000000002 0.13890169999999999 + 6 1 4 -0.56790000164633436 4.3607215999999998 1.8984493 0.36315979999999998 + 7 1 5 -0.41570000000000001 3.5879191000000001 3.9603571 -0.1138393 + 8 1 6 0.27190000000000003 2.7859197999999998 4.5436604999999997 -0.3046603 + 9 1 2 0.033699999999999994 4.8919525999999998 4.5905944999999999 -0.067514500000000005 +10 1 7 0.082299999999999998 5.4432974999999999 4.2337999000000002 0.80264219999999997 +11 1 2 -0.1825 5.6616692000000004 4.2349395999999997 -1.3354246999999999 +12 1 1 0.060299999999999992 5.1089564000000003 4.5854832999999999 -2.2068146999999998 +13 1 1 0.060299999999999992 6.6419078000000003 4.7111339000000001 -1.3103305999999999 +14 1 1 0.060299999999999992 5.7851258000000003 3.1535202 -1.3931851 +15 1 3 0.59730000054877808 4.7222410999999997 6.1004128 0.021521100000000001 +16 1 4 -0.56790000164633436 3.6035081999999998 6.6057630999999999 -0.039872400000000002 +17 1 5 -0.41570000000000001 5.8431179000000002 6.8242649999999996 0.073516899999999996 +18 1 6 0.27190000000000003 6.7415383999999996 6.3635967999999998 0.087583599999999998 +19 1 2 -0.14899999999999999 5.8209701000000003 8.2733497000000007 0.056992800000000003 +20 1 7 0.097600000000000006 4.7879085999999997 8.6215931000000001 0.0518138 +21 1 7 0.097600000000000006 6.3303031000000001 8.6556759000000003 0.94173300000000004 +22 1 7 0.097600000000000006 6.3290771000000001 8.6329496999999993 -0.83830830000000001 + +Bonds + + 1 3 2 3 + 2 3 2 4 + 3 3 1 2 + 4 3 11 12 + 5 3 11 13 + 6 3 11 14 + 7 5 9 10 + 8 7 7 8 + 9 5 19 20 +10 5 19 21 +11 5 19 22 +12 7 17 18 +13 1 5 6 +14 2 5 7 +15 4 2 5 +16 1 15 16 +17 2 15 17 +18 6 9 11 +19 4 9 15 +20 8 7 9 +21 8 17 19 + +Angles + +1 2 5 7 8 +2 4 4 2 5 +3 5 3 2 4 +4 4 3 2 5 +5 5 1 2 3 +6 5 1 2 4 +7 4 1 2 5 +8 2 15 17 18 +9 5 13 11 14 +10 5 12 11 13 +11 5 12 11 14 +12 9 10 9 11 +13 10 10 9 15 +14 11 9 11 12 +15 11 9 11 13 +16 11 9 11 14 +17 12 8 7 9 +18 13 7 9 10 +19 16 21 19 22 +20 16 20 19 21 +21 16 20 19 22 +22 12 18 17 19 +23 13 17 19 20 +24 13 17 19 21 +25 13 17 19 22 +26 1 6 5 7 +27 3 5 7 9 +28 6 2 5 6 +29 7 2 5 7 +30 1 16 15 17 +31 3 15 17 19 +32 8 11 9 15 +33 6 9 15 16 +34 7 9 15 17 +35 14 7 9 11 +36 15 7 9 15 + +Dihedrals + +1 1 6 5 7 8 +2 2 6 5 7 8 +3 3 5 7 9 10 +4 10 4 2 5 6 +5 3 4 2 5 6 +6 11 4 2 5 6 +7 3 4 2 5 7 +8 10 3 2 5 6 +9 3 3 2 5 6 +10 11 3 2 5 6 +11 3 3 2 5 7 +12 2 2 5 7 8 +13 10 1 2 5 6 +14 3 1 2 5 6 +15 11 1 2 5 6 +16 3 1 2 5 7 +17 1 16 15 17 18 +18 2 16 15 17 18 +19 3 15 17 19 20 +20 3 15 17 19 21 +21 3 15 17 19 22 +22 12 14 11 9 15 +23 12 13 11 9 15 +24 12 12 11 9 15 +25 12 10 9 11 12 +26 12 10 9 11 13 +27 12 10 9 11 14 +28 10 10 9 15 16 +29 11 10 9 15 16 +30 3 10 9 15 17 +31 2 9 15 17 18 +32 3 8 7 9 10 +33 3 8 7 9 11 +34 3 8 7 9 15 +35 12 7 9 11 12 +36 12 7 9 11 13 +37 12 7 9 11 14 +38 3 18 17 19 20 +39 3 18 17 19 21 +40 3 18 17 19 22 +41 19 5 9 7 8 +42 19 15 19 17 18 +43 2 6 5 7 9 +44 1 5 7 9 11 +45 4 5 7 9 11 +46 5 5 7 9 11 +47 6 5 7 9 11 +48 7 5 7 9 15 +49 8 5 7 9 15 +50 9 5 7 9 15 +51 6 5 7 9 15 +52 2 2 5 7 9 +53 2 16 15 17 19 +54 3 11 9 15 16 +55 13 11 9 15 17 +56 14 11 9 15 17 +57 5 11 9 15 17 +58 6 11 9 15 17 +59 2 9 15 17 19 +60 3 7 9 15 16 +61 15 7 9 15 17 +62 16 7 9 15 17 +63 17 7 9 15 17 +64 6 7 9 15 17 +65 18 2 7 5 6 +66 18 9 17 15 16 + +Pair Coeffs + +1 0.01570000002623629 2.6495327872602221 +2 0.10939999991572773 3.3996695084507409 +3 0.086000000128358844 3.3996695079448309 +4 0.20999999984182244 2.9599219016446874 +5 0.16999999991766696 3.2499985240310356 +6 0.015700000004219245 1.0690784617205229 +7 0.015700000098461422 2.4713530426421655 + +Bond Coeffs + +1 570.0 1.229 +2 490.0 1.335 +3 340.0 1.090 +4 317.0 1.522 +5 340.0 1.090 +6 310.0 1.526 +7 434.0 1.010 +8 337.0 1.449 + +Angle Coeffs + +1 80.0 122.90005267195104 +2 50.0 120.00005142908158 +3 50.0 121.90005224337536 +4 50.0 109.50004692903693 +5 35.0 109.50004692903693 +6 80.0 120.40005160051184 +7 70.0 116.60004997192425 +8 63.0 111.10004761475803 +9 50.0 109.50004692903693 +10 50.0 109.50004692903693 +11 50.0 109.50004692903693 +12 50.0 118.04005047448166 +13 50.0 109.50004692903693 +14 80.0 109.70004701475206 +15 63.0 110.10004718618234 +16 35.0 109.50004692903693 + +Dihedral Coeffs + + 1 2.0 1 1 + 2 2.5 -1 2 + 3 0.0 1 2 + 4 2.0 1 2 + 5 0.4 1 3 + 6 0.0 1 4 + 7 0.0 1 1 + 8 0.272 1 2 + 9 0.43 1 3 +10 0.8 1 1 +11 0.08 -1 3 +12 0.155555556 1 3 +13 0.20 1 1 +14 0.20 1 2 +15 0.45 -1 1 +16 1.58 -1 2 +17 0.55 -1 3 +18 10.5 -1 2 +19 1.10 -1 2 diff --git a/examples/lammps/spectral_abf/adp.lmp b/examples/lammps/spectral_abf/adp.lmp new file mode 100644 index 00000000..4bbd000d --- /dev/null +++ b/examples/lammps/spectral_abf/adp.lmp @@ -0,0 +1,24 @@ +units real + +neigh_modify once yes one 22 page 2200 + +atom_style full +bond_style harmonic +angle_style harmonic +dihedral_style harmonic +pair_style lj/cut/coul/cut 10.0 +pair_modify mix arithmetic + +kspace_style none +read_data adp.input + + +velocity all create 300 3 + +timestep 1.0 + +fix 1 all nve +fix 2 all langevin 300 300 1000 63683 +fix 3 all momentum 100 linear 1 1 1 +fix 4 all shake 0.0001 10 100 b 3 5 7 +special_bonds amber diff --git a/examples/lammps/spectral_abf/run.py b/examples/lammps/spectral_abf/run.py new file mode 100644 index 00000000..59b382f4 --- /dev/null +++ b/examples/lammps/spectral_abf/run.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +""" +Example SpectralABF simulation with pysages and lammps. + +For a list of possible options for running the script pass `-h` as argument from the +command line, or call `get_args(["-h"])` if the module was loaded interactively. +""" + +# %% +import argparse +import sys + +import numpy +from lammps import lammps + +import pysages +from pysages.colvars import DihedralAngle +from pysages.methods import HistogramLogger, SpectralABF + + +# %% +def generate_context(args="", script="adp.lmp", store_freq=1): + """ + Returns a lammps simulation defined by the contents of `script` using `args` as + initialization arguments. + """ + context = lammps(cmdargs=args.split()) + context.file(script) + # Allow for the retrieval of the wrapped positions + context.command(f"dump 4a all custom {store_freq} dump.myforce id type x y z") + return context + + +def get_args(argv): + """Process the command-line arguments to this script.""" + + available_args = [ + ("time-steps", "t", int, 2e6, "Number of simulation steps"), + ("kokkos", "k", bool, True, "Whether to use Kokkos acceleration"), + ("log-steps", "l", int, 2e3, "Number of simulation steps for logging"), + ] + parser = argparse.ArgumentParser(description="Example script to run pysages with lammps") + + for name, short, T, val, doc in available_args: + if T is bool: + action = "store_" + str(val).lower() + parser.add_argument("--" + name, "-" + short, action=action, help=doc) + else: + convert = (lambda x: int(float(x))) if T is int else T + parser.add_argument("--" + name, "-" + short, type=convert, default=T(val), help=doc) + + return parser.parse_args(argv) + + +def main(argv): + """Example SpectralABF simulation with pysages and lammps.""" + args = get_args(argv) + + context_args = {"store_freq": args.log_steps} + if args.kokkos: + # Passed to the lammps constructor as `cmdargs` when running the script + # with the --kokkos (or -k) option + context_args["args"] = "-k on g 1 -sf kk -pk kokkos newton on neigh half" + # context_args["args"] = "-k on -sf kk -pk kokkos newton on neigh half" + + # Setting the collective variable, method, and running the simulation + cvs = [DihedralAngle([4, 6, 8, 14]), DihedralAngle([6, 8, 14, 16])] + grid = pysages.Grid( + lower=(-numpy.pi, -numpy.pi), + upper=(numpy.pi, numpy.pi), + shape=(32, 32), + periodic=True, + ) + method = SpectralABF(cvs, grid) + callback = HistogramLogger(args.log_steps) + raw_result = pysages.run( + method, + generate_context, + args.time_steps, + callback=callback, + context_args=context_args, + ) + # Post-run analysis + result = pysages.analyze(raw_result) + mesh = result["mesh"] + fes_fn = result["fes_fn"] + A = fes_fn(mesh) + hist = result["histogram"] + A = A.reshape(32, 32) + numpy.savetxt("FES.dat", numpy.hstack([mesh, A.reshape(-1, 1)])) + numpy.savetxt("hist.dat", numpy.hstack([mesh, hist.reshape(-1, 1)])) + bins = 50 + histo, xedges, yedges = numpy.histogram2d( + callback.data[:, 0], + callback.data[:, 1], + bins=bins, + range=[[-numpy.pi, numpy.pi], [-numpy.pi, numpy.pi]], + ) + xedges = (xedges[1:] + xedges[:-1]) / 2 + yedges = (yedges[1:] + yedges[:-1]) / 2 + mesh = numpy.dstack(numpy.meshgrid(xedges, yedges)).reshape(-1, 2) + numpy.savetxt("hist-from-logger.dat", numpy.hstack([mesh, histo.reshape(-1, 1)])) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/openmm/Harmonic_Bias.md b/examples/openmm/Harmonic_Bias.md index 2b87ebdc..f89ee2de 100644 --- a/examples/openmm/Harmonic_Bias.md +++ b/examples/openmm/Harmonic_Bias.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/examples/openmm/abf/alanine-dipeptide_openmm.py b/examples/openmm/abf/alanine-dipeptide_openmm.py index f913b23f..810c362d 100644 --- a/examples/openmm/abf/alanine-dipeptide_openmm.py +++ b/examples/openmm/abf/alanine-dipeptide_openmm.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 - import matplotlib.pyplot as plt import numpy @@ -115,7 +114,7 @@ def post_run_action(**kwargs): def main(): cvs = [DihedralAngle((4, 6, 8, 14)), DihedralAngle((6, 8, 14, 16))] grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(32, 32), periodic=True) - method = ABF(cvs, grid) + method = ABF(cvs, grid, use_pinv=True) raw_result = pysages.run(method, generate_simulation, 25, post_run_action=post_run_action) result = pysages.analyze(raw_result, topology=(14,)) diff --git a/examples/openmm/metad/Metadynamics-ADP.md b/examples/openmm/metad/Metadynamics-ADP.md index 298401f5..6ba6a4e3 100644 --- a/examples/openmm/metad/Metadynamics-ADP.md +++ b/examples/openmm/metad/Metadynamics-ADP.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/examples/openmm/metad/nacl/Metadynamics_NaCl.md b/examples/openmm/metad/nacl/Metadynamics_NaCl.md index 60af462e..a7c990c1 100644 --- a/examples/openmm/metad/nacl/Metadynamics_NaCl.md +++ b/examples/openmm/metad/nacl/Metadynamics_NaCl.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/examples/openmm/spectral_abf/ADP_SpectralABF.ipynb b/examples/openmm/spectral_abf/ADP_SpectralABF.ipynb index ac651e9c..2fc5ecb4 100644 --- a/examples/openmm/spectral_abf/ADP_SpectralABF.ipynb +++ b/examples/openmm/spectral_abf/ADP_SpectralABF.ipynb @@ -342,7 +342,6 @@ "# mesh = result[\"mesh\"]\n", "# fes_fn = result[\"fes_fn\"]\n", "# A = fes_fn(mesh)\n", - "A = A.max() - A\n", "A = A.reshape(grid.shape)" ] }, diff --git a/examples/openmm/spectral_abf/ADP_SpectralABF.md b/examples/openmm/spectral_abf/ADP_SpectralABF.md index de443d4e..f0bd66a1 100644 --- a/examples/openmm/spectral_abf/ADP_SpectralABF.md +++ b/examples/openmm/spectral_abf/ADP_SpectralABF.md @@ -7,7 +7,7 @@ jupyter: extension: .md format_name: markdown format_version: '1.3' - jupytext_version: 1.16.2 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -211,7 +211,6 @@ A = result["free_energy"] # mesh = result["mesh"] # fes_fn = result["fes_fn"] # A = fes_fn(mesh) -A = A.max() - A A = A.reshape(grid.shape) ``` diff --git a/examples/openmm/spectral_abf/alanine-dipeptide.py b/examples/openmm/spectral_abf/alanine-dipeptide.py index 3f5c611d..166dce33 100644 --- a/examples/openmm/spectral_abf/alanine-dipeptide.py +++ b/examples/openmm/spectral_abf/alanine-dipeptide.py @@ -78,7 +78,7 @@ def get_args(argv): ("time-steps", "t", int, 5e5, "Number of simulation steps"), ] parser = argparse.ArgumentParser(description="Example script to run Spectral ABF") - for (name, short, T, val, doc) in available_args: + for name, short, T, val, doc in available_args: parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc) return parser.parse_args(argv) @@ -108,7 +108,6 @@ def main(argv=[]): # Set min free energy to zero A = fes_fn(xi) - A = A.max() - A A = A.reshape(plot_grid.shape) # plot and save free energy to a PNG file diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py index 37508ff8..63892f90 100644 --- a/pysages/backends/lammps.py +++ b/pysages/backends/lammps.py @@ -30,6 +30,7 @@ from pysages.typing import Callable, Optional from pysages.utils import copy, identity +kConversionFactors = {"real": 2390.0573615334906, "metal": 1.0364269e-4, "electron": 1.06657236} kDefaultLocation = dlext.kOnHost if not hasattr(ExecutionSpace, "kOnDevice") else dlext.kOnDevice @@ -129,6 +130,8 @@ def build_helpers(context, sampling_method, on_gpu, restore_fn): """ utils = importlib.import_module(".utils", package="pysages.backends") dim = context.extract_setting("dimension") + units = context.extract_global("units") + factor = kConversionFactors.get(units) # Depending on the device being used we need to use either cupy or numpy # (or numba) to generate a view of jax's DeviceArrays @@ -141,6 +144,16 @@ def build_helpers(context, sampling_method, on_gpu, restore_fn): def sync_forces(): pass + if factor is None: + + def add_bias(forces, biases): + forces[:, :3] += biases + + else: + + def add_bias(forces, biases): + forces[:, :3] += factor * biases + # TODO: check if this can be sped up. # pylint: disable=W0511 def bias(snapshot, state): """Adds the computed bias to the forces.""" @@ -148,7 +161,7 @@ def bias(snapshot, state): return forces = view(snapshot.forces) biases = view(state.bias.block_until_ready()) - forces[:, :3] += biases + add_bias(forces, biases) sync_forces() snapshot_methods = build_snapshot_methods(sampling_method, on_gpu) diff --git a/pysages/backends/openmm.py b/pysages/backends/openmm.py index 0c5f8e44..b2cc83d8 100644 --- a/pysages/backends/openmm.py +++ b/pysages/backends/openmm.py @@ -157,13 +157,12 @@ def bias(snapshot, state, sync_backend): """Adds the computed bias to the forces.""" if state.bias is None: return - biases = adapt(state.bias) # Forces may be computed asynchronously on the GPU, so we need to # synchronize them before applying the bias. sync_backend() + biases = adapt(state.bias) forces = view(snapshot.forces) - biases = view(biases.block_until_ready()) - forces += biases + forces += view(biases.block_until_ready()) sync_forces() def dimensionality(): diff --git a/pysages/methods/__init__.py b/pysages/methods/__init__.py index dbeeab15..875baa91 100644 --- a/pysages/methods/__init__.py +++ b/pysages/methods/__init__.py @@ -69,6 +69,7 @@ from .harmonic_bias import HarmonicBias from .metad import Metadynamics from .restraints import CVRestraints +from .sirens import Sirens from .spectral_abf import SpectralABF from .spline_string import SplineString from .umbrella_integration import UmbrellaIntegration diff --git a/pysages/methods/abf.py b/pysages/methods/abf.py index ba713d15..dbba8888 100644 --- a/pysages/methods/abf.py +++ b/pysages/methods/abf.py @@ -27,7 +27,7 @@ from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals from pysages.typing import JaxArray, NamedTuple -from pysages.utils import dispatch, solve_pos_def +from pysages.utils import dispatch, linear_solver class ABFState(NamedTuple): @@ -103,6 +103,11 @@ class ABF(GriddedSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -110,6 +115,7 @@ class ABF(GriddedSamplingMethod): def __init__(self, cvs, grid, **kwargs): super().__init__(cvs, grid, **kwargs) self.N = np.asarray(self.kwargs.get("N", 500)) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers, *args, **kwargs): """ @@ -158,6 +164,7 @@ def _abf(method, snapshot, helpers): dt = snapshot.dt dims = grid.shape.size natoms = np.size(snapshot.positions, 0) + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) estimate_force = build_force_estimator(method) @@ -201,11 +208,7 @@ def update(state, data): xi, Jxi = cv(data) p = data.momenta - # The following could equivalently be computed as `linalg.pinv(Jxi.T) @ p` - # (both seem to have the same performance). - # Another option to benchmark against is - # Wp = linalg.tensorsolve(Jxi @ Jxi.T, Jxi @ p) - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt diff --git a/pysages/methods/ann.py b/pysages/methods/ann.py index 1c7b1948..df139cb6 100644 --- a/pysages/methods/ann.py +++ b/pysages/methods/ann.py @@ -120,7 +120,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs): default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6)) self.optimizer = kwargs.get("optimizer", default_optimizer) - def build(self, snapshot, helpers): + def build(self, snapshot, helpers, *_args, **_kwargs): return _ann(self, snapshot, helpers) @@ -155,7 +155,7 @@ def update(state, data): in_training_regime = ncalls > train_freq # We only train every `train_freq` timesteps in_training_step = in_training_regime & (ncalls % train_freq == 1) - hist, phi, prob, nn = learn_free_energy(state, in_training_step) + hist, prob, phi, nn = learn_free_energy(state, in_training_step) # Compute the collective variable and its jacobian xi, Jxi = cv(data) I_xi = get_grid_index(xi) @@ -208,10 +208,10 @@ def learn_free_energy(state): # hist = np.zeros_like(state.hist) # - return hist, phi, prob, nn + return hist, prob, phi, nn def skip_learning(state): - return state.hist, state.phi, state.prob, state.nn + return state.hist, state.prob, state.phi, state.nn def _learn_free_energy(state, in_training_step): return cond(in_training_step, learn_free_energy, skip_learning, state) @@ -241,7 +241,7 @@ def predict_force(data): params = pack(nn.params, layout) return nn.std * f64(model_grad(params, f32(x)).flatten()) - def zero_force(data): + def zero_force(_data): return np.zeros(dims) def estimate_force(xi, I_xi, nn, in_training_regime): @@ -282,7 +282,6 @@ def analyze(result: Result[ANN]): """ method = result.method - states = result.states grid = method.grid mesh = (compute_mesh(grid) + 1) * grid.size / 2 + grid.lower @@ -306,7 +305,7 @@ def fes_fn(x): transpose = grid_transposer(grid) d = mesh.shape[-1] - for s in states: + for s in result.states: histograms.append(transpose(s.hist)) free_energies.append(transpose(s.phi.max() - s.phi)) nns.append(s.nn) diff --git a/pysages/methods/cff.py b/pysages/methods/cff.py index 27d856e9..3a3306c1 100644 --- a/pysages/methods/cff.py +++ b/pysages/methods/cff.py @@ -31,7 +31,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver # Aliases f32 = np.float32 @@ -148,6 +148,11 @@ class CFF(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -171,6 +176,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs): self.fmodel = MLP(dims, dims, topology, transform=scale) self.optimizer = kwargs.get("optimizer", default_optimizer) self.foptimizer = kwargs.get("foptimizer", default_foptimizer) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers): return _cff(self, snapshot, helpers) @@ -187,6 +193,7 @@ def _cff(method: CFF, snapshot, helpers): fps, _ = unpack(method.fmodel.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy = build_free_energy_learner(method) estimate_force = build_force_estimator(method) @@ -216,12 +223,12 @@ def update(state, data): ncalls = state.ncalls + 1 in_training_regime = ncalls > train_freq in_training_step = in_training_regime & (ncalls % train_freq == 1) - histp, fe, prob, nn, fnn = learn_free_energy(state, in_training_step) + histp, prob, fe, nn, fnn = learn_free_energy(state, in_training_step) # Compute the collective variable and its jacobian xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) @@ -281,7 +288,7 @@ def train(nn, fnn, data): return NNData(params, nn.mean, s), NNData(fparams, f_mean, s) def skip_learning(state): - return state.hist, state.fe, state.prob, state.nn, state.fnn + return state.histp, state.prob, state.fe, state.nn, state.fnn def learn_free_energy(state): prob = state.prob + state.histp * np.exp(state.fe / kT) @@ -294,7 +301,7 @@ def learn_free_energy(state): fe = nn.std * model.apply(params, inputs).reshape(fe.shape) fe = fe - fe.min() - return histp, fe, prob, nn, fnn + return histp, prob, fe, nn, fnn def _learn_free_energy(state, in_training_step): return cond(in_training_step, learn_free_energy, skip_learning, state) diff --git a/pysages/methods/funn.py b/pysages/methods/funn.py index 6130d396..593dda5a 100644 --- a/pysages/methods/funn.py +++ b/pysages/methods/funn.py @@ -34,7 +34,7 @@ from pysages.ml.training import NNData, build_fitting_function, convolve, normalize from pysages.ml.utils import blackman_kernel, pack, unpack from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class FUNNState(NamedTuple): @@ -126,6 +126,11 @@ class FUNN(NNSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -142,6 +147,7 @@ def __init__(self, cvs, grid, topology, **kwargs): self.model = MLP(dims, dims, topology, transform=scale) default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6)) self.optimizer = kwargs.get("optimizer", default_optimizer) + self.use_pinv = self.kwargs.get("use_pinv", False) def build(self, snapshot, helpers): return _funn(self, snapshot, helpers) @@ -160,6 +166,7 @@ def _funn(method, snapshot, helpers): ps, _ = unpack(method.model.parameters) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) learn_free_energy_grad = build_free_energy_grad_learner(method) estimate_free_energy_grad = build_force_estimator(method) @@ -186,7 +193,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # I_xi = get_grid_index(xi) diff --git a/pysages/methods/sirens.py b/pysages/methods/sirens.py new file mode 100644 index 00000000..836e31b6 --- /dev/null +++ b/pysages/methods/sirens.py @@ -0,0 +1,500 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +SIRENs Sampling + +Sirens is a free energy sampling method which uses artificial neural networks to +generate an on-the-fly adaptive bias capable of rapidly resolving free energy landscapes. +The method learns both from the frequency of visits to bins in a CV-space and generalized +force estimates. It can be seen as a generalization of ANN and FUNN sampling methods that +uses one neural network to approximate the free energy and its derivatives. +""" + +import numbers +from functools import partial + +from jax import grad, jit +from jax import numpy as np +from jax import vmap +from jax.lax import cond + +from pysages.approxfun import compute_mesh +from pysages.approxfun import scale as _scale +from pysages.grids import build_indexer, grid_transposer +from pysages.methods.core import NNSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.methods.utils import numpyfy_vals +from pysages.ml import objectives +from pysages.ml.models import Siren +from pysages.ml.objectives import GradientsSSE, L2Regularization, Sobolev1SSE +from pysages.ml.optimizers import LevenbergMarquardt +from pysages.ml.training import NNData, build_fitting_function, convolve +from pysages.ml.utils import blackman_kernel, pack, unpack +from pysages.typing import JaxArray, NamedTuple, Tuple +from pysages.utils import dispatch, first_or_all, linear_solver + + +class SirensState(NamedTuple): # pylint: disable=R0903 + """ + Parameters + ---------- + + xi: JaxArray + Last collective variable recorded in the simulation. + + bias: JaxArray + Array with biasing forces for each particle. + + hist: JaxArray + Histogram of visits to the bins in the collective variable grid. + + histp: JaxArray + Partial histogram of visits to the bins in the collective variable grid, + resets to zero after each training sweep. + + prob: JaxArray + Probability distribution of the CV evaluated ath each bin of the CV grid. + + fe: JaxArray + Current estimate of the free energy evaluated at each bin of the CV grid. + + Fsum: JaxArray + The cumulative force recorded at each bin of the CV grid. + + force: JaxArray + Average force at each bin of the CV grid. + + Wp: JaxArray + Estimate of the product $W p$ where `p` is the matrix of momenta and + `W` the Moore-Penrose inverse of the Jacobian of the CVs. + + Wp_: JaxArray + The value of `Wp` for the previous integration step. + + nn: NNDada + Bundle of the neural network parameters, and output scaling coefficients. + + ncalls: int + Counts the number of times the method's update has been called. + """ + + xi: JaxArray + bias: JaxArray + hist: JaxArray + histp: JaxArray + prob: JaxArray + fe: JaxArray + Fsum: JaxArray + force: JaxArray + Wp: JaxArray + Wp_: JaxArray + nn: NNData + ncalls: int + + def __repr__(self): + return repr("PySAGES " + type(self).__name__) + + +class PartialSirensState(NamedTuple): # pylint: disable=C0115,R0903 + xi: JaxArray + hist: JaxArray + Fsum: JaxArray + ind: Tuple + nn: NNData + pred: bool + + +class Sirens(NNSamplingMethod): + """ + Implementation of the sampling method described in + "Sobolev Sampling of Free Energy Landscapes" + + Parameters + ---------- + cvs: Union[List, Tuple] + List of collective variables. + + grid: Grid + Specifies the CV domain and number of bins for discretizing the CV space + along each CV dimension. + + topology: Tuple[int] + Defines the architecture of the neural network + (number of nodes of each hidden layer). + + mode: Literal["abf", "cff"] + If `mode == "cff"`, the model will be trained in both histogram of visits and mean + forces information. Otherwise, only forces will be used as in `ABF` and `FUNN`. + + kT: Optional[numbers.Real] + Value of `kT` in the same units as the backend internal energy units. + When `mode == "cff"` this parameter has to be a machine real, otherwise it is + ignored even if provided. + + N: Optional[int] = 500 + Threshold parameter before accounting for the full average of the + binned generalized mean force. + + train_freq: Optional[int] = 5000 + Training frequency. + + optimizer: Optional[Optimizer] = None + Optimization method used for training. Must be compatible with the selected `mode`. + + restraints: Optional[CVRestraints] = None + If provided, indicate that harmonic restraints will be applied when any + collective variable lies outside the box from `restraints.lower` to + `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, grid, topology, **kwargs): + mode = kwargs.get("mode", "abf") + kT = kwargs.get("kT", None) + + regularizer = L2Regularization(1e-4) + loss = GradientsSSE() if mode == "abf" else Sobolev1SSE() + default_optimizer = LevenbergMarquardt(loss=loss, reg=regularizer, max_iters=250) + optimizer = kwargs.get("optimizer", default_optimizer) + + self.__check_init_invariants__(mode, kT, optimizer) + + super().__init__(cvs, grid, topology, **kwargs) + + self.mode = mode + self.kT = kT + self.N = np.asarray(self.kwargs.get("N", 500)) + self.train_freq = self.kwargs.get("train_freq", 5000) + + dims = grid.shape.size + scale = partial(_scale, grid=grid) + self.model = Siren(dims, 1, topology, transform=scale) + self.optimizer = optimizer + self.use_pinv = self.kwargs.get("use_pinv", False) + + def __check_init_invariants__(self, mode, kT, optimizer): + if mode not in ("abf", "cff"): + raise ValueError(f"Invalid mode {mode}. Possible values are 'abf' and 'cff'") + + if mode == "cff": + if kT is None: + raise KeyError( + "When running in 'cff' mode, a keyword argument `kT` in the same " + "units as the backend internal energy units must be provided." + ) + assert isinstance(kT, numbers.Real) + assert isinstance(optimizer.loss, objectives.Sobolev1Loss) + else: + assert isinstance(optimizer.loss, objectives.GradientsLoss) + + def build(self, snapshot, helpers, *_args, **_kwargs): + return _sirens(self, snapshot, helpers) + + +def _sirens(method: Sirens, snapshot, helpers): + cv = method.cv + grid = method.grid + train_freq = method.train_freq + dt = snapshot.dt + + # Neural network paramters + ps, _ = unpack(method.model.parameters) + + # Helper methods + tsolve = linear_solver(method.use_pinv) + get_grid_index = build_indexer(grid) + learn_free_energy = build_free_energy_learner(method) + estimate_force = build_force_estimator(method) + + if method.mode == "cff": + increase_me_maybe = jit(lambda a, idx, v: a.at[idx].add(v)) + else: + increase_me_maybe = jit(lambda a, idx, v: a) + + def initialize(): + dims = grid.shape.size + natoms = np.size(snapshot.positions, 0) + gshape = grid.shape if dims > 1 else (*grid.shape, 1) + + xi, _ = cv(helpers.query(snapshot)) + bias = np.zeros((natoms, helpers.dimensionality())) + hist = np.zeros(gshape, dtype=np.uint32) + Fsum = np.zeros((*grid.shape, dims)) + force = np.zeros(dims) + Wp = np.zeros(dims) + Wp_ = np.zeros(dims) + nn = NNData(ps, np.array(0.0), np.array(1.0)) + + if method.mode == "cff": + histp = np.zeros(gshape, dtype=np.uint32) + prob = np.zeros(gshape) + fe = np.zeros(gshape) + else: + histp = prob = fe = None + + return SirensState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, 0) + + def update(state, data): + # During the intial stage, when there are not enough collected samples, use ABF + ncalls = state.ncalls + 1 + in_training_regime = ncalls > train_freq + in_training_step = in_training_regime & (ncalls % train_freq == 1) + histp, prob, fe, nn = learn_free_energy(state, in_training_step) + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + # + p = data.momenta + Wp = tsolve(Jxi, p) + dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt + # + I_xi = get_grid_index(xi) + hist = state.hist.at[I_xi].add(1) + Fsum = state.Fsum.at[I_xi].add(dWp_dt + state.force) + histp = increase_me_maybe(histp, I_xi, 1) # Special handling depending on the mode + # + force = estimate_force(PartialSirensState(xi, hist, Fsum, I_xi, nn, in_training_regime)) + bias = (-Jxi.T @ force).reshape(state.bias.shape) + # + return SirensState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, state.Wp, nn, ncalls) + + return snapshot, initialize, generalize(update, helpers) + + +def build_free_energy_learner(method: Sirens): + """ + Returns a function that given a `SirensState` trains the method's neural networks which + give estimates of the probability density and the gradient of the free energy. + """ + # Local aliases + f32 = np.float32 + + mode = method.mode + kT = method.kT + grid = method.grid + model = method.model + optimizer = method.optimizer + + dims = grid.shape.size + shape = (*grid.shape, 1) + inputs = f32((compute_mesh(grid) + 1) * grid.size / 2 + grid.lower) + smoothing_kernel = f32(blackman_kernel(dims, 5)) + padding = "wrap" if grid.is_periodic else "edge" + conv = partial(convolve, kernel=smoothing_kernel, boundary=padding) + + _, layout = unpack(model.parameters) + fit = build_fitting_function(model, optimizer) + normalize_gradients = _build_gradient_normalizer(grid) + + def vsmooth(y): + return vmap(conv)(y.T).T + + smooth = conv if dims > 1 else vsmooth + + if mode == "cff": + + def preprocess(y, dy): + dy, dy_std = normalize_gradients(dy) + s = np.maximum(y.std(), dy_std) + return (smooth(f32(y / s)), dy / s), s + + def preupdate_free_energy(state): + prob = state.prob + state.histp * np.exp(state.fe / kT) + fe = kT * np.log(np.maximum(1, prob)) + histp = np.zeros_like(state.histp) + return histp, prob, fe + + def postupdate_free_energy(nn, fe): + params = pack(nn.params, layout) + fe = nn.std * model.apply(params, inputs).reshape(fe.shape) + return fe - fe.min() + + else: + + def preprocess(_, dy): + dy, s = normalize_gradients(dy) + return dy / s, s + + def preupdate_free_energy(state): + return state.histp, state.prob, state.fe + + def postupdate_free_energy(_, fe): + return fe + + def train(nn, data): + targets, s = preprocess(*data) + params = fit(nn.params, inputs, targets).params + return NNData(params, nn.mean, s) + + def skip_learning(state): + return state.histp, state.prob, state.fe, state.nn + + def learn_free_energy(state): + force = state.Fsum / np.maximum(1, state.hist.reshape(shape)) + histp, prob, fe = preupdate_free_energy(state) + + nn = train(state.nn, (fe, force)) + fe = postupdate_free_energy(nn, fe) + + return histp, prob, fe, nn + + def _learn_free_energy(state, in_training_step): + return cond(in_training_step, learn_free_energy, skip_learning, state) + + return _learn_free_energy + + +@dispatch +def build_force_estimator(method: Sirens): + """ + Returns a function that given a `PartialSirensState` computes an estimate to the force. + """ + # Local aliases + f32 = np.float32 + f64 = np.float64 + + N = method.N + grid = method.grid + model = method.model + _, layout = unpack(model.parameters) + + model_grad = grad(lambda p, x: model.apply(p, x).sum(), argnums=-1) + + def average_force(state): + i = state.ind + return state.Fsum[i] / np.maximum(N, state.hist[i]) + + def predict_force(state): + nn = state.nn + x = state.xi + params = pack(nn.params, layout) + return nn.std * f64(model_grad(params, f32(x)).flatten()) + + def _estimate_force(state): + return cond(state.pred, predict_force, average_force, state) + + if method.restraints is None: + estimate_force = _estimate_force + else: + lo, hi, kl, kh = method.restraints + + def restraints_force(state): + xi = state.xi.reshape(grid.shape.size) + return apply_restraints(lo, hi, kl, kh, xi) + + def estimate_force(state): + ob = np.any(np.array(state.ind) == grid.shape) + return cond(ob, restraints_force, _estimate_force, state) + + return estimate_force + + +def _build_gradient_normalizer(grid): + if grid.is_periodic: + + def normalize_gradients(data): + axes = tuple(range(data.ndim - 1)) + mean = data.mean(axis=axes) + std = data.std(axis=axes).max() + return data - mean, std + + else: + + def normalize_gradients(data): + axes = tuple(range(data.ndim - 1)) + s = data.std(axis=axes).max() + return data, s + + return normalize_gradients + + +@dispatch +def analyze(result: Result[Sirens]): + """ + Parameters + ---------- + + result: Result[Sirens] + Result bundle containing the method, final states, and callbacks. + + Returns + ------- + + dict: + A dictionary with the following keys: + + histogram: JaxArray + A histogram of the visits to each bin in the CV grid. + + mean_force: JaxArray + Generalized mean forces at each bin in the CV grid. + + free_energy: JaxArray + Free energy at each bin in the CV grid. + + mesh: JaxArray + These are the values of the CVs that are used as inputs for training. + + nn: NNData + Parameters of the resulting trained neural network. + + fnn: NNData + Parameters of the resulting trained force-based neural network. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the CV domain defined + by the grid. + """ + method = result.method + + grid = method.grid + mesh = (compute_mesh(grid) + 1) * grid.size / 2 + grid.lower + model = method.model + _, layout = unpack(model.parameters) + + def average_forces(hist, Fsum): + hist = hist.reshape(*Fsum.shape[:-1], 1) + return Fsum / np.maximum(hist, 1) + + def build_fes_fn(nn): + def fes_fn(x): + params = pack(nn.params, layout) + A = nn.std * model.apply(params, x) + nn.mean + return A.max() - A + + return jit(fes_fn) + + histograms = [] + mean_forces = [] + free_energies = [] + nns = [] + fes_fns = [] + + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + d = mesh.shape[-1] + + for s in result.states: + fes_fn = build_fes_fn(s.nn) + histograms.append(s.hist) + mean_forces.append(transpose(average_forces(s.hist, s.Fsum))) + free_energies.append(transpose(fes_fn(mesh))) + nns.append(s.nn) + fes_fns.append(fes_fn) + + ana_result = { + "histogram": first_or_all(histograms), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "mesh": transpose(mesh).reshape(-1, d).squeeze(), + "nn": first_or_all(nns), + "fes_fn": first_or_all(fes_fns), + } + + return numpyfy_vals(ana_result) diff --git a/pysages/methods/spectral_abf.py b/pysages/methods/spectral_abf.py index dcba5317..c50d2f11 100644 --- a/pysages/methods/spectral_abf.py +++ b/pysages/methods/spectral_abf.py @@ -30,7 +30,7 @@ from pysages.methods.restraints import apply_restraints from pysages.methods.utils import numpyfy_vals from pysages.typing import JaxArray, NamedTuple, Tuple -from pysages.utils import dispatch, first_or_all, solve_pos_def +from pysages.utils import dispatch, first_or_all, linear_solver class SpectralABFState(NamedTuple): @@ -124,6 +124,11 @@ class SpectralABF(GriddedSamplingMethod): If provided, indicate that harmonic restraints will be applied when any collective variable lies outside the box from `restraints.lower` to `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. """ snapshot_flags = {"positions", "indices", "momenta"} @@ -135,8 +140,9 @@ def __init__(self, cvs, grid, **kwargs): self.fit_threshold = self.kwargs.get("fit_threshold", 500) self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev]) self.model = SpectralGradientFit(self.grid) + self.use_pinv = self.kwargs.get("use_pinv", False) - def build(self, snapshot, helpers): + def build(self, snapshot, helpers, *_args, **_kwargs): """ Returns the `initialize` and `update` functions for the sampling method. """ @@ -154,6 +160,7 @@ def _spectral_abf(method, snapshot, helpers): natoms = np.size(snapshot.positions, 0) # Helper methods + tsolve = linear_solver(method.use_pinv) get_grid_index = build_indexer(grid) fit = build_fitter(method.model) fit_forces = build_free_energy_fitter(method, fit) @@ -181,7 +188,7 @@ def update(state, data): xi, Jxi = cv(data) # p = data.momenta - Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p) + Wp = tsolve(Jxi, p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # @@ -199,7 +206,7 @@ def update(state, data): return snapshot, initialize, generalize(update, helpers) -def build_free_energy_fitter(method: SpectralABF, fit): +def build_free_energy_fitter(_method: SpectralABF, fit): """ Returns a function that given a `SpectralABFState` performs a least squares fit of the generalized average forces for finding the coefficients of a basis functions expansion @@ -293,7 +300,6 @@ def analyze(result: Result[SpectralABF]): For multiple-replicas runs we return a list (one item per-replica) for each attribute. """ method = result.method - states = result.states grid = method.grid mesh = (compute_mesh(grid) + 1) * grid.size / 2 + grid.lower @@ -320,7 +326,7 @@ def fes_fn(x): transpose = grid_transposer(grid) d = mesh.shape[-1] - for s in states: + for s in result.states: fes_fn = build_fes_fn(s.fun) hists.append(transpose(s.hist)) mean_forces.append(transpose(average_forces(s.hist, s.Fsum))) diff --git a/pysages/ml/models.py b/pysages/ml/models.py index a89aa00a..19e8390d 100644 --- a/pysages/ml/models.py +++ b/pysages/ml/models.py @@ -5,11 +5,11 @@ from itertools import chain from jax import numpy as np -from jax.nn.initializers import variance_scaling +from jax import random -from pysages.ml.utils import dispatch, rng_key +from pysages.ml.utils import dispatch, rng_key, uniform_scaling from pysages.typing import Callable, JaxArray, Optional -from pysages.utils import try_import +from pysages.utils import identity, try_import stax = try_import("jax.example_libraries.stax", "jax.experimental.stax") @@ -27,14 +27,16 @@ class Model: @dataclass class MLPBase(Model): + """ + Multilayer-perceptron network base class. Provides a convenience constructor. + """ + def __init__(self, indim, layers, seed): # Build initialization and application functions for the network init, apply = stax.serial(*layers) # Randomly initialize network parameters with seed _, parameters = init(rng_key(seed), (-1, indim)) - # - self.parameters = parameters - self.apply = apply + super().__init__(parameters, apply) @dataclass @@ -80,14 +82,14 @@ def __init__(self, indim, outdim, topology, activation=stax.Tanh, transform=None @dataclass class Siren(MLPBase): """ - Siren network as decribed in [1] + Sinusoidal Representation Network as decribed in [1] [1] Sitzmann, V., Martel, J., Bergman, A., Lindell, D., and Wetzstein, G. "Implicit neural representations with periodic activation functions." Advances in Neural Information Processing Systems 33 (2020). """ - def __init__(self, indim, outdim, topology, omega=1.0, transform=None, seed=0): + def __init__(self, indim, outdim, topology, omega_0=16, omega=1, transform=None, seed=0): """ Arguments --------- @@ -101,7 +103,7 @@ def __init__(self, indim, outdim, topology, omega=1.0, transform=None, seed=0): topology: Tuple Defines the structure of the inner layers for the network. - omega: + omega_0: Weight distribution factor ω₀ for the first layer (as described in [1]). transform: Optional[Callable] @@ -112,20 +114,54 @@ def __init__(self, indim, outdim, topology, omega=1.0, transform=None, seed=0): """ tlayer = build_transform_layer(transform) - # Weight initialization for Sirens - pdf_in = variance_scaling(1.0 / 3, "fan_in", "uniform") - pdf = variance_scaling(2.0 / omega**2, "fan_in", "uniform") - # Sine activation function - σ_in = stax.elementwise(lambda x: np.sin(omega * x)) - σ = stax.elementwise(lambda x: np.sin(x)) + # Weight initialization for the first layer of Sirens + pdf_0 = uniform_scaling(1.0, "fan_in", scale_transform=identity) # Build layers - layer_in = [stax.Flatten, *tlayer, stax.Dense(topology[0], pdf_in), σ_in] - layers = list(chain.from_iterable((stax.Dense(i, pdf), σ) for i in topology[1:])) - layers = layer_in + layers + [stax.Dense(outdim, pdf)] + layers_0 = [stax.Flatten, *tlayer, SirenLayer(topology[0], omega_0, pdf_0)] + layers = [SirenLayer(i, omega) for i in topology[1:]] + layers = layers_0 + layers + [SirenLayer(outdim, omega, is_linear=True)] # super().__init__(indim, layers, seed) +def SirenLayer( + out_dim, + omega, + W_init: Optional[Callable] = None, + b_init: Optional[Callable] = None, + is_linear: bool = False, +): + """ + Constructor function for a dense (fully-connected) layer for Sinusoidal Representation + Networks. Similar to `jax.example_libraries.stax.Dense`. + """ + + if W_init is None: + W_init = uniform_scaling(6.0 / omega**2, "fan_in") + + if b_init is None: + b_init = uniform_scaling(1.0, "fan_in", bias_like=True) + + if is_linear: + σ = identity(lambda x, omega: x) + else: + σ = identity(lambda x, omega: np.sin(omega * x)) + + def init(rng, input_shape): + k1, k2 = random.split(rng) + inner_shape = (input_shape[-1], out_dim) + output_shape = input_shape[:-1] + (out_dim,) + W, b = W_init(k1, inner_shape), b_init(k2, inner_shape) + return output_shape, (W, b) + + def apply(params, inputs, **_): + # NOTE: experiment with having omega in params + W, b = params + return σ(np.dot(inputs, W) + b, omega) + + return init, apply + + @dispatch def build_transform_layer(transform: Optional[Callable] = None): return () if transform is None else (stax.elementwise(transform),) diff --git a/pysages/ml/objectives.py b/pysages/ml/objectives.py index cde6094a..3be1ddce 100644 --- a/pysages/ml/objectives.py +++ b/pysages/ml/objectives.py @@ -22,48 +22,36 @@ class Loss: Abstract base class for all losses. """ - pass - class GradientsLoss(Loss): """ Abstract base class for gradient-based losses. """ - pass - class Sobolev1Loss(Loss): """ Abstract base class for losses that depend on both target values and gradients. """ - pass - class SSE(Loss): """ Sum-of-Squared-Errors Loss. """ - pass - class GradientsSSE(GradientsLoss): """ Sum-of-Squared-Gradient-Errors Loss. """ - pass - class Sobolev1SSE(Sobolev1Loss): """ Sum-of-Squared-Errors and Squared-Gradient-Errors Loss. """ - pass - # Regularizers class Regularizer: @@ -71,14 +59,12 @@ class Regularizer: Abstract base class for all regularizers. """ - pass - # On Python >= 3.9 NamedTuple cannot be used directly as superclass -L2Regularization = NamedTuple("L2Regularization", [("coeff", float)]) +L2RegularizationBase = NamedTuple("L2Regularization", [("coeff", float)]) -class L2Regularization(Regularizer, L2Regularization): +class L2Regularization(Regularizer, L2RegularizationBase): """ L2-norm regularization. @@ -92,11 +78,9 @@ class VarRegularization(Regularizer): Weights-variance regularization. """ - pass - @dispatch.abstract -def build_objective_function(model, loss, reg): +def build_objective_function(model, loss, reg): # pylint: disable=W0613 """ Given a model, loss and regularizer, it builds an objective function that takes a set of parameters, input and reference values for the model, and returns the @@ -106,6 +90,7 @@ def build_objective_function(model, loss, reg): @dispatch def build_objective_function(model, loss: Loss, reg: Regularizer): + # pylint: disable=C0116,E0102 cost = build_cost_function(loss, reg) def objective(params, inputs, reference): @@ -118,7 +103,8 @@ def objective(params, inputs, reference): @dispatch -def build_objective_function(model, loss: GradientsSSE, reg: Regularizer): +def build_objective_function(model, _loss: GradientsSSE, reg: Regularizer): + # pylint: disable=C0116,E0102 apply = grad(lambda p, x: model.apply(p, x.reshape(1, -1)).sum(), argnums=1) cost = build_cost_function(SSE(), reg) @@ -134,6 +120,7 @@ def objective(params, inputs, reference): @dispatch def build_objective_function(model, loss: Sobolev1Loss, reg: Regularizer): + # pylint: disable=C0116,E0102 apply = value_and_grad(lambda p, x: model.apply(p, x.reshape(1, -1)).sum(), argnums=1) cost = build_cost_function(loss, reg) @@ -151,7 +138,7 @@ def objective(params, inputs, refs): @dispatch.abstract -def build_cost_function(loss, reg): +def build_cost_function(loss, reg): # pylint: disable=W0613 """ Given a loss and regularizer, it builds a regularized cost function that takes a set of model parameters and the error differences between such model's @@ -160,7 +147,8 @@ def build_cost_function(loss, reg): @dispatch -def build_cost_function(loss: Union[SSE, GradientsSSE], reg: L2Regularization): +def build_cost_function(_loss: Union[SSE, GradientsSSE], reg: L2Regularization): + # pylint: disable=C0116,E0102 r = reg.coeff def cost(errors, ps): @@ -170,7 +158,9 @@ def cost(errors, ps): @dispatch -def build_cost_function(loss: Union[SSE, GradientsSSE], reg: VarRegularization): +def build_cost_function(_loss: Union[SSE, GradientsSSE], _reg: VarRegularization): + # pylint: disable=C0116,E0102 + def cost(errors, ps): # k = ps.size return (sum_squares(errors) + ps.var()) / 2 @@ -179,7 +169,8 @@ def cost(errors, ps): @dispatch -def build_cost_function(loss: Sobolev1SSE, reg: L2Regularization): +def build_cost_function(_loss: Sobolev1SSE, reg: L2Regularization): + # pylint: disable=C0116,E0102 r = reg.coeff def cost(errors, ps): @@ -190,7 +181,9 @@ def cost(errors, ps): @dispatch -def build_cost_function(loss: Sobolev1SSE, reg: VarRegularization): +def build_cost_function(_loss: Sobolev1SSE, _reg: VarRegularization): + # pylint: disable=C0116,E0102 + def cost(errors, ps): # k = ps.size e, ge = errors @@ -200,7 +193,7 @@ def cost(errors, ps): @dispatch.abstract -def build_error_function(model, loss): +def build_error_function(model, loss): # pylint: disable=W0613 """ Given a model and loss, it builds a function that computes the error differences between the model's predictions at each input value and some @@ -209,7 +202,8 @@ def build_error_function(model, loss): @dispatch -def build_error_function(model, loss: Loss): +def build_error_function(model, _loss: Loss): + # pylint: disable=C0116,E0102 _, layout = unpack(model.parameters) def error(ps, inputs, reference): @@ -221,7 +215,8 @@ def error(ps, inputs, reference): @dispatch -def build_error_function(model, loss: GradientsLoss): +def build_error_function(model, _loss: GradientsLoss): + # pylint: disable=C0116,E0102 apply = grad(lambda p, x: model.apply(p, x.reshape(1, -1)).sum(), argnums=1) _, layout = unpack(model.parameters) @@ -235,7 +230,8 @@ def error(ps, inputs, reference): @dispatch -def build_error_function(model, loss: Sobolev1Loss): +def build_error_function(model, _loss: Sobolev1Loss): + # pylint: disable=C0116,E0102 apply = value_and_grad(lambda p, x: model.apply(p, x.reshape(1, -1)).sum(), argnums=1) _, layout = unpack(model.parameters) @@ -261,7 +257,7 @@ def build_split_cost_function(model, loss, reg): @dispatch.abstract -def build_damped_hessian(loss, reg): +def build_damped_hessian(loss, reg): # pylint: disable=W0613 """ Returns a function that evaluates the damped hessian for the Levenberg-Marquardt optimizer, given a model's Jacobian `J` with respect @@ -270,7 +266,8 @@ def build_damped_hessian(loss, reg): @dispatch -def build_damped_hessian(loss: Loss, reg: L2Regularization): +def build_damped_hessian(_loss: Loss, reg: L2Regularization): + # pylint: disable=C0116,E0102 r = reg.coeff def dhessian(J, mu): @@ -282,7 +279,9 @@ def dhessian(J, mu): @dispatch -def build_damped_hessian(loss: Loss, reg: VarRegularization): +def build_damped_hessian(_loss: Loss, _reg: VarRegularization): + # pylint: disable=C0116,E0102 + def dhessian(J, mu): H = J.T @ J k = H.shape[0] @@ -293,7 +292,8 @@ def dhessian(J, mu): @dispatch -def build_damped_hessian(loss: Sobolev1Loss, reg: L2Regularization): +def build_damped_hessian(_loss: Sobolev1Loss, reg: L2Regularization): + # pylint: disable=C0116,E0102 r = reg.coeff def dhessian(jacs, mu): @@ -306,7 +306,9 @@ def dhessian(jacs, mu): @dispatch -def build_damped_hessian(loss: Sobolev1Loss, reg: VarRegularization): +def build_damped_hessian(_loss: Sobolev1Loss, _reg: VarRegularization): + # pylint: disable=C0116,E0102 + def dhessian(jacs, mu): J, gJ = jacs H = J.T @ J + gJ.T @ gJ @@ -318,7 +320,7 @@ def dhessian(jacs, mu): @dispatch.abstract -def build_jac_err_prod(loss, reg): +def build_jac_err_prod(loss, reg): # pylint: disable=W0613 """ Returns a function that evaluates the product of a model's Jacobian `J` with respect to its parameters, and error differences `e`. Used within the @@ -327,7 +329,8 @@ def build_jac_err_prod(loss, reg): @dispatch -def build_jac_err_prod(loss: Loss, reg: L2Regularization): +def build_jac_err_prod(_loss: Loss, reg: L2Regularization): + # pylint: disable=C0116,E0102 r = reg.coeff def jep(J, e, ps): @@ -337,7 +340,9 @@ def jep(J, e, ps): @dispatch -def build_jac_err_prod(loss: Loss, reg: VarRegularization): +def build_jac_err_prod(_loss: Loss, _reg: VarRegularization): + # pylint: disable=C0116,E0102 + def jep(J, e, ps): k = ps.size return J.T @ e + (1 - 1 / k) / k * (ps - ps.mean()) @@ -346,7 +351,8 @@ def jep(J, e, ps): @dispatch -def build_jac_err_prod(loss: Sobolev1Loss, reg: L2Regularization): +def build_jac_err_prod(_loss: Sobolev1Loss, reg: L2Regularization): + # pylint: disable=C0116,E0102 r = reg.coeff def jep(jacs, errors, ps): @@ -358,7 +364,9 @@ def jep(jacs, errors, ps): @dispatch -def build_jac_err_prod(loss: Sobolev1Loss, reg: VarRegularization): +def build_jac_err_prod(_loss: Sobolev1Loss, _reg: VarRegularization): + # pylint: disable=C0116,E0102 + def jep(jacs, errors, ps): k = ps.size J, gJ = jacs diff --git a/pysages/ml/optimizers.py b/pysages/ml/optimizers.py index b193aa0b..7a2e9f56 100644 --- a/pysages/ml/optimizers.py +++ b/pysages/ml/optimizers.py @@ -103,8 +103,6 @@ class Optimizer: Abstract base class for all optimizers. """ - pass - @dataclass class Adam(Optimizer): @@ -144,7 +142,7 @@ class LevenbergMarquardtBR(Optimizer): @dispatch.abstract -def build(optimizer, model): +def build(optimizer, model): # pylint: disable=W0613 """ Given an optimizer and a model, builds and return three functions `initialize`, `keep_iterating` and `update` that respectively handle the initialization of @@ -158,6 +156,7 @@ def build(optimizer, model): @dispatch def build(optimizer: Adam, model): + # pylint: disable=C0116,E0102 _init, _update, repack = jopt.adam(*optimizer.params) objective = build_objective_function(model, optimizer.loss, optimizer.reg) gradient = jax.grad(objective) @@ -183,6 +182,7 @@ def update(state): @dispatch def build(optimizer: LevenbergMarquardt, model): + # pylint: disable=C0116,E0102 error, cost = build_split_cost_function(model, optimizer.loss, optimizer.reg) jac_err_prod = build_jac_err_prod(optimizer.loss, optimizer.reg) damped_hessian = build_damped_hessian(optimizer.loss, optimizer.reg) @@ -229,6 +229,7 @@ def update(state): @dispatch def build(optimizer: LevenbergMarquardtBR, model): + # pylint: disable=C0116,E0102 error = build_error_function(model, SSE()) jacobian = jax.jacobian(error) _, c, mu_min, mu_max, rho_c, rho_min = optimizer.params @@ -236,7 +237,7 @@ def build(optimizer: LevenbergMarquardtBR, model): # m = len(model.parameters) / 2 - 1 k = unpack(model.parameters)[0].size - update_hyperparams = partial(optimizer.update, m, k, optimizer.alpha) + _update_hyperparams = partial(optimizer.update, m, k, optimizer.alpha) def initialize(params, x, y): e = error(params, x, y) @@ -280,7 +281,7 @@ def update(state): improved = (C_ > C) | bad_step # bundle = (alpha, H, idx, sse, ssp, x.size) - alpha, *_ = cond(bad_step, lambda t: t, update_hyperparams, bundle) + alpha, *_ = cond(bad_step, lambda t: t, _update_hyperparams, bundle) C = (sse + alpha * ssp) / 2 # return LevenbergMarquardtBRState(data, p, e, C, mu, alpha, iters + ~bad_step, improved) diff --git a/pysages/ml/training.py b/pysages/ml/training.py index c4895ca9..44838942 100644 --- a/pysages/ml/training.py +++ b/pysages/ml/training.py @@ -31,7 +31,7 @@ def convolve(data, kernel, boundary="edge"): if n == 1: padding = (kernel.size - 1) // 2 else: - padding = [tuple((s - 1) // 2 for _ in range(n)) for s in kernel.shape] + padding = [((s - 1) // 2, (s - 1) // 2) for s in kernel.shape] def pad(slice): return np.pad(slice, padding, mode=boundary) diff --git a/pysages/ml/utils.py b/pysages/ml/utils.py index 74466073..b3a796a5 100644 --- a/pysages/ml/utils.py +++ b/pysages/ml/utils.py @@ -3,13 +3,15 @@ from jax import numpy as np from jax import random, vmap +from jax._src.nn import initializers +from jax.core import as_named_shape from jax.numpy.linalg import norm from jax.tree_util import PyTreeDef, tree_flatten from numpy import cumsum from plum import Dispatcher from pysages.typing import NamedTuple -from pysages.utils import prod +from pysages.utils import identity, prod # Dispatcher for the `ml` submodule dispatch = Dispatcher() @@ -62,6 +64,43 @@ def pack(params, layout): return structure.unflatten(ps) +def uniform_scaling( + scale, mode, in_axis=-2, out_axis=-1, dtype=np.float_, bias_like=False, scale_transform=np.sqrt +): + """ + Similar to `jax.nn.initializers.variance_scaling`, but the sampling distribution is + always uniform and the scaling can be transformed by `scale_transform` (defaults to + `jax.numpy.sqrt`). In addition, it also works for biases if `bias_like == True`. + """ + # Local aliases + idem = identity + transform = scale_transform + + if mode == "fan_in": + denominator = idem(lambda fan_in, fan_out: fan_in) + elif mode == "fan_out": + denominator = idem(lambda fan_in, fan_out: fan_out) + elif mode == "fan_avg": + denominator = idem(lambda fan_in, fan_out: (fan_in + fan_out) / 2) + else: + raise ValueError(f"invalid mode for variance scaling initializer: {mode}") + + if bias_like: + trim_named_shape = idem(lambda named_shp, shp, axis: as_named_shape(shp[axis:])) + else: + trim_named_shape = idem(lambda named_shp, shp, axis: named_shp) + + def init(key, shape, dtype=dtype): + args_named_shape = as_named_shape(shape) + named_shape = trim_named_shape(args_named_shape, shape, out_axis) + # pylint: disable-next=W0212 + fan_in, fan_out = initializers._compute_fans(args_named_shape, in_axis, out_axis) + s = np.array(scale / denominator(fan_in, fan_out), dtype=dtype) + return random.uniform(key, named_shape, dtype, -1) * transform(s) + + return init + + def number_of_weights(topology): k = topology[0] n = 0 diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 3351d48e..ffde6466 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -18,5 +18,14 @@ solve_pos_def, try_import, ) -from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity +from .core import ( + ToCPU, + copy, + dispatch, + eps, + first_or_all, + gaussian, + identity, + linear_solver, +) from .transformations import quaternion_from_euler, quaternion_matrix diff --git a/pysages/utils/compat.py b/pysages/utils/compat.py index 1544d42a..443b70f5 100644 --- a/pysages/utils/compat.py +++ b/pysages/utils/compat.py @@ -103,13 +103,23 @@ def has_method(fn, T, index): _typing = import_module("plum" if _plum_version_tuple < (2, 2, 1) else "typing") _util = _typing.type if _plum_version_tuple < (2, 2, 1) else _typing + if _plum_version_tuple < (2, 3, 0): + + def _signature_types(sig): + return sig.types + + else: + + def _signature_types(sig): + return sig.signature.types + def dispatch_table(dispatch): return dispatch.functions def has_method(fn, T, index): types_at_index = set() for sig in fn.methods: - typ = sig.types[index] + typ = _signature_types(sig)[index] if _util.get_origin(typ) is _typing.Union: types_at_index.update(_util.get_args(typ)) else: diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 06afdbc8..20fe0f3e 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -8,6 +8,7 @@ from plum import Dispatcher from pysages.typing import JaxArray, Scalar +from pysages.utils.compat import solve_pos_def # PySAGES main dispatcher dispatch = Dispatcher() @@ -70,3 +71,22 @@ def gaussian(a, sigma, x): N-dimensional origin-centered gaussian with height `a` and standard deviation `sigma`. """ return a * np.exp(-row_sum((x / sigma) ** 2) / 2) + + +def linear_solver(use_pinv: bool): + """ + Returns a function that solves the linear system `A.T @ X = B` for `X`. + When `use_pinv == True`, `numpy.linalg.pinv` is used rather than `scipy.linalg.solve` + (this is computationally more expensive but numerically more stable). + """ + if use_pinv: + # This is numerically more robust + def tsolve(A, B): + return np.linalg.pinv(A.T) @ B + + else: + # Another option to benchmark against is `linalg.tensorsolve(A @ A.T, A @ B)` + def tsolve(A, B): + return solve_pos_def(A @ A.T, A @ B) + + return tsolve diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 14ad5886..26f56cd7 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -82,6 +82,12 @@ def build_neighbor_list(box_size, positions, r_cutoff, capacity_multiplier): "cvs": [pysages.colvars.Component([0], 0), pysages.colvars.Component([0], 1)], "grid": pysages.Grid(lower=(1, 1), upper=(5, 5), shape=(32, 32)), }, + "Sirens": { + "cvs": [pysages.colvars.Component([0], 0), pysages.colvars.Component([0], 1)], + "grid": pysages.Grid(lower=(1, 1), upper=(5, 5), shape=(32, 32)), + "topology": (14,), + "mode": "abf", + }, "HistogramLogger": { "period": 1, "offset": 1,