Skip to content

Commit 180ffea

Browse files
ajhoffman1229ajh1229janosh
authored
Dispersion (#192)
* ENH: add dispersion tutorial * STY: format dispersion notebook * add extra deps set dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"] also bump ruff * tweak dispersion.ipynb var names --------- Co-authored-by: alex <ajhoff29@mit.edu> Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
1 parent be9bab9 commit 180ffea

8 files changed

+283
-9
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]
44

55
repos:
66
- repo: https://github.com/astral-sh/ruff-pre-commit
7-
rev: v0.5.5
7+
rev: v0.6.2
88
hooks:
99
- id: ruff
1010
args: [--fix]
@@ -48,7 +48,7 @@ repos:
4848
- svelte
4949

5050
- repo: https://github.com/pre-commit/mirrors-eslint
51-
rev: v9.8.0
51+
rev: v9.9.0
5252
hooks:
5353
- id: eslint
5454
types: [file]

examples/dispersion.ipynb

+274
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"The autoreload extension is already loaded. To reload it, use:\n",
13+
" %reload_ext autoreload\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"%load_ext autoreload\n",
19+
"%autoreload 2"
20+
]
21+
},
22+
{
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"# Adding dispersion to the CHGNet pre-trained model"
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"metadata": {},
32+
"source": [
33+
"This notebook describes the process of adding a dispersion correction to the CHGNet pre-trained model. CHGNet is trained on PBE (GGA) DFT calculations; as such, it does not include a correction for van der Waals or dispersive forces. This kind of correction may be particularly useful for those studying porous materials, such as MOFs or zeolites, but who do not wish to fine-tune the pre-trained model on data that include a dispersion correction.\n",
34+
"\n",
35+
"This notebook uses both the [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master) and [DFT-D4](https://dftd4.readthedocs.io/en/latest/reference/ase.html) repositories to add dispersion to CHGNet. The torch-dftd repository currently has DFT-D2 and DFT-D3 implementations and does not have the most recent DFT-D4 version, but is GPU-accelerated where DFT-D4 is not. The Grimme group has released a version of [DFT-D4 implemented in PyTorch](https://github.com/dftd4/tad-dftd4); however, this version does not have an ASE-compatible calculator available.\n",
36+
"\n",
37+
"You will need to install CHGNet, [ASE](https://wiki.fysik.dtu.dk/ase/install.html), [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master?tab=readme-ov-file#install), and [DFT-D4](https://dftd4.readthedocs.io/en/latest/recipe/installation.html) to run this notebook (links are to their installation instructions)."
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"from ase.build import fcc111\n",
47+
"from ase.calculators.mixing import SumCalculator\n",
48+
"from dftd4.ase import DFTD4\n",
49+
"from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator\n",
50+
"\n",
51+
"from chgnet.model.dynamics import CHGNetCalculator"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"metadata": {},
58+
"outputs": [
59+
{
60+
"name": "stdout",
61+
"output_type": "stream",
62+
"text": [
63+
"CHGNet v0.3.0 initialized with 412,525 parameters\n",
64+
"CHGNet will run on cpu\n",
65+
"CHGNet will run on cpu\n"
66+
]
67+
}
68+
],
69+
"source": [
70+
"# pre-trained chgnet model\n",
71+
"chgnet_calc = CHGNetCalculator()\n",
72+
"\n",
73+
"d3_calc = TorchDFTD3Calculator() # uses PBE parameters by default\n",
74+
"\n",
75+
"d4_calc = DFTD4(method=\"PBE\")"
76+
]
77+
},
78+
{
79+
"cell_type": "markdown",
80+
"metadata": {},
81+
"source": [
82+
"## A simple example"
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"This example shows how to initialize an Atoms object (of a Cu(111) surface) and compute its energy with and without the dispersion correction."
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"metadata": {},
96+
"outputs": [
97+
{
98+
"name": "stdout",
99+
"output_type": "stream",
100+
"text": [
101+
"Disp calculator: sumcalculator\n",
102+
"Cu4\n",
103+
"E without dispersion: -12.876540184020996\n",
104+
"E with DFT-D3 dispersion: -13.272150007989707\n",
105+
"E with DFT-D4 dispersion: -13.573149493981\n"
106+
]
107+
}
108+
],
109+
"source": [
110+
"# Create a 2x2x1 fcc(111) Cu slab\n",
111+
"atoms = fcc111(\"Cu\", (2, 2, 1), vacuum=10.0)\n",
112+
"atoms.set_pbc([True, True, True])\n",
113+
"\n",
114+
"atoms_disp = atoms.copy()\n",
115+
"atoms_d4 = atoms.copy()\n",
116+
"\n",
117+
"atoms.calc = chgnet_calc\n",
118+
"\n",
119+
"chgnet_d3 = SumCalculator([chgnet_calc, d3_calc])\n",
120+
"chgnet_d4 = SumCalculator([chgnet_calc, d4_calc])\n",
121+
"atoms_disp.calc = chgnet_d3\n",
122+
"atoms_d4.calc = chgnet_d4\n",
123+
"\n",
124+
"e_chg = atoms.get_potential_energy()\n",
125+
"e_disp = atoms_disp.get_potential_energy()\n",
126+
"e_d4 = atoms_d4.get_potential_energy()\n",
127+
"\n",
128+
"print(f\"Disp calculator: {chgnet_d3.name}\")\n",
129+
"print(atoms.get_chemical_formula())\n",
130+
"print(f\"E without dispersion: {e_chg}\")\n",
131+
"print(f\"E with DFT-D3 dispersion: {e_disp}\")\n",
132+
"print(f\"E with DFT-D4 dispersion: {e_d4}\")"
133+
]
134+
},
135+
{
136+
"cell_type": "markdown",
137+
"metadata": {},
138+
"source": [
139+
"## Optimization example"
140+
]
141+
},
142+
{
143+
"cell_type": "markdown",
144+
"metadata": {},
145+
"source": [
146+
"Below is a simple example of an optimization of a Cu cell with a displaced atom and perturbed unit cell."
147+
]
148+
},
149+
{
150+
"cell_type": "code",
151+
"execution_count": null,
152+
"metadata": {},
153+
"outputs": [],
154+
"source": [
155+
"from ase.build import bulk\n",
156+
"from ase.filters import FrechetCellFilter\n",
157+
"from ase.optimize import BFGS\n",
158+
"\n",
159+
"atoms = bulk(\"Cu\", cubic=True)\n",
160+
"\n",
161+
"atoms[0].x += 0.1\n",
162+
"atoms.cell[0] += 0.1"
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": null,
168+
"metadata": {},
169+
"outputs": [
170+
{
171+
"data": {
172+
"text/plain": [
173+
"Cell([[3.71, 0.1, 0.1], [0.0, 3.61, 0.0], [0.0, 0.0, 3.61]])"
174+
]
175+
},
176+
"execution_count": null,
177+
"metadata": {},
178+
"output_type": "execute_result"
179+
}
180+
],
181+
"source": [
182+
"atoms.cell"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"metadata": {},
189+
"outputs": [
190+
{
191+
"name": "stdout",
192+
"output_type": "stream",
193+
"text": [
194+
" Step Time Energy fmax\n",
195+
"BFGS: 0 17:55:31 -18.448409 0.718576\n",
196+
"BFGS: 1 17:55:33 -18.479250 0.649700\n",
197+
"BFGS: 2 17:55:34 -18.602240 1352.772997\n",
198+
"BFGS: 3 17:55:35 -18.495526 0.666888\n",
199+
"BFGS: 4 17:55:36 -18.505485 0.674898\n",
200+
"BFGS: 5 17:55:37 -18.524050 0.707074\n",
201+
"BFGS: 6 17:55:39 -18.524685 0.708438\n",
202+
"BFGS: 7 17:55:40 -18.527744 0.722830\n",
203+
"BFGS: 8 17:55:41 -18.529436 0.740565\n",
204+
"BFGS: 9 17:55:42 -18.530648 0.755404\n",
205+
"BFGS: 10 17:55:43 -18.530939 0.757174\n",
206+
"BFGS: 11 17:55:44 -18.531477 0.756699\n",
207+
"BFGS: 12 17:55:46 -18.532681 0.750869\n",
208+
"BFGS: 13 17:55:47 -18.534685 0.741311\n",
209+
"BFGS: 14 17:55:48 -18.537005 0.728779\n",
210+
"BFGS: 15 17:55:49 -18.539452 0.706232\n",
211+
"BFGS: 16 17:55:50 -18.540848 0.684654\n",
212+
"BFGS: 17 17:55:52 -18.541594 0.680409\n",
213+
"BFGS: 18 17:55:53 -18.544656 0.659496\n",
214+
"BFGS: 19 17:55:54 -18.548455 0.625415\n",
215+
"BFGS: 20 17:55:55 -18.554630 0.578256\n",
216+
"BFGS: 21 17:55:56 -18.562180 0.520110\n",
217+
"BFGS: 22 17:55:58 -18.568905 0.460355\n"
218+
]
219+
},
220+
{
221+
"data": {
222+
"text/plain": [
223+
"True"
224+
]
225+
},
226+
"execution_count": null,
227+
"metadata": {},
228+
"output_type": "execute_result"
229+
}
230+
],
231+
"source": [
232+
"atoms.calc = chgnet_d4\n",
233+
"cell_filter = FrechetCellFilter(atoms)\n",
234+
"opt = BFGS(cell_filter, trajectory=\"Cu.traj\")\n",
235+
"\n",
236+
"opt.run(fmax=0.5, steps=100)"
237+
]
238+
},
239+
{
240+
"cell_type": "markdown",
241+
"metadata": {},
242+
"source": [
243+
"The output of this optimization can be viewed by running\n",
244+
"\n",
245+
"```bash\n",
246+
"ase gui Cu.traj\n",
247+
"```\n",
248+
"\n",
249+
"in the command line in this folder."
250+
]
251+
}
252+
],
253+
"metadata": {
254+
"kernelspec": {
255+
"display_name": "htvs",
256+
"language": "python",
257+
"name": "python3"
258+
},
259+
"language_info": {
260+
"codemirror_mode": {
261+
"name": "ipython",
262+
"version": 3
263+
},
264+
"file_extension": ".py",
265+
"mimetype": "text/x-python",
266+
"name": "python",
267+
"nbconvert_exporter": "python",
268+
"pygments_lexer": "ipython3",
269+
"version": "3.12.5"
270+
}
271+
},
272+
"nbformat": 4,
273+
"nbformat_minor": 2
274+
}

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ test = ["pytest-cov>=4", "pytest>=8"]
3434
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
3535
docs = ["lazydocs>=0.4"]
3636
logging = ["wandb>=0.17"]
37+
dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"]
3738

3839
[project.urls]
3940
Source = "https://github.com/CederGroupHub/chgnet"
@@ -52,7 +53,6 @@ build-backend = "setuptools.build_meta"
5253

5354
[tool.ruff]
5455
target-version = "py39"
55-
extend-include = ["*.ipynb"]
5656

5757
[tool.ruff.lint]
5858
select = ["ALL"]

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
from chgnet import ROOT
77

88

9-
@pytest.fixture()
9+
@pytest.fixture
1010
def li_mn_o2() -> Structure:
1111
return Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")

tests/test_converter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
NaCl = Structure(lattice, species, coords)
1818

1919

20-
@pytest.fixture()
20+
@pytest.fixture
2121
def _set_make_graph() -> Generator[None, None, None]:
2222
# fixture to force make_graph to be None and then restore it after test
2323
from chgnet.graph import converter

tests/test_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
NaCl = Structure(lattice, species, coords)
1717

1818

19-
@pytest.fixture()
19+
@pytest.fixture
2020
def structure_data() -> StructureData:
2121
"""Create a graph with 3 nodes and 3 directed edges."""
2222
random.seed(42)

tests/test_graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from chgnet.graph.graph import DirectedEdge, Graph, Node, UndirectedEdge
77

88

9-
@pytest.fixture()
9+
@pytest.fixture
1010
def graph() -> Graph:
1111
"""Create a graph with 3 nodes and 3 directed edges."""
1212
nodes = [Node(index=idx) for idx in range(3)]
@@ -50,7 +50,7 @@ def test_as_dict(graph: Graph) -> None:
5050
assert len(graph_dict["undirected_edges_list"]) == 3
5151

5252

53-
@pytest.fixture()
53+
@pytest.fixture
5454
def bigraph() -> Graph:
5555
"""Create a bi-directional graph with 3 nodes and 4 bi-directed edges."""
5656
nodes = [Node(index=idx) for idx in range(3)]

tests/test_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
114114
assert torch.all(comparison == expect)
115115

116116

117-
@pytest.fixture()
117+
@pytest.fixture
118118
def mock_wandb():
119119
with patch("chgnet.trainer.trainer.wandb") as mock:
120120
yield mock

0 commit comments

Comments
 (0)