Skip to content

Feature/24 gmm logic tree classes #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Changelog


## [0.2.0] - 2024-05-30
## [0.2.0] - 2024-06-07
### Changed
- Complete reset, no more django
- all previous code is mothballed
Expand All @@ -13,6 +13,7 @@
- get_model resolver
- get_models resolver
- source logic tree models and resolvers
- gmmm logic tree models and resolvers

## [0.1.3] - 2023-09-04
### Added
Expand Down
142 changes: 142 additions & 0 deletions nshm_model_graphql_api/schema/nshm_model_gmms_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Define graphene model for nzshm_model gmm logic tree classes."""

import json
import logging
from functools import lru_cache

import graphene
from graphene import relay

from .nshm_model_sources_schema import get_model_by_version

log = logging.getLogger(__name__)


# TODO: this method belongs on the nzshm-model gmcm class
@lru_cache
def get_branch_set(model_version, short_name):
glt = get_model_by_version(model_version).gmm_logic_tree
log.debug(f"glt {glt}")
for bs in glt.branch_sets:
if bs.short_name == short_name:
return bs
assert 0, f"branch set {short_name} was not found" # pragma: no cover


# TODO: this method belongs on the nzshm-model gmcm class
@lru_cache
def get_logic_tree_branch(model_version, branch_set_short_name, gsim_name, gsim_args):
log.info(
f"get_logic_tree_branch: {branch_set_short_name} gsim_name: {gsim_name} gsim_args: {gsim_args}"
)
branch_set = get_branch_set(model_version, branch_set_short_name)
for ltb in branch_set.branches:
if (ltb.gsim_name == gsim_name) and (ltb.gsim_args == json.loads(gsim_args)):
return ltb
assert (
0
), f"branch with gsim_name: {gsim_name} gsim_args: {gsim_args} was not found" # pragma: no cover


class GmmLogicTreeBranch(graphene.ObjectType):
class Meta:
interfaces = (relay.Node,)

model_version = graphene.String()
branch_set_short_name = graphene.String()
gsim_name = graphene.String()
gsim_args = graphene.JSONString()
tectonic_region_type = graphene.String() # should be an enum
weight = graphene.Float()

def resolve_id(self, info):
return f"{self.model_version}|{self.branch_set_short_name}|{self.gsim_name}|{json.dumps(self.gsim_args)}"

@classmethod
def get_node(cls, info, node_id: str):
model_version, branch_set_short_name, gsim_name, gsim_args = node_id.split("|")
gltb = get_logic_tree_branch(
model_version, branch_set_short_name, gsim_name, gsim_args
)
return GmmLogicTreeBranch(
model_version=model_version,
branch_set_short_name=branch_set_short_name,
tectonic_region_type=gltb.tectonic_region_type,
gsim_name=gltb.gsim_name,
gsim_args=gltb.gsim_args,
weight=gltb.weight,
)


class GmmBranchSet(graphene.ObjectType):
"""Ground Motion Model branch sets,

to ensure that the wieghts of the enclosed branches sum to 1.0
"""

class Meta:
interfaces = (relay.Node,)

model_version = graphene.String()
short_name = graphene.String()
long_name = graphene.String()
tectonic_region_type = graphene.String()
branches = graphene.List(GmmLogicTreeBranch)

def resolve_id(self, info):
return f"{self.model_version}:{self.short_name}"

@classmethod
def get_node(cls, info, node_id: str):
model_version, short_name = node_id.split(":")
bs = get_branch_set(model_version, short_name)
return GmmBranchSet(
model_version=model_version,
tectonic_region_type=bs.tectonic_region_type,
short_name=bs.short_name,
long_name=bs.long_name,
)

@staticmethod
def resolve_branches(root, info, **kwargs):
log.info(f"resolve_branches root: {root} kwargs: {kwargs}")
bs = get_branch_set(root.model_version, root.short_name)
for ltb in bs.branches:
log.debug(ltb)
yield GmmLogicTreeBranch(
model_version=root.model_version,
branch_set_short_name=root.short_name,
tectonic_region_type=ltb.tectonic_region_type,
weight=ltb.weight,
gsim_name=ltb.gsim_name,
gsim_args=ltb.gsim_args,
)


class GroundMotionModelLogicTree(graphene.ObjectType):
"""A custom Node representing the GMM logic tree of a given model."""

class Meta:
interfaces = (relay.Node,)

model_version = graphene.String()
branch_sets = graphene.List(GmmBranchSet)

def resolve_id(self, info):
return self.model_version

@classmethod
def get_node(cls, info, model_version: str):
return GroundMotionModelLogicTree(model_version=model_version)

@staticmethod
def resolve_branch_sets(root, info, **kwargs):
log.info(f"resolve_branch_sets root: {root} kwargs: {kwargs}")
glt = get_model_by_version(root.model_version).gmm_logic_tree
for bs in glt.branch_sets:
yield GmmBranchSet(
model_version=root.model_version,
short_name=bs.short_name,
long_name=bs.long_name,
tectonic_region_type=bs.tectonic_region_type,
)
11 changes: 8 additions & 3 deletions nshm_model_graphql_api/schema/nshm_model_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import nzshm_model as nm
from graphene import relay

from .nshm_model_gmms_schema import GroundMotionModelLogicTree
from .nshm_model_sources_schema import SourceLogicTree

log = logging.getLogger(__name__)
Expand All @@ -21,16 +22,20 @@ class Meta:
version = graphene.String()
title = graphene.String()
source_logic_tree = graphene.Field(SourceLogicTree)
gmm_logic_tree = graphene.Field(GroundMotionModelLogicTree)

def resolve_id(self, info):
return self.version

@staticmethod
def resolve_source_logic_tree(root, info, **kwargs):
log.info(f"resolve_source_logic_tree root: {root} kwargs: {kwargs}")
return SourceLogicTree(
model_version=root.version
) # , branch_sets=get_branch_sets(slt))
return SourceLogicTree(model_version=root.version)

@staticmethod
def resolve_gmm_logic_tree(root, info, **kwargs):
log.info(f"resolve_gmm_logic_tree root: {root} kwargs: {kwargs}")
return GroundMotionModelLogicTree(model_version=root.version)

@classmethod
def get_node(cls, info, version: str):
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

134 changes: 134 additions & 0 deletions tests/test_schema_gmm_models_as_relay_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest
from graphene.test import Client
from graphql_relay import to_global_id

from nshm_model_graphql_api import schema


@pytest.fixture(scope="module")
def client():
return Client(schema.schema_root)


@pytest.mark.parametrize(
"model_version",
["NSHM_v1.0.0", "NSHM_v1.0.4"],
)
def test_get_model_SourceLogicTree_as_node(client, model_version):
QUERY = """
query {
node(id: "%s")
{
... on Node {
id
}
... on GroundMotionModelLogicTree {
model_version
}
}
}
""" % to_global_id(
"GroundMotionModelLogicTree", model_version
)
print(QUERY)
executed = client.execute(QUERY)
print(executed)
assert executed["data"]["node"]["model_version"] == model_version
assert executed["data"]["node"]["id"] == to_global_id(
"GroundMotionModelLogicTree", model_version
)


@pytest.mark.parametrize(
"model_version, short_name, long_name",
[
("NSHM_v1.0.0", "CRU", "Crustal"),
("NSHM_v1.0.0", "SLAB", "Subduction Intraslab"),
("NSHM_v1.0.4", "CRU", "Crustal"),
("NSHM_v1.0.4", "INTER", "Subduction Interface"),
],
)
def test_get_model_GmmBranchSet_as_node(client, model_version, short_name, long_name):
QUERY = """
query {
node(id: "%s")
{
... on Node {
id
}
... on GmmBranchSet {
model_version
short_name
long_name
tectonic_region_type
}

}
}
""" % to_global_id(
"GmmBranchSet", f"{model_version}:{short_name}"
)
executed = client.execute(QUERY)
print(executed)
assert executed["data"]["node"]["model_version"] == model_version
assert executed["data"]["node"]["short_name"] == short_name
assert executed["data"]["node"]["long_name"] == long_name
assert executed["data"]["node"]["id"] == to_global_id(
"GmmBranchSet", f"{model_version}:{short_name}"
)


@pytest.mark.parametrize(
"model_version, branch_set_short_name, gsim_name, gsim_args, weight",
[
(
"NSHM_v1.0.0",
"CRU",
"Stafford2022",
'{"mu_branch": "Upper"}',
0.117,
),
(
"NSHM_v1.0.4",
"INTER",
"Atkinson2022SInter",
'{"epistemic": "Lower", "modified_sigma": "true"}',
0.081,
),
],
)
def test_get_model_GmmLogicTreeBranch_as_node(
client, model_version, branch_set_short_name, gsim_name, gsim_args, weight
):
QUERY = """
query {
node(id: "%s")
{
... on Node {
id
}
... on GmmLogicTreeBranch {
model_version
branch_set_short_name
gsim_name
gsim_args
weight
}

}
}
""" % to_global_id(
"GmmLogicTreeBranch",
f"{model_version}|{branch_set_short_name}|{gsim_name}|{gsim_args}",
)
executed = client.execute(QUERY)
print(executed)
assert executed["data"]["node"]["id"] == to_global_id(
"GmmLogicTreeBranch",
f"{model_version}|{branch_set_short_name}|{gsim_name}|{gsim_args}",
)

assert executed["data"]["node"]["model_version"] == model_version
assert executed["data"]["node"]["branch_set_short_name"] == branch_set_short_name
assert executed["data"]["node"]["gsim_name"] == gsim_name
assert executed["data"]["node"]["weight"] == weight
Loading
Loading