Skip to content

Commit

Permalink
feat: add solve pagination method (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtth authored May 19, 2024
1 parent 42cfc4e commit a690686
Show file tree
Hide file tree
Showing 5 changed files with 873 additions and 750 deletions.
88 changes: 85 additions & 3 deletions opvious/client/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
import json
import humanize
import logging
from typing import cast, Iterable, Optional, Sequence, Union
from typing import (
cast,
Any,
AsyncIterator,
Iterable,
Optional,
Sequence,
Union,
)

from ..common import (
Json,
format_percent,
json_dict,
)
from ..data.queued_solves import (
AttemptAttributes,
QueuedSolve,
queued_solve_from_graphql,
SolveNotification,
Expand Down Expand Up @@ -475,7 +484,9 @@ async def solve(

return solution

async def queue_solve(self, problem: Problem) -> QueuedSolve:
async def queue_solve(
self, problem: Problem, attributes: Optional[AttemptAttributes] = None
) -> QueuedSolve:
"""Queues a solve for asynchronous processing
Inputs will be validated locally before the request is sent to the API.
Expand Down Expand Up @@ -527,7 +538,7 @@ async def queue_solve(self, problem: Problem) -> QueuedSolve:
result_type=JsonExecutorResult,
url="/queue-solve",
method="POST",
json_data=json_dict(problem=problem),
json_data=json_dict(problem=problem, attributes=attributes),
) as res:
uuid = res.json_data()["uuid"]
return QueuedSolve(
Expand Down Expand Up @@ -707,3 +718,74 @@ async def fetch_solve_outputs(self, solve: QueuedSolve) -> SolveOutputs:
raw_variables=data["variables"],
raw_constraints=data["constraints"],
)

async def paginate_formulation_solves(
self,
name: str,
attributes: Optional[AttemptAttributes] = None,
limit: int = 25,
) -> AsyncIterator[QueuedSolve]:
"""Lists queued solves
Args:
name: Formulation name
attributes: Optional attributes to filter by
limit: Maximum number of results to return
"""
attribute_list = (
[json_dict(key=k, value=v) for k, v in attributes.items()]
if attributes
else None
)
cursor = None
outlines: dict[int, ProblemOutline] = {}
while limit > 0:
solves, cursor = await self._list_formulation_solves(
name, attribute_list, outlines, cursor, limit
)
if not solves:
return
for solve in solves:
yield solve
limit -= len(solves)

async def _list_formulation_solves(
self,
name: str,
attribute_list: Any,
outlines: dict[int, ProblemOutline],
cursor: Optional[str],
limit: Optional[int],
) -> tuple[list[QueuedSolve], str]:
data = await self._executor.execute_graphql_query(
query="@PaginateFormulationQueuedSolves",
variables=json_dict(
name=name, last=limit, before=cursor, attributes=attribute_list
),
)
formulation = data["formulation"]
if not formulation:
return [], ""
solves: list[QueuedSolve] = []
for edge in formulation["attempts"]["edges"]:
node = edge["node"]
content = node["content"]
if not content:
continue
spec = content["specification"]
outline = outlines.get(spec["revno"])
if not outline:
outline = await generate_outline(
self._executor,
spec["outline"],
content["transformations"],
)
outlines[spec["revno"]] = outline
solves.append(
QueuedSolve(
uuid=content["uuid"],
outline=outline,
started_at=datetime.fromisoformat(node["startedAt"]),
)
)
return solves, formulation["attempts"]["pageInfo"]["startCursor"]
5 changes: 4 additions & 1 deletion opvious/data/queued_solves.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import dataclasses
from datetime import datetime
from typing import Any, Optional
from typing import Any, Mapping, Optional

from .outlines import ProblemOutline
from .tensors import Value


AttemptAttributes = Mapping[str, str]


@dataclasses.dataclass(frozen=True)
class QueuedSolve:
"""Queued optimization attempt
Expand Down
Loading

0 comments on commit a690686

Please sign in to comment.