Skip to content

Commit

Permalink
Update outline generation (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtth authored Nov 29, 2023
1 parent 0b196f4 commit 4b30413
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 31 deletions.
36 changes: 23 additions & 13 deletions opvious/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,24 @@ def build(self) -> SolveInputs:
)


async def generate_outline(
executor: Executor, outline_data: Json, transformation_data: Json
) -> ProblemOutline:
if not transformation_data:
return outline_from_json(outline_data)
async with executor.execute(
result_type=JsonExecutorResult,
url="/outlines/transform",
method="POST",
json_data=json_dict(
outline=outline_data,
transformations=transformation_data,
),
) as res:
data = res.json_data()
return outline_from_json(data["outline"])


class ProblemOutlineGenerator:
def __init__(self, executor: Executor, outline_data: Json):
self._executor = executor
Expand Down Expand Up @@ -178,20 +196,12 @@ async def generate(self) -> tuple[ProblemOutline, Json]:

class Context(ProblemTransformationContext):
async def fetch_outline(self) -> ProblemOutline:
transformations = self.get_json()
if not transformations:
transformation_data = self.get_json()
if not transformation_data:
return pristine_outline
async with executor.execute(
result_type=JsonExecutorResult,
url="/outlines/transform",
method="POST",
json_data=json_dict(
outline=pristine_outline_data,
transformations=transformations,
),
) as res:
data = res.json_data()
return outline_from_json(data["outline"])
return await generate_outline(
executor, pristine_outline_data, transformation_data
)

context = Context()
for tf in self._transformations:
Expand Down
32 changes: 23 additions & 9 deletions opvious/client/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from ..data.outcomes import (
AbortedOutcome,
FailedOutcome,
FeasibleOutcome,
InfeasibleOutcome,
SolveOutcome,
Expand All @@ -29,7 +30,7 @@
feasible_outcome_from_graphql,
solve_outcome_status,
)
from ..data.outlines import ProblemOutline, outline_from_json
from ..data.outlines import ProblemOutline
from ..data.solves import (
ProblemSummary,
SolveInputs,
Expand Down Expand Up @@ -60,6 +61,7 @@
ProblemOutlineGenerator,
SolveInputsBuilder,
feasible_outcome_details,
generate_outline,
log_progress,
)

Expand Down Expand Up @@ -530,8 +532,8 @@ async def queue_solve(self, problem: Problem) -> QueuedSolve:
uuid = res.json_data()["uuid"]
return QueuedSolve(
uuid=uuid,
started_at=datetime.now(timezone.utc),
outline=outline,
started_at=datetime.now(timezone.utc),
)

async def fetch_solve(self, uuid: str) -> Optional[QueuedSolve]:
Expand All @@ -547,10 +549,12 @@ async def fetch_solve(self, uuid: str) -> Optional[QueuedSolve]:
solve = data["queuedSolve"]
if not solve:
return None
return queued_solve_from_graphql(
data=solve,
outline=outline_from_json(solve["outline"]),
outline = await generate_outline(
self._executor,
solve["specification"]["outline"],
solve["transformations"],
)
return queued_solve_from_graphql(solve, outline)

async def cancel_solve(self, uuid: str) -> bool:
"""Cancels a running solve
Expand Down Expand Up @@ -580,13 +584,26 @@ async def poll_solve(
variables=json_dict(uuid=solve.uuid),
)
solve_data = data["queuedSolve"]

error_status = solve_data["attempt"]["errorStatus"]
if error_status:
failure_data = solve_data["failure"]
if failure_data:
return failed_outcome_from_graphql(failure_data)
else:
return FailedOutcome(
error_status,
"The problem's inputs did not match its specification",
)

outcome_data = solve_data["outcome"]
if not outcome_data:
edges = solve_data["notifications"]["edges"]
return solve_notification_from_graphql(
dequeued=bool(solve_data["dequeuedAt"]),
data=edges[0]["node"] if edges else None,
)

status = outcome_data["status"]
if status == "ABORTED":
return cast(SolveOutcome, AbortedOutcome())
Expand All @@ -596,10 +613,7 @@ async def poll_solve(
return UnboundedOutcome()
if status == "FEASIBLE" or status == "OPTIMAL":
return feasible_outcome_from_graphql(outcome_data)
failure_data = solve_data["failure"]
if not failure_data:
raise Exception(f"Unexpected status {status} without failure")
return failed_outcome_from_graphql(failure_data)
raise Exception(f"Unexpected status {status} without failure")

@backoff.on_predicate(
backoff.fibo,
Expand Down
4 changes: 2 additions & 2 deletions opvious/data/outcomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class FailedOutcome:
message: str
"""The underlying error's message"""

code: Optional[str]
code: Optional[str] = None
"""The underlying error's error code"""

tags: Any
tags: Any = None
"""Structured data associated with the failure"""


Expand Down
8 changes: 4 additions & 4 deletions opvious/data/queued_solves.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ class QueuedSolve:
uuid: str
"""The solve's unique identifier"""

started_at: datetime
"""The time the solve was created"""

outline: ProblemOutline = dataclasses.field(repr=False)
"""The specification outline corresponding to this solve"""

started_at: datetime
"""The time the solve was created"""


def queued_solve_from_graphql(
data: Any, outline: ProblemOutline
) -> QueuedSolve:
return QueuedSolve(
uuid=data["uuid"],
started_at=datetime.fromisoformat(data["startedAt"]),
outline=outline,
started_at=datetime.fromisoformat(data["attempt"]["startedAt"]),
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "opvious"
version = "0.18.5rc2"
version = "0.18.6rc1"
description = "Opvious Python SDK"
authors = ["Opvious Engineering <oss@opvious.io>"]
readme = "README.md"
Expand Down
5 changes: 3 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def test_queue_diet_solve(self):

@pytest.mark.asyncio
async def test_queue_relaxed_solve(self):
solve = await client.queue_solve(
queued = await client.queue_solve(
opvious.Problem(
specification=opvious.FormulationSpecification("bounded"),
transformations=[
Expand All @@ -113,7 +113,8 @@ async def test_queue_relaxed_solve(self):
parameters={"bound": 3},
),
)
outcome = await client.wait_for_solve_outcome(solve)
fetched = await client.fetch_solve(queued.uuid)
outcome = await client.wait_for_solve_outcome(fetched)
assert isinstance(outcome, opvious.FeasibleOutcome)
assert outcome.objective_value == 2

Expand Down

0 comments on commit 4b30413

Please sign in to comment.