Skip to content

Commit e664ba6

Browse files
committed
new gmm classes & tests; WIP util LIB is updated;
1 parent ad97850 commit e664ba6

6 files changed

+593
-206
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Define graphene model for nzshm_model gmm logic tree classes."""
2+
3+
import logging
4+
from functools import lru_cache
5+
6+
import graphene
7+
from graphene import relay
8+
9+
from .nshm_model_sources_schema import get_model_by_version
10+
11+
log = logging.getLogger(__name__)
12+
13+
14+
# TODO: this method belongs on the nzshm-model gmcm class
15+
@lru_cache
16+
def get_branch_set(model_version, tectonic_region_type):
17+
glt = get_model_by_version(model_version).gmm_logic_tree
18+
log.debug(f"glt {glt}")
19+
for bs in glt.branch_sets:
20+
if bs.tectonic_region_type == tectonic_region_type:
21+
return bs
22+
assert 0, f"branch set {tectonic_region_type} was not found" # pragma: no cover
23+
24+
25+
# TODO: this method belongs on the nzshm-model gmcm class
26+
@lru_cache
27+
def get_logic_tree_branch(model_version, branch_set_trt, gsim_name, gsim_args):
28+
log.info(
29+
f"get_logic_tree_branch: {branch_set_trt} gsim_name: {gsim_name} gsim_args: {gsim_args}"
30+
)
31+
branch_set = get_branch_set(model_version, branch_set_trt)
32+
for ltb in branch_set.branches:
33+
if (ltb.gsim_name == gsim_name) and (str(ltb.gsim_args) == gsim_args):
34+
return ltb
35+
print(branch_set_trt, ltb.tag)
36+
assert (
37+
0
38+
), f"branch with gsim_name: {gsim_name} gsim_args: {gsim_args} was not found" # pragma: no cover
39+
40+
41+
class GmmLogicTreeBranch(graphene.ObjectType):
42+
class Meta:
43+
interfaces = (relay.Node,)
44+
45+
model_version = graphene.String()
46+
branch_set_trt = graphene.String()
47+
gsim_name = graphene.String()
48+
gsim_args = graphene.String()
49+
tectonic_region_type = graphene.String() # should be an enum
50+
weight = graphene.Float()
51+
52+
def resolve_id(self, info):
53+
return f"{self.model_version}:{self.branch_set_trt}:{self.gsim_name}:{self.gsim_args}"
54+
55+
@classmethod
56+
def get_node(cls, info, node_id: str):
57+
model_version, branch_set_trt, gsim_name, gsim_args = node_id.split(":")
58+
gltb = get_logic_tree_branch(
59+
model_version, branch_set_trt, gsim_name, gsim_args
60+
)
61+
return GmmLogicTreeBranch(
62+
model_version=model_version,
63+
branch_set_trt=branch_set_trt,
64+
gsim_name=gsim_name,
65+
gsim_args=gsim_args,
66+
weight=gltb.weight,
67+
)
68+
69+
70+
class GmmBranchSet(graphene.ObjectType):
71+
"""Ground Motion Model branch sets,
72+
73+
to ensure that the wieghts of the enclosed branches sum to 1.0
74+
"""
75+
76+
class Meta:
77+
interfaces = (relay.Node,)
78+
79+
model_version = graphene.String()
80+
short_name = graphene.String()
81+
long_name = graphene.String()
82+
tectonic_region_type = graphene.String()
83+
branches = graphene.List(GmmLogicTreeBranch)
84+
85+
def resolve_id(self, info):
86+
return f"{self.model_version}:{self.tectonic_region_type}"
87+
88+
@classmethod
89+
def get_node(cls, info, node_id: str):
90+
model_version, tectonic_region_type = node_id.split(":")
91+
bs = get_branch_set(model_version, tectonic_region_type)
92+
return GmmBranchSet(
93+
model_version=model_version,
94+
tectonic_region_type=tectonic_region_type,
95+
short_name=bs.short_name,
96+
long_name=bs.long_name,
97+
)
98+
99+
@staticmethod
100+
def resolve_branches(root, info, **kwargs):
101+
log.info(f"resolve_branches root: {root} kwargs: {kwargs}")
102+
bs = get_branch_set(root.model_version, root.tectonic_region_type)
103+
for ltb in bs.branches:
104+
log.debug(ltb)
105+
ltb = GmmLogicTreeBranch(
106+
model_version=root.model_version,
107+
tectonic_region_type=root.tectonic_region_type,
108+
weight=ltb.weight,
109+
gsim_name=ltb.gsim_name,
110+
gsim_args=str(ltb.gsim_args),
111+
)
112+
yield ltb
113+
114+
115+
class GroundMotionModelLogicTree(graphene.ObjectType):
116+
"""A custom Node representing the GMM logic tree of a given model."""
117+
118+
class Meta:
119+
interfaces = (relay.Node,)
120+
121+
model_version = graphene.String()
122+
branch_sets = graphene.List(GmmBranchSet)
123+
124+
def resolve_id(self, info):
125+
return self.model_version
126+
127+
@classmethod
128+
def get_node(cls, info, model_version: str):
129+
return GroundMotionModelLogicTree(model_version=model_version)
130+
131+
@staticmethod
132+
def resolve_branch_sets(root, info, **kwargs):
133+
log.info(f"resolve_branch_sets root: {root} kwargs: {kwargs}")
134+
glt = get_model_by_version(root.model_version).gmm_logic_tree
135+
for bs in glt.branch_sets:
136+
yield GmmBranchSet(
137+
model_version=root.model_version,
138+
tectonic_region_type=bs.tectonic_region_type,
139+
)

nshm_model_graphql_api/schema/nshm_model_schema.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import nzshm_model as nm
88
from graphene import relay
99

10+
from .nshm_model_gmms_schema import GroundMotionModelLogicTree
1011
from .nshm_model_sources_schema import SourceLogicTree
1112

1213
log = logging.getLogger(__name__)
@@ -21,16 +22,20 @@ class Meta:
2122
version = graphene.String()
2223
title = graphene.String()
2324
source_logic_tree = graphene.Field(SourceLogicTree)
25+
gmm_logic_tree = graphene.Field(GroundMotionModelLogicTree)
2426

2527
def resolve_id(self, info):
2628
return self.version
2729

2830
@staticmethod
2931
def resolve_source_logic_tree(root, info, **kwargs):
3032
log.info(f"resolve_source_logic_tree root: {root} kwargs: {kwargs}")
31-
return SourceLogicTree(
32-
model_version=root.version
33-
) # , branch_sets=get_branch_sets(slt))
33+
return SourceLogicTree(model_version=root.version)
34+
35+
@staticmethod
36+
def resolve_gmm_logic_tree(root, info, **kwargs):
37+
log.info(f"resolve_gmm_logic_tree root: {root} kwargs: {kwargs}")
38+
return GroundMotionModelLogicTree(model_version=root.version)
3439

3540
@classmethod
3641
def get_node(cls, info, version: str):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import pytest
2+
from graphene.test import Client
3+
from graphql_relay import to_global_id
4+
5+
from nshm_model_graphql_api import schema
6+
7+
8+
@pytest.fixture(scope="module")
9+
def client():
10+
return Client(schema.schema_root)
11+
12+
13+
@pytest.mark.parametrize(
14+
"model_version",
15+
["NSHM_v1.0.0", "NSHM_v1.0.4"],
16+
)
17+
def test_get_model_SourceLogicTree_as_node(client, model_version):
18+
QUERY = """
19+
query {
20+
node(id: "%s")
21+
{
22+
... on Node {
23+
id
24+
}
25+
... on GroundMotionModelLogicTree {
26+
model_version
27+
}
28+
}
29+
}
30+
""" % to_global_id(
31+
"GroundMotionModelLogicTree", model_version
32+
)
33+
print(QUERY)
34+
executed = client.execute(QUERY)
35+
print(executed)
36+
assert executed["data"]["node"]["model_version"] == model_version
37+
assert executed["data"]["node"]["id"] == to_global_id(
38+
"GroundMotionModelLogicTree", model_version
39+
)
40+
41+
42+
@pytest.mark.parametrize(
43+
"model_version, short_name, long_name",
44+
[
45+
("NSHM_v1.0.0", "Active Shallow Crust", "Crustal"),
46+
# ("NSHM_v1.0.0", "PUY", "Puysegur"),
47+
# ("NSHM_v1.0.4", "CRU", "Crustal"),
48+
# ("NSHM_v1.0.4", "PUY", "Puysegur"),
49+
],
50+
)
51+
def test_get_model_GmmBranchSet_as_node(client, model_version, short_name, long_name):
52+
QUERY = """
53+
query {
54+
node(id: "%s")
55+
{
56+
... on Node {
57+
id
58+
}
59+
... on GmmBranchSet {
60+
model_version
61+
short_name
62+
long_name
63+
tectonic_region_type
64+
}
65+
66+
}
67+
}
68+
""" % to_global_id(
69+
"GmmBranchSet", f"{model_version}:{short_name}"
70+
)
71+
executed = client.execute(QUERY)
72+
print(executed)
73+
assert executed["data"]["node"]["model_version"] == model_version
74+
assert executed["data"]["node"]["tectonic_region_type"] == short_name
75+
# assert executed["data"]["node"]["long_name"] == long_name
76+
assert executed["data"]["node"]["id"] == to_global_id(
77+
"GmmBranchSet", f"{model_version}:{short_name}"
78+
)
79+
80+
81+
'''
82+
83+
@pytest.mark.parametrize(
84+
"model_version, branch_set_short_name, tag, weight",
85+
[
86+
(
87+
"NSHM_v1.0.0",
88+
"CRU",
89+
"[dmgeologic, tdFalse, bN[1.089, 4.6], C4.2, s1.0]",
90+
0.00541000379473566,
91+
),
92+
("NSHM_v1.0.0", "PUY", "[dm0.7, bN[0.902, 4.6], C4.0, s0.28]", 0.21),
93+
(
94+
"NSHM_v1.0.4",
95+
"CRU",
96+
"[dmgeologic, tdFalse, bN[1.089, 4.6], C4.2, s1.41]",
97+
0.00286782725429677,
98+
),
99+
("NSHM_v1.0.4", "PUY", "[dm0.7, bN[0.902, 4.6], C4.0, s0.28]", 0.21),
100+
],
101+
)
102+
def test_get_model_SourceLogicTreeBranch_as_node(
103+
client, model_version, branch_set_short_name, tag, weight
104+
):
105+
QUERY = """
106+
query {
107+
node(id: "%s")
108+
{
109+
... on Node {
110+
id
111+
}
112+
... on SourceLogicTreeBranch {
113+
model_version
114+
branch_set_short_name
115+
tag
116+
weight
117+
sources {
118+
... on BranchInversionSource {
119+
nrml_id
120+
}
121+
}
122+
}
123+
124+
}
125+
}
126+
""" % to_global_id(
127+
"SourceLogicTreeBranch", f"{model_version}:{branch_set_short_name}:{tag}"
128+
)
129+
executed = client.execute(QUERY)
130+
print(executed)
131+
assert executed["data"]["node"]["id"] == to_global_id(
132+
"SourceLogicTreeBranch", f"{model_version}:{branch_set_short_name}:{tag}"
133+
)
134+
135+
assert executed["data"]["node"]["model_version"] == model_version
136+
assert executed["data"]["node"]["branch_set_short_name"] == branch_set_short_name
137+
assert executed["data"]["node"]["tag"] == tag
138+
assert executed["data"]["node"]["weight"] == weight
139+
'''

0 commit comments

Comments
 (0)