diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5f0a13d2..bc3acb2d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,10 +1,10 @@
-default_stages: [commit]
+default_stages: [pre-commit]
default_install_hook_types: [pre-commit, commit-msg]
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.6.9
+ rev: v0.7.4
hooks:
- id: ruff
args: [--fix]
@@ -28,11 +28,11 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
- stages: [commit, commit-msg]
+ stages: [pre-commit, commit-msg]
args: [--check-filenames]
- repo: https://github.com/kynan/nbstripout
- rev: 0.7.1
+ rev: 0.8.0
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
@@ -48,7 +48,7 @@ repos:
- svelte
- repo: https://github.com/pre-commit/mirrors-eslint
- rev: v9.12.0
+ rev: v9.15.0
hooks:
- id: eslint
types: [file]
diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py
index b5b01f97..8b03bf0b 100644
--- a/chgnet/model/dynamics.py
+++ b/chgnet/model/dynamics.py
@@ -33,6 +33,8 @@
from ase.optimize.optimize import Optimizer
from typing_extensions import Self
+ from chgnet import PredTask
+
# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet
@@ -59,7 +61,7 @@ def __init__(
*,
use_device: str | None = None,
check_cuda_mem: bool = False,
- stress_weight: float | None = 1 / 160.21766208,
+ stress_weight: float = units.GPa, # GPa to eV/A^3
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
return_site_energies: bool = False,
**kwargs,
@@ -124,6 +126,7 @@ def calculate(
atoms: Atoms | None = None,
properties: list | None = None,
system_changes: list | None = None,
+ task: PredTask = "efsm",
) -> None:
"""Calculate various properties of the atoms using CHGNet.
@@ -133,6 +136,8 @@ def calculate(
Default is all properties.
system_changes (list | None): The changes made to the system.
Default is all changes.
+ task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm".
+ Default = "efsm"
"""
properties = properties or all_properties
system_changes = system_changes or all_changes
@@ -147,23 +152,28 @@ def calculate(
graph = self.model.graph_converter(structure)
model_prediction = self.model.predict_graph(
graph.to(self.device),
- task="efsm",
+ task=task,
return_crystal_feas=True,
return_site_energies=self.return_site_energies,
)
# Convert Result
- factor = 1 if not self.model.is_intensive else structure.composition.num_atoms
- self.results.update(
- energy=model_prediction["e"] * factor,
- forces=model_prediction["f"],
- free_energy=model_prediction["e"] * factor,
- magmoms=model_prediction["m"],
- stress=model_prediction["s"] * self.stress_weight,
- crystal_fea=model_prediction["crystal_fea"],
+ extensive_factor = len(structure) if self.model.is_intensive else 1
+ key_map = dict(
+ e=("energy", extensive_factor),
+ f=("forces", 1),
+ m=("magmoms", 1),
+ s=("stress", self.stress_weight),
)
+ self.results |= {
+ long_key: model_prediction[key] * factor
+ for key, (long_key, factor) in key_map.items()
+ if key in model_prediction
+ }
+ self.results["free_energy"] = self.results["energy"]
+ self.results["crystal_fea"] = model_prediction["crystal_fea"]
if self.return_site_energies:
- self.results.update(energies=model_prediction["site_energies"])
+ self.results["energies"] = model_prediction["site_energies"]
class StructOptimizer:
@@ -174,7 +184,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
- stress_weight: float = 1 / 160.21766208,
+ stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
) -> None:
"""Provide a trained CHGNet model and an optimizer to relax crystal structures.
@@ -773,7 +783,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
- stress_weight: float = 1 / 160.21766208,
+ stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
) -> None:
"""Initialize a structure optimizer object for calculation of bulk modulus.
diff --git a/chgnet/model/model.py b/chgnet/model/model.py
index d42c61c9..c1bd58f8 100644
--- a/chgnet/model/model.py
+++ b/chgnet/model/model.py
@@ -4,12 +4,13 @@
import os
from collections.abc import Sequence
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Literal
+from typing import TYPE_CHECKING, Literal, get_args
import torch
from pymatgen.core import Structure
from torch import Tensor, nn
+from chgnet import PredTask
from chgnet.graph import CrystalGraph, CrystalGraphConverter
from chgnet.graph.crystalgraph import TORCH_DTYPE
from chgnet.model.composition_model import AtomRef
@@ -27,7 +28,6 @@
if TYPE_CHECKING:
from typing_extensions import Self
- from chgnet import PredTask
module_dir = os.path.dirname(os.path.abspath(__file__))
@@ -603,7 +603,7 @@ def predict_graph(
Args:
graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict.
- task (str): can be 'e' 'ef', 'em', 'efs', 'efsm'
+ task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm'
Default = "efsm"
return_site_energies (bool): whether to return per-site energies.
Default = False
@@ -626,6 +626,9 @@ def predict_graph(
raise TypeError(
f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs"
)
+ valid_tasks = get_args(PredTask)
+ if task not in valid_tasks:
+ raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.")
model_device = next(self.parameters()).device
diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py
index e3637212..b742118a 100644
--- a/chgnet/trainer/trainer.py
+++ b/chgnet/trainer/trainer.py
@@ -858,7 +858,7 @@ def forward(
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
# exclude structures without magmom labels
if self.allow_missing_labels:
- if mag_target is not None and not np.isnan(mag_target).any():
+ if mag_target is not None and not torch.isnan(mag_target).any():
mag_preds.append(mag_pred)
mag_targets.append(mag_target)
m_mae_size += mag_target.shape[0]
diff --git a/site/.gitignore b/site/.gitignore
index 59078f29..bded1f72 100644
--- a/site/.gitignore
+++ b/site/.gitignore
@@ -5,4 +5,3 @@ node_modules
.svelte-kit
build
src/routes/api/*.md
-src/MetricsTable.svelte
diff --git a/site/package.json b/site/package.json
index 3474e4be..2f8156fc 100644
--- a/site/package.json
+++ b/site/package.json
@@ -15,28 +15,28 @@
"changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false"
},
"devDependencies": {
- "@sveltejs/adapter-static": "^3.0.2",
- "@sveltejs/kit": "^2.5.17",
- "@sveltejs/vite-plugin-svelte": "^3.1.1",
- "eslint": "^9.5.0",
- "eslint-plugin-svelte": "^2.41.0",
+ "@sveltejs/adapter-static": "^3.0.6",
+ "@sveltejs/kit": "^2.8.1",
+ "@sveltejs/vite-plugin-svelte": "^4.0.1",
+ "eslint": "^9.15.0",
+ "eslint-plugin-svelte": "^2.46.0",
"hastscript": "^9.0.0",
- "mdsvex": "^0.11.2",
- "prettier": "^3.3.2",
- "prettier-plugin-svelte": "^3.2.5",
+ "mdsvex": "^0.12.3",
+ "prettier": "^3.3.3",
+ "prettier-plugin-svelte": "^3.2.8",
"rehype-autolink-headings": "^7.1.0",
"rehype-slug": "^6.0.0",
- "svelte": "^4.2.18",
- "svelte-check": "^3.8.4",
- "svelte-multiselect": "^10.3.0",
- "svelte-preprocess": "^6.0.1",
+ "svelte": "^5.2.1",
+ "svelte-check": "^4.0.8",
+ "svelte-multiselect": "11.0.0-rc.1",
+ "svelte-preprocess": "^6.0.3",
"svelte-toc": "^0.5.9",
- "svelte-zoo": "^0.4.10",
- "svelte2tsx": "^0.7.13",
- "tslib": "^2.6.3",
- "typescript": "^5.5.2",
- "typescript-eslint": "^7.14.1",
- "vite": "^5.3.1"
+ "svelte-zoo": "^0.4.13",
+ "svelte2tsx": "^0.7.25",
+ "tslib": "^2.8.1",
+ "typescript": "^5.6.3",
+ "typescript-eslint": "^8.14.0",
+ "vite": "^5.4.11"
},
"prettier": {
"semi": false,
diff --git a/site/src/routes/+page.svelte b/site/src/routes/+page.svelte
index 7e2c6975..201fe721 100644
--- a/site/src/routes/+page.svelte
+++ b/site/src/routes/+page.svelte
@@ -1,12 +1,9 @@
-
-
-
+