Skip to content

Commit 84e8d55

Browse files
authored
CHGNetCalculator add kwarg task: PredTask = "efsm" (#215)
1 parent 0da2d15 commit 84e8d55

11 files changed

+86
-61
lines changed

.pre-commit-config.yaml

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
default_stages: [commit]
1+
default_stages: [pre-commit]
22

33
default_install_hook_types: [pre-commit, commit-msg]
44

55
repos:
66
- repo: https://github.com/astral-sh/ruff-pre-commit
7-
rev: v0.6.9
7+
rev: v0.7.4
88
hooks:
99
- id: ruff
1010
args: [--fix]
@@ -28,11 +28,11 @@ repos:
2828
rev: v2.3.0
2929
hooks:
3030
- id: codespell
31-
stages: [commit, commit-msg]
31+
stages: [pre-commit, commit-msg]
3232
args: [--check-filenames]
3333

3434
- repo: https://github.com/kynan/nbstripout
35-
rev: 0.7.1
35+
rev: 0.8.0
3636
hooks:
3737
- id: nbstripout
3838
args: [--drop-empty-cells, --keep-output]
@@ -48,7 +48,7 @@ repos:
4848
- svelte
4949

5050
- repo: https://github.com/pre-commit/mirrors-eslint
51-
rev: v9.12.0
51+
rev: v9.15.0
5252
hooks:
5353
- id: eslint
5454
types: [file]

chgnet/model/dynamics.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from ase.optimize.optimize import Optimizer
3434
from typing_extensions import Self
3535

36+
from chgnet import PredTask
37+
3638
# We would like to thank M3GNet develop team for this module
3739
# source: https://github.com/materialsvirtuallab/m3gnet
3840

@@ -59,7 +61,7 @@ def __init__(
5961
*,
6062
use_device: str | None = None,
6163
check_cuda_mem: bool = False,
62-
stress_weight: float | None = 1 / 160.21766208,
64+
stress_weight: float = units.GPa, # GPa to eV/A^3
6365
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
6466
return_site_energies: bool = False,
6567
**kwargs,
@@ -124,6 +126,7 @@ def calculate(
124126
atoms: Atoms | None = None,
125127
properties: list | None = None,
126128
system_changes: list | None = None,
129+
task: PredTask = "efsm",
127130
) -> None:
128131
"""Calculate various properties of the atoms using CHGNet.
129132
@@ -133,6 +136,8 @@ def calculate(
133136
Default is all properties.
134137
system_changes (list | None): The changes made to the system.
135138
Default is all changes.
139+
task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm".
140+
Default = "efsm"
136141
"""
137142
properties = properties or all_properties
138143
system_changes = system_changes or all_changes
@@ -147,23 +152,28 @@ def calculate(
147152
graph = self.model.graph_converter(structure)
148153
model_prediction = self.model.predict_graph(
149154
graph.to(self.device),
150-
task="efsm",
155+
task=task,
151156
return_crystal_feas=True,
152157
return_site_energies=self.return_site_energies,
153158
)
154159

155160
# Convert Result
156-
factor = 1 if not self.model.is_intensive else structure.composition.num_atoms
157-
self.results.update(
158-
energy=model_prediction["e"] * factor,
159-
forces=model_prediction["f"],
160-
free_energy=model_prediction["e"] * factor,
161-
magmoms=model_prediction["m"],
162-
stress=model_prediction["s"] * self.stress_weight,
163-
crystal_fea=model_prediction["crystal_fea"],
161+
extensive_factor = len(structure) if self.model.is_intensive else 1
162+
key_map = dict(
163+
e=("energy", extensive_factor),
164+
f=("forces", 1),
165+
m=("magmoms", 1),
166+
s=("stress", self.stress_weight),
164167
)
168+
self.results |= {
169+
long_key: model_prediction[key] * factor
170+
for key, (long_key, factor) in key_map.items()
171+
if key in model_prediction
172+
}
173+
self.results["free_energy"] = self.results["energy"]
174+
self.results["crystal_fea"] = model_prediction["crystal_fea"]
165175
if self.return_site_energies:
166-
self.results.update(energies=model_prediction["site_energies"])
176+
self.results["energies"] = model_prediction["site_energies"]
167177

168178

169179
class StructOptimizer:
@@ -174,7 +184,7 @@ def __init__(
174184
model: CHGNet | CHGNetCalculator | None = None,
175185
optimizer_class: Optimizer | str | None = "FIRE",
176186
use_device: str | None = None,
177-
stress_weight: float = 1 / 160.21766208,
187+
stress_weight: float = units.GPa,
178188
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
179189
) -> None:
180190
"""Provide a trained CHGNet model and an optimizer to relax crystal structures.
@@ -773,7 +783,7 @@ def __init__(
773783
model: CHGNet | CHGNetCalculator | None = None,
774784
optimizer_class: Optimizer | str | None = "FIRE",
775785
use_device: str | None = None,
776-
stress_weight: float = 1 / 160.21766208,
786+
stress_weight: float = units.GPa,
777787
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
778788
) -> None:
779789
"""Initialize a structure optimizer object for calculation of bulk modulus.

chgnet/model/model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import os
55
from collections.abc import Sequence
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Literal
7+
from typing import TYPE_CHECKING, Literal, get_args
88

99
import torch
1010
from pymatgen.core import Structure
1111
from torch import Tensor, nn
1212

13+
from chgnet import PredTask
1314
from chgnet.graph import CrystalGraph, CrystalGraphConverter
1415
from chgnet.graph.crystalgraph import TORCH_DTYPE
1516
from chgnet.model.composition_model import AtomRef
@@ -27,7 +28,6 @@
2728
if TYPE_CHECKING:
2829
from typing_extensions import Self
2930

30-
from chgnet import PredTask
3131

3232
module_dir = os.path.dirname(os.path.abspath(__file__))
3333

@@ -603,7 +603,7 @@ def predict_graph(
603603
604604
Args:
605605
graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict.
606-
task (str): can be 'e' 'ef', 'em', 'efs', 'efsm'
606+
task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm'
607607
Default = "efsm"
608608
return_site_energies (bool): whether to return per-site energies.
609609
Default = False
@@ -626,6 +626,9 @@ def predict_graph(
626626
raise TypeError(
627627
f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs"
628628
)
629+
valid_tasks = get_args(PredTask)
630+
if task not in valid_tasks:
631+
raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.")
629632

630633
model_device = next(self.parameters()).device
631634

chgnet/trainer/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def forward(
858858
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
859859
# exclude structures without magmom labels
860860
if self.allow_missing_labels:
861-
if mag_target is not None and not np.isnan(mag_target).any():
861+
if mag_target is not None and not torch.isnan(mag_target).any():
862862
mag_preds.append(mag_pred)
863863
mag_targets.append(mag_target)
864864
m_mae_size += mag_target.shape[0]

site/.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ node_modules
55
.svelte-kit
66
build
77
src/routes/api/*.md
8-
src/MetricsTable.svelte

site/package.json

+18-18
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,28 @@
1515
"changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false"
1616
},
1717
"devDependencies": {
18-
"@sveltejs/adapter-static": "^3.0.2",
19-
"@sveltejs/kit": "^2.5.17",
20-
"@sveltejs/vite-plugin-svelte": "^3.1.1",
21-
"eslint": "^9.5.0",
22-
"eslint-plugin-svelte": "^2.41.0",
18+
"@sveltejs/adapter-static": "^3.0.6",
19+
"@sveltejs/kit": "^2.8.1",
20+
"@sveltejs/vite-plugin-svelte": "^4.0.1",
21+
"eslint": "^9.15.0",
22+
"eslint-plugin-svelte": "^2.46.0",
2323
"hastscript": "^9.0.0",
24-
"mdsvex": "^0.11.2",
25-
"prettier": "^3.3.2",
26-
"prettier-plugin-svelte": "^3.2.5",
24+
"mdsvex": "^0.12.3",
25+
"prettier": "^3.3.3",
26+
"prettier-plugin-svelte": "^3.2.8",
2727
"rehype-autolink-headings": "^7.1.0",
2828
"rehype-slug": "^6.0.0",
29-
"svelte": "^4.2.18",
30-
"svelte-check": "^3.8.4",
31-
"svelte-multiselect": "^10.3.0",
32-
"svelte-preprocess": "^6.0.1",
29+
"svelte": "^5.2.1",
30+
"svelte-check": "^4.0.8",
31+
"svelte-multiselect": "11.0.0-rc.1",
32+
"svelte-preprocess": "^6.0.3",
3333
"svelte-toc": "^0.5.9",
34-
"svelte-zoo": "^0.4.10",
35-
"svelte2tsx": "^0.7.13",
36-
"tslib": "^2.6.3",
37-
"typescript": "^5.5.2",
38-
"typescript-eslint": "^7.14.1",
39-
"vite": "^5.3.1"
34+
"svelte-zoo": "^0.4.13",
35+
"svelte2tsx": "^0.7.25",
36+
"tslib": "^2.8.1",
37+
"typescript": "^5.6.3",
38+
"typescript-eslint": "^8.14.0",
39+
"vite": "^5.4.11"
4040
},
4141
"prettier": {
4242
"semi": false,

site/src/routes/+page.svelte

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
<script lang="ts">
22
import Readme from '$root/README.md'
3-
import MetricsTable from '$src/MetricsTable.svelte'
43
</script>
54

65
<main>
7-
<Readme>
8-
<MetricsTable slot="metrics-table" />
9-
</Readme>
6+
<Readme />
107
</main>
118

129
<style>

site/vite.config.ts

-10
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
import { sveltekit } from '@sveltejs/kit/vite'
2-
import * as fs from 'fs'
32
import type { UserConfig } from 'vite'
43

5-
// fetch latest Matbench Discovery metrics table at build time and save to src/ dir
6-
await fetch(
7-
`https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table-uniq-protos.svelte`,
8-
)
9-
.then((res) => res.text())
10-
.then((text) => {
11-
fs.writeFileSync(`src/MetricsTable.svelte`, text)
12-
})
13-
144
export default {
155
plugins: [sveltekit()],
166

tests/test_md.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import pickle
5-
from typing import TYPE_CHECKING, Literal
5+
from typing import TYPE_CHECKING, Literal, get_args
66

77
import numpy as np
88
import pytest
@@ -22,7 +22,7 @@
2222
from chgnet.graph import CrystalGraphConverter
2323
from chgnet.model import StructOptimizer
2424
from chgnet.model.dynamics import CHGNetCalculator, EquationOfState, MolecularDynamics
25-
from chgnet.model.model import CHGNet
25+
from chgnet.model.model import CHGNet, PredTask
2626

2727
if TYPE_CHECKING:
2828
from pathlib import Path
@@ -314,3 +314,27 @@ def test_md_crystal_feas_log(tmp_path: Path, monkeypatch: MonkeyPatch):
314314
assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5)
315315
assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5)
316316
assert crystal_feas[10][1] == approx(-1.4284436, abs=1e-5)
317+
318+
319+
@pytest.mark.parametrize("task", [*get_args(PredTask)])
320+
def test_calculator_task_valid(task: PredTask):
321+
"""Test that the task kwarg of CHGNetCalculator.calculate() works correctly."""
322+
key_map = dict(e="energy", f="forces", m="magmoms", s="stress")
323+
calculator = CHGNetCalculator()
324+
atoms = AseAtomsAdaptor.get_atoms(structure)
325+
atoms.calc = calculator
326+
327+
calculator.calculate(atoms=atoms, task=task)
328+
329+
for key, prop in key_map.items():
330+
assert (prop in calculator.results) == (key in task)
331+
332+
333+
def test_calculator_task_invalid():
334+
"""Test that invalid task raises ValueError."""
335+
calculator = CHGNetCalculator()
336+
atoms = AseAtomsAdaptor.get_atoms(structure)
337+
atoms.calc = calculator
338+
339+
with pytest.raises(ValueError, match="Invalid task='invalid'."):
340+
calculator.calculate(atoms=atoms, task="invalid")

tests/test_relaxation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def test_relaxation(
5050
assert {*traj.__dict__} == {
5151
*"atoms energies forces stresses magmoms atom_positions cells".split()
5252
}
53-
assert len(traj) == 2 if algorithm == "legacy" else 4
53+
assert len(traj) == (
54+
2 if algorithm == "legacy" else 4
55+
), f"{len(traj)=}, {algorithm=}"
5456

5557
# make sure final structure is more relaxed than initial one
5658
assert traj.energies[-1] == pytest.approx(-58.94209, rel=1e-4)

tests/test_trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
7474
for param in chgnet.composition_model.parameters():
7575
assert param.requires_grad is False
7676
assert tmp_path.is_dir(), "Training dir was not created"
77-
for target_str in ["e", "f", "s", "m"]:
78-
assert ~np.isnan(trainer.training_history[target_str]["train"]).any()
79-
assert ~np.isnan(trainer.training_history[target_str]["val"]).any()
77+
for prop in "efsm":
78+
assert ~np.isnan(trainer.training_history[prop]["train"]).any()
79+
assert ~np.isnan(trainer.training_history[prop]["val"]).any()
8080
output_files = [file.name for file in tmp_path.iterdir()]
8181
for prefix in ("epoch", "bestE_", "bestF_"):
8282
n_matches = sum(file.startswith(prefix) for file in output_files)

0 commit comments

Comments
 (0)