Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MPS backend if available and use_device=None, add CHGNET_DEVICE env var #131

Merged
merged 7 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ jobs:

- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml .
env:
CHGNET_DEVICE: cpu

- name: Codacy coverage reporter
if: ${{ matrix.os == 'ubuntu-latest' && github.event_name == 'push' }}
Expand Down
8 changes: 5 additions & 3 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import inspect
import io
import os
import pickle
import sys
from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -51,7 +52,7 @@
class CHGNetCalculator(Calculator):
"""CHGNet Calculator for ASE applications."""

implemented_properties = ("energy", "forces", "stress", "magmoms")
implemented_properties = ("energy", "forces", "stress", "magmoms") # type: ignore

def __init__(
self,
Expand Down Expand Up @@ -81,7 +82,8 @@ def __init__(
super().__init__(**kwargs)

# Determine the device to use
if use_device == "mps" and torch.backends.mps.is_available():
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -95,7 +97,7 @@ def __init__(
print(f"CHGNet will run on {self.device}")

@property
def version(self) -> str:
def version(self) -> str | None:
"""The version of CHGNet."""
return self.model.version

Expand Down
10 changes: 5 additions & 5 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,8 @@ def load(
)

# Determine the device to use
if use_device == "mps" and torch.backends.mps.is_available():
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
device = "mps"
else:
device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -763,7 +764,7 @@ class BatchedGraph:
directed2undirected: Tensor
atom_positions: Sequence[Tensor]
strains: Sequence[Tensor]
volumes: Sequence[Tensor]
volumes: Sequence[Tensor] | Tensor

@classmethod
def from_graphs(
Expand All @@ -790,8 +791,7 @@ def from_graphs(
batched_atom_graph, batched_bond_graph = [], []
directed2undirected = []
atom_owners = []
atom_offset_idx = 0
n_undirected = 0
atom_offset_idx = n_undirected = 0

for graph_idx, graph in enumerate(graphs):
# Atoms
Expand All @@ -807,7 +807,7 @@ def from_graphs(
else:
strain = None
lattice = graph.lattice
volumes.append(torch.det(lattice))
volumes.append(torch.dot(lattice[0], torch.cross(lattice[1], lattice[2])))
strains.append(strain)

# Bonds
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }
[tool.ruff]
target-version = "py39"
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]
[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
Expand Down Expand Up @@ -94,6 +95,7 @@ ignore = [
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"DTZ005", # use of datetime.now() without timezone
"E731", # do not assign a lambda expression, use a def
"EM",
"ERA001", # found commented out code
"FBT001", # Boolean positional argument in function
Expand All @@ -114,7 +116,7 @@ pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
isort.split-on-trailing-comma = false

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"site/*" = ["INP001", "S602"]
"tests/*" = ["ANN201", "D103", "INP001", "S101"]
# E402 Module level import not at top of file
Expand Down
38 changes: 23 additions & 15 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,29 +220,37 @@ def test_as_to_from_dict() -> None:
assert model_3.todict() == to_dict


def test_model_load_version_params(capsys: pytest.CaptureFixture) -> None:
def test_model_load_version_params(
capsys: pytest.CaptureFixture, monkeypatch: pytest.MonkeyPatch
) -> None:
model = CHGNet.load(use_device="cpu")
assert model.version == "0.3.0"
assert model.n_params == 412_525
v030_key, v030_params = "0.3.0", 412_525
assert model.version == v030_key
assert model.n_params == v030_params
stdout, stderr = capsys.readouterr()
assert (
stdout
== f"""CHGNet v{model.version} initialized with {model.n_params:,} parameters
CHGNet will run on cpu\n"""
expected_stdout = lambda version, params: (
f"CHGNet v{version} initialized with {params:,} parameters\n"
"CHGNet will run on cpu\n"
)
assert stdout == expected_stdout(v030_key, v030_params)
assert stderr == ""

model = CHGNet.load(model_name="0.2.0", use_device="cpu")
assert model.version == "0.2.0"
assert model.n_params == 400_438
v020_key, v020_params = "0.2.0", 400_438
model = CHGNet.load(model_name=v020_key, use_device="cpu")
assert model.version == v020_key
assert model.n_params == v020_params
stdout, stderr = capsys.readouterr()
assert (
stdout
== f"""CHGNet v{model.version} initialized with {model.n_params:,} parameters
CHGNet will run on cpu\n"""
)
assert stdout == expected_stdout(v020_key, v020_params)
assert stderr == ""

model_name = "0.1.0" # invalid
with pytest.raises(ValueError, match=f"Unknown {model_name=}"):
CHGNet.load(model_name=model_name)

# # set CHGNET_DEVICE to "cuda" and test
monkeypatch.setenv("CHGNET_DEVICE", env_device := "foobar")
with pytest.raises(
RuntimeError,
match=f"Expected one of cpu, .+type at start of device string: {env_device}",
):
CHGNet.load()
Loading